diff --git a/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py b/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py index 2e0e45ab..3a9254be 100644 --- a/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py +++ b/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py @@ -30,6 +30,7 @@ from aws_advanced_python_wrapper.errors import FailoverError from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils @@ -147,13 +148,12 @@ def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]): class AuroraConnectionTrackerPlugin(Plugin): - _SUBSCRIBED_METHODS: Set[str] = {"*"} _current_writer: Optional[HostInfo] = None _need_update_current_writer: bool = False @property def subscribed_methods(self) -> Set[str]: - return self._SUBSCRIBED_METHODS + return self._subscribed_methods def __init__(self, plugin_service: PluginService, @@ -164,6 +164,11 @@ def __init__(self, self._props = props self._rds_utils = rds_utils self._tracker = tracker + self._subscribed_methods: Set[str] = {DbApiMethod.CONNECT.method_name, + DbApiMethod.CONNECTION_CLOSE.method_name, + DbApiMethod.CONNECT.method_name, + DbApiMethod.NOTIFY_HOST_LIST_CHANGED.method_name} + self._subscribed_methods.update(self._plugin_service.network_bound_methods) def connect( self, @@ -210,5 +215,6 @@ def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]: class AuroraConnectionTrackerPluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return AuroraConnectionTrackerPlugin(plugin_service, props) diff --git a/aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py b/aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py index 7e56daaf..5a7d3862 100644 --- a/aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py +++ b/aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py @@ -233,5 +233,6 @@ def _has_no_readers(self) -> bool: class AuroraInitialConnectionStrategyPluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return AuroraInitialConnectionStrategyPlugin(plugin_service) diff --git a/aws_advanced_python_wrapper/aws_secrets_manager_plugin.py b/aws_advanced_python_wrapper/aws_secrets_manager_plugin.py index 895a0d1e..b225c92a 100644 --- a/aws_advanced_python_wrapper/aws_secrets_manager_plugin.py +++ b/aws_advanced_python_wrapper/aws_secrets_manager_plugin.py @@ -32,6 +32,7 @@ from aws_advanced_python_wrapper.plugin_service import PluginService from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages @@ -45,7 +46,7 @@ class AwsSecretsManagerPlugin(Plugin): - _SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"} + _SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.CONNECT.method_name, DbApiMethod.FORCE_CONNECT.method_name} _SECRETS_ARN_PATTERN = r"^arn:aws:secretsmanager:(?P[^:\n]*):[^:\n]*:([^:/\n]*[:/])?(.*)$" _ONE_YEAR_IN_SECONDS = 60 * 60 * 24 * 365 @@ -136,7 +137,8 @@ def _update_secret(self, token_expiration_ns: int, force_refetch: bool = False) """ telemetry_factory = self._plugin_service.get_telemetry_factory() context = telemetry_factory.open_telemetry_context("fetch credentials", TelemetryTraceLevel.NESTED) - self._fetch_credentials_counter.inc() + if self._fetch_credentials_counter is not None: + self._fetch_credentials_counter.inc() try: fetched: bool = False @@ -167,11 +169,13 @@ def _update_secret(self, token_expiration_ns: int, force_refetch: bool = False) return fetched except Exception as ex: - context.set_success(False) - context.set_exception(ex) + if context is not None: + context.set_success(False) + context.set_exception(ex) raise ex finally: - context.close_context() + if context is not None: + context.close_context() def _fetch_latest_credentials(self): """ @@ -228,5 +232,6 @@ def _get_rds_region(self, secret_id: str, props: Properties) -> str: class AwsSecretsManagerPluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return AwsSecretsManagerPlugin(plugin_service, props) diff --git a/aws_advanced_python_wrapper/blue_green_plugin.py b/aws_advanced_python_wrapper/blue_green_plugin.py index aeed76e5..f0f95ac6 100644 --- a/aws_advanced_python_wrapper/blue_green_plugin.py +++ b/aws_advanced_python_wrapper/blue_green_plugin.py @@ -42,6 +42,7 @@ from aws_advanced_python_wrapper.host_availability import HostAvailability from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.atomic import AtomicInt from aws_advanced_python_wrapper.utils.concurrent import (ConcurrentDict, @@ -471,7 +472,8 @@ def apply( "SuspendConnectRouting.SwitchoverCompleteContinueWithConnect", (time.time() - start_time_sec) * 1000)) finally: - telemetry_context.close_context() + if telemetry_context is not None: + telemetry_context.close_context() # return None so that the next routing can attempt a connection return None @@ -540,7 +542,8 @@ def apply( host_info.host, (time.time() - start_time_sec) * 1000)) finally: - telemetry_context.close_context() + if telemetry_context is not None: + telemetry_context.close_context() # return None so that the next routing can attempt a connection return None @@ -615,15 +618,16 @@ def apply( method_name, (time.time() - start_time_sec) * 1000)) finally: - telemetry_context.close_context() + if telemetry_context is not None: + telemetry_context.close_context() # return empty so that the next routing can attempt a connection return ValueContainer.empty() class BlueGreenPlugin(Plugin): - _SUBSCRIBED_METHODS: Set[str] = {"connect"} - _CLOSE_METHODS: ClassVar[Set[str]] = {"Connection.close", "Cursor.close"} + _SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.CONNECT.method_name} + _CLOSE_METHODS: ClassVar[Set[str]] = {DbApiMethod.CONNECTION_CLOSE.method_name, DbApiMethod.CURSOR_CLOSE.method_name} _status_providers: ClassVar[ConcurrentDict[str, BlueGreenStatusProvider]] = ConcurrentDict() def __init__(self, plugin_service: PluginService, props: Properties): @@ -779,7 +783,8 @@ def get_hold_time_ns(self) -> int: class BlueGreenPluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return BlueGreenPlugin(plugin_service, props) diff --git a/aws_advanced_python_wrapper/connect_time_plugin.py b/aws_advanced_python_wrapper/connect_time_plugin.py index d754b9f7..2df4d96c 100644 --- a/aws_advanced_python_wrapper/connect_time_plugin.py +++ b/aws_advanced_python_wrapper/connect_time_plugin.py @@ -27,6 +27,7 @@ from time import perf_counter_ns from typing import Callable, Set +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.messages import Messages @@ -42,7 +43,7 @@ def reset_connect_time(): @property def subscribed_methods(self) -> Set[str]: - return {"connect", "force_connect"} + return {DbApiMethod.CONNECT.method_name, DbApiMethod.FORCE_CONNECT.method_name} def connect( self, @@ -65,5 +66,6 @@ def connect( class ConnectTimePluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return ConnectTimePlugin() diff --git a/aws_advanced_python_wrapper/custom_endpoint_plugin.py b/aws_advanced_python_wrapper/custom_endpoint_plugin.py index 8e490e02..2db46ae5 100644 --- a/aws_advanced_python_wrapper/custom_endpoint_plugin.py +++ b/aws_advanced_python_wrapper/custom_endpoint_plugin.py @@ -37,6 +37,7 @@ from boto3 import Session +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.properties import WrapperProperties @@ -194,7 +195,9 @@ def _run(self): self._custom_endpoint_host_info.host, endpoint_info, CustomEndpointMonitor._CUSTOM_ENDPOINT_INFO_EXPIRATION_NS) - self._info_changed_counter.inc() + + if self._info_changed_counter is not None: + self._info_changed_counter.inc() elapsed_time = perf_counter_ns() - start_ns sleep_duration = max(0, self._refresh_rate_ns - elapsed_time) @@ -228,7 +231,7 @@ class CustomEndpointPlugin(Plugin): A plugin that analyzes custom endpoints for custom endpoint information and custom endpoint changes, such as adding or removing an instance in the custom endpoint. """ - _SUBSCRIBED_METHODS: ClassVar[Set[str]] = {"connect"} + _SUBSCRIBED_METHODS: ClassVar[Set[str]] = {DbApiMethod.CONNECT.method_name} _CACHE_CLEANUP_RATE_NS: ClassVar[int] = 6 * 10 ^ 10 # 1 minute _monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, CustomEndpointMonitor]] = \ SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_RATE_NS, @@ -250,7 +253,7 @@ def __init__(self, plugin_service: PluginService, props: Properties): self._custom_endpoint_host_info: Optional[HostInfo] = None self._custom_endpoint_id: Optional[str] = None telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory() - self._wait_for_info_counter: TelemetryCounter = telemetry_factory.create_counter("customEndpoint.waitForInfo.counter") + self._wait_for_info_counter: TelemetryCounter | None = telemetry_factory.create_counter("customEndpoint.waitForInfo.counter") CustomEndpointPlugin._SUBSCRIBED_METHODS.update(self._plugin_service.network_bound_methods) @@ -312,7 +315,8 @@ def _wait_for_info(self, monitor: CustomEndpointMonitor): if has_info: return - self._wait_for_info_counter.inc() + if self._wait_for_info_counter is not None: + self._wait_for_info_counter.inc() host_info = cast('HostInfo', self._custom_endpoint_host_info) hostname = host_info.host logger.debug("CustomEndpointPlugin.WaitingForCustomEndpointInfo", hostname, self._wait_for_info_timeout_ms) @@ -343,5 +347,6 @@ def execute(self, target: type, method_name: str, execute_func: Callable, *args: class CustomEndpointPluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return CustomEndpointPlugin(plugin_service, props) diff --git a/aws_advanced_python_wrapper/default_plugin.py b/aws_advanced_python_wrapper/default_plugin.py index 86934f4d..31d532c5 100644 --- a/aws_advanced_python_wrapper/default_plugin.py +++ b/aws_advanced_python_wrapper/default_plugin.py @@ -31,6 +31,7 @@ from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.host_availability import HostAvailability from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ @@ -38,8 +39,7 @@ class DefaultPlugin(Plugin): - _SUBSCRIBED_METHODS: Set[str] = {"*"} - _CLOSE_METHOD = "Connection.close" + _SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.ALL.method_name} def __init__(self, plugin_service: PluginService, connection_provider_manager: ConnectionProviderManager): self._plugin_service: PluginService = plugin_service @@ -74,7 +74,8 @@ def _connect( database_dialect = self._plugin_service.database_dialect conn = conn_provider.connect(target_func, driver_dialect, database_dialect, host_info, props) finally: - context.close_context() + if context is not None: + context.close_context() self._plugin_service.set_availability(host_info.all_aliases, HostAvailability.AVAILABLE) self._plugin_service.update_driver_dialect(conn_provider) @@ -106,9 +107,10 @@ def execute(self, target: object, method_name: str, execute_func: Callable, *arg try: result = self._plugin_service.driver_dialect.execute(method_name, execute_func, *args, **kwargs) finally: - context.close_context() + if context is not None: + context.close_context() - if method_name != DefaultPlugin._CLOSE_METHOD and self._plugin_service.current_connection is not None: + if method_name != DbApiMethod.CONNECTION_CLOSE.method_name and self._plugin_service.current_connection is not None: self._plugin_service.update_in_transaction() return result diff --git a/aws_advanced_python_wrapper/developer_plugin.py b/aws_advanced_python_wrapper/developer_plugin.py index e8bbd152..ffca0370 100644 --- a/aws_advanced_python_wrapper/developer_plugin.py +++ b/aws_advanced_python_wrapper/developer_plugin.py @@ -23,6 +23,7 @@ from aws_advanced_python_wrapper.utils.properties import Properties from aws_advanced_python_wrapper.hostinfo import HostInfo +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages @@ -76,8 +77,7 @@ def set_method_callback(method_callback: Optional[ExceptionSimulatorMethodCallba class DeveloperPlugin(Plugin): - _ALL_METHODS: str = "*" - _SUBSCRIBED_METHODS: Set[str] = {_ALL_METHODS} + _SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.ALL.method_name} @property def subscribed_methods(self) -> Set[str]: @@ -90,7 +90,7 @@ def execute(self, target: type, method_name: str, execute_func: Callable, *args: def raise_method_exception_if_set( self, target: type, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> None: if ExceptionSimulatorManager.next_method_exception is not None: - if DeveloperPlugin._ALL_METHODS == ExceptionSimulatorManager.next_method_name or \ + if DbApiMethod.ALL.method_name == ExceptionSimulatorManager.next_method_name or \ method_name == ExceptionSimulatorManager.next_method_name: self.raise_exception_on_method(ExceptionSimulatorManager.next_method_exception, method_name) elif ExceptionSimulatorManager.method_callback is not None: @@ -158,6 +158,6 @@ def raise_exception_on_connect(self, error: Optional[Exception]) -> None: class DeveloperPluginFactory(PluginFactory): - - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return DeveloperPlugin() diff --git a/aws_advanced_python_wrapper/driver_dialect.py b/aws_advanced_python_wrapper/driver_dialect.py index 9722cf94..1da64e7b 100644 --- a/aws_advanced_python_wrapper/driver_dialect.py +++ b/aws_advanced_python_wrapper/driver_dialect.py @@ -26,6 +26,7 @@ from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes from aws_advanced_python_wrapper.errors import (QueryTimeoutError, UnsupportedOperationError) +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.thread_pool_container import \ ThreadPoolContainer from aws_advanced_python_wrapper.utils.decorators import timeout @@ -40,11 +41,10 @@ class DriverDialect(ABC): Driver dialects help the driver-agnostic AWS Python Driver interface with the driver-specific functionality of the underlying Python Driver. """ _QUERY = "SELECT 1" - _ALL_METHODS = "*" _executor_name: ClassVar[str] = "DriverDialectExecutor" _dialect_code: str = DriverDialectCodes.GENERIC - _network_bound_methods: Set[str] = {_ALL_METHODS} + _network_bound_methods: Set[str] = {DbApiMethod.ALL.method_name} _read_only: bool = False _autocommit: bool = False _driver_name: str = "Generic" @@ -130,7 +130,7 @@ def execute( *args: Any, exec_timeout: Optional[float] = None, **kwargs: Any) -> Cursor: - if DriverDialect._ALL_METHODS not in self.network_bound_methods and method_name not in self.network_bound_methods: + if DbApiMethod.ALL.method_name not in self.network_bound_methods and method_name not in self.network_bound_methods: return exec_func() if exec_timeout is None: @@ -161,7 +161,7 @@ def ping(self, conn: Connection) -> bool: try: with conn.cursor() as cursor: query = DriverDialect._QUERY - self.execute("Cursor.execute", lambda: cursor.execute(query), query, exec_timeout=10) + self.execute(DbApiMethod.CURSOR_EXECUTE.method_name, lambda: cursor.execute(query), query, exec_timeout=10) cursor.fetchone() return True except Exception: diff --git a/aws_advanced_python_wrapper/execute_time_plugin.py b/aws_advanced_python_wrapper/execute_time_plugin.py index dbda10b3..6aa48cb3 100644 --- a/aws_advanced_python_wrapper/execute_time_plugin.py +++ b/aws_advanced_python_wrapper/execute_time_plugin.py @@ -54,5 +54,6 @@ def execute(self, target: type, method_name: str, execute_func: Callable, *args: class ExecuteTimePluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return ExecuteTimePlugin() diff --git a/aws_advanced_python_wrapper/failover_plugin.py b/aws_advanced_python_wrapper/failover_plugin.py index c4cb5374..be47fc5f 100644 --- a/aws_advanced_python_wrapper/failover_plugin.py +++ b/aws_advanced_python_wrapper/failover_plugin.py @@ -31,6 +31,7 @@ TransactionResolutionUnknownError) from aws_advanced_python_wrapper.host_availability import HostAvailability from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.reader_failover_handler import ( ReaderFailoverHandler, ReaderFailoverHandlerImpl) @@ -57,18 +58,18 @@ class FailoverPlugin(Plugin): This plugin provides cluster-aware failover features. The plugin switches connections upon detecting communication related exceptions and/or cluster topology changes. """ - _SUBSCRIBED_METHODS: Set[str] = {"init_host_provider", - "connect", - "notify_host_list_changed"} + _SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.INIT_HOST_PROVIDER.method_name, + DbApiMethod.CONNECT.method_name, + DbApiMethod.NOTIFY_HOST_LIST_CHANGED.method_name} _METHODS_REQUIRE_UPDATED_TOPOLOGY: Set[str] = { - "Connection.commit", - "Connection.autocommit", - "Connection.autocommit_setter", - "Connection.rollback", - "Connection.cursor", - "Cursor.callproc", - "Cursor.execute" + DbApiMethod.CONNECTION_COMMIT.method_name, + DbApiMethod.CONNECTION_AUTOCOMMIT.method_name, + DbApiMethod.CONNECTION_AUTOCOMMIT_SETTER.method_name, + DbApiMethod.CONNECTION_ROLLBACK.method_name, + DbApiMethod.CONNECTION_CURSOR.method_name, + DbApiMethod.CURSOR_CALLPROC.method_name, + DbApiMethod.CURSOR_EXECUTE.method_name } def __init__(self, plugin_service: PluginService, props: Properties): @@ -254,7 +255,8 @@ def _failover(self, failed_host: Optional[HostInfo]): def _failover_reader(self, failed_host: Optional[HostInfo]): telemetry_factory = self._plugin_service.get_telemetry_factory() context = telemetry_factory.open_telemetry_context("failover to replica", TelemetryTraceLevel.NESTED) - self._failover_reader_triggered_counter.inc() + if self._failover_reader_triggered_counter is not None: + self._failover_reader_triggered_counter.inc() try: logger.info("FailoverPlugin.StartReaderFailover") @@ -284,26 +286,33 @@ def _failover_reader(self, failed_host: Optional[HostInfo]): logger.info("FailoverPlugin.EstablishedConnection", self._plugin_service.current_host_info) - self._failover_reader_success_counter.inc() + if self._failover_reader_success_counter is not None: + self._failover_reader_success_counter.inc() except FailoverSuccessError as fse: - context.set_success(True) - context.set_exception(fse) - self._failover_reader_success_counter.inc() + if context is not None: + context.set_success(True) + context.set_exception(fse) + if self._failover_reader_success_counter is not None: + self._failover_reader_success_counter.inc() raise fse except Exception as ex: - context.set_success(False) - context.set_exception(ex) - self._failover_reader_failed_counter.inc() + if context is not None: + context.set_success(False) + context.set_exception(ex) + if self._failover_reader_failed_counter is not None: + self._failover_reader_failed_counter.inc() raise ex finally: - context.close_context() - if self._telemetry_failover_additional_top_trace_setting: - telemetry_factory.post_copy(context, TelemetryTraceLevel.FORCE_TOP_LEVEL) + if context is not None: + context.close_context() + if self._telemetry_failover_additional_top_trace_setting: + telemetry_factory.post_copy(context, TelemetryTraceLevel.FORCE_TOP_LEVEL) def _failover_writer(self): telemetry_factory = self._plugin_service.get_telemetry_factory() context = telemetry_factory.open_telemetry_context("failover to writer host", TelemetryTraceLevel.NESTED) - self._failover_writer_triggered_counter.inc() + if self._failover_writer_triggered_counter is not None: + self._failover_writer_triggered_counter.inc() try: logger.info("FailoverPlugin.StartWriterFailover") @@ -328,22 +337,27 @@ def _failover_writer(self): logger.info("FailoverPlugin.EstablishedConnection", self._plugin_service.current_host_info) self._plugin_service.refresh_host_list() - - self._failover_writer_success_counter.inc() + if self._failover_writer_success_counter is not None: + self._failover_writer_success_counter.inc() except FailoverSuccessError as fse: - context.set_success(True) - context.set_exception(fse) - self._failover_writer_success_counter.inc() + if context is not None: + context.set_success(True) + context.set_exception(fse) + if self._failover_writer_success_counter is not None: + self._failover_writer_success_counter.inc() raise fse except Exception as ex: - context.set_success(False) - context.set_exception(ex) - self._failover_writer_failed_counter.inc() + if context is not None: + context.set_success(False) + context.set_exception(ex) + if self._failover_writer_success_counter is not None: + self._failover_writer_failed_counter.inc() raise ex finally: - context.close_context() - if self._telemetry_failover_additional_top_trace_setting: - telemetry_factory.post_copy(context, TelemetryTraceLevel.FORCE_TOP_LEVEL) + if context is not None: + context.close_context() + if self._telemetry_failover_additional_top_trace_setting: + telemetry_factory.post_copy(context, TelemetryTraceLevel.FORCE_TOP_LEVEL) def _invalidate_current_connection(self): """ @@ -358,14 +372,14 @@ def _invalidate_current_connection(self): if self._plugin_service.is_in_transaction: self._plugin_service.update_in_transaction(True) try: - driver_dialect.execute("Connection.rollback", lambda: conn.rollback()) + driver_dialect.execute(DbApiMethod.CONNECTION_ROLLBACK.method_name, lambda: conn.rollback()) conn.rollback() except Exception: pass if not driver_dialect.is_closed(conn): try: - return driver_dialect.execute("Connection.close", lambda: conn.close()) + return driver_dialect.execute(DbApiMethod.CONNECTION_CLOSE.method_name, lambda: conn.close()) except Exception: pass @@ -476,9 +490,9 @@ def _can_direct_execute(method_name): :param method_name: The name of the method that is being called. :return: `True` if the method can be executed directly; `False` otherwise. """ - return method_name == "Connection.close" or \ - method_name == "Connection.is_closed" or \ - method_name == "Cursor.close" + return method_name == DbApiMethod.CONNECTION_CLOSE.method_name or \ + method_name == DbApiMethod.CONNECTION_IS_CLOSED.method_name or \ + method_name == DbApiMethod.CURSOR_CLOSE.method_name @staticmethod def _allowed_on_closed_connection(method_name: str): @@ -488,7 +502,7 @@ def _allowed_on_closed_connection(method_name: str): :param method_name: The method being executed at the moment. :return: `True` if the given method is allowed on closed connections. """ - return method_name == "Connection.autocommit" + return method_name == DbApiMethod.CONNECTION_AUTOCOMMIT.method_name def _requires_update_topology(self, method_name: str): """ @@ -503,5 +517,6 @@ def _requires_update_topology(self, method_name: str): class FailoverPluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return FailoverPlugin(plugin_service, props) diff --git a/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py b/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py index e85d40cc..da963fcf 100644 --- a/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py +++ b/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py @@ -138,8 +138,8 @@ class ResponseTimeTuple: class FastestResponseStrategyPluginFactory: - - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return FastestResponseStrategyPlugin(plugin_service, props) @@ -169,7 +169,7 @@ def __init__(self, plugin_service: PluginService, host_info: HostInfo, props: Pr # Report current response time (in milliseconds) to telemetry engine. # Report -1 if response time couldn't be measured. - self._response_time_gauge: TelemetryGauge = \ + self._response_time_gauge: TelemetryGauge | None = \ self._telemetry_factory.create_gauge("frt.response.time." + self._host_id, lambda: self._response_time if self._response_time != MAX_VALUE else -1) self._daemon_thread.start() @@ -201,7 +201,9 @@ def _get_current_time(self): def run(self): context: TelemetryContext = self._telemetry_factory.open_telemetry_context( "host response time thread", TelemetryTraceLevel.TOP_LEVEL) - context.set_attribute("url", self._host_info.url) + + if context is not None: + context.set_attribute("url", self._host_info.url) try: while not self.is_stopped: self._open_connection() @@ -290,7 +292,7 @@ def __init__(self, plugin_service: PluginService, props: Properties, interval_ms self._interval_ms = interval_ms self._hosts: Tuple[HostInfo, ...] = () self._telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory() - self._host_count_gauge: TelemetryGauge = self._telemetry_factory.create_gauge("frt.hosts.count", lambda: len(self._monitoring_hosts)) + self._host_count_gauge: TelemetryGauge | None = self._telemetry_factory.create_gauge("frt.hosts.count", lambda: len(self._monitoring_hosts)) @property def hosts(self) -> Tuple[HostInfo, ...]: diff --git a/aws_advanced_python_wrapper/federated_plugin.py b/aws_advanced_python_wrapper/federated_plugin.py index 2f48019f..a83b3652 100644 --- a/aws_advanced_python_wrapper/federated_plugin.py +++ b/aws_advanced_python_wrapper/federated_plugin.py @@ -144,7 +144,8 @@ def _update_authentication_token(self, port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port) credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props) - self._fetch_token_counter.inc() + if self._fetch_token_counter is not None: + self._fetch_token_counter.inc() token: str = IamAuthUtils.generate_authentication_token( self._plugin_service, user, @@ -158,10 +159,12 @@ def _update_authentication_token(self, class FederatedAuthPluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: - return FederatedAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props)) + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: + return FederatedAuthPlugin(plugin_service, FederatedAuthPluginFactory.get_credentials_provider_factory(plugin_service, props)) - def get_credentials_provider_factory(self, plugin_service: PluginService, props: Properties) -> AdfsCredentialsProviderFactory: + @staticmethod + def get_credentials_provider_factory(plugin_service: PluginService, props: Properties) -> AdfsCredentialsProviderFactory: idp_name = WrapperProperties.IDP_NAME.get(props) if idp_name is None or idp_name == "" or idp_name == "adfs": return AdfsCredentialsProviderFactory(plugin_service, props) diff --git a/aws_advanced_python_wrapper/host_monitoring_plugin.py b/aws_advanced_python_wrapper/host_monitoring_plugin.py index a5e5752f..6b9271d7 100644 --- a/aws_advanced_python_wrapper/host_monitoring_plugin.py +++ b/aws_advanced_python_wrapper/host_monitoring_plugin.py @@ -33,6 +33,7 @@ from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.host_availability import HostAvailability +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import (CanReleaseResources, Plugin, PluginFactory) from aws_advanced_python_wrapper.thread_pool_container import \ @@ -54,12 +55,17 @@ class HostMonitoringPluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return HostMonitoringPlugin(plugin_service, props) class HostMonitoringPlugin(Plugin, CanReleaseResources): - _SUBSCRIBED_METHODS: Set[str] = {"connect", "notify_host_list_changed", "notify_connection_changed"} + _SUBSCRIBED_METHODS: Set[str] = { + DbApiMethod.CONNECT.method_name, + DbApiMethod.NOTIFY_HOST_LIST_CHANGED.method_name, + DbApiMethod.NOTIFY_CONNECTION_CHANGED.method_name + } def __init__(self, plugin_service, props): dialect: DriverDialect = plugin_service.driver_dialect @@ -74,6 +80,10 @@ def __init__(self, plugin_service, props): self._rds_utils: RdsUtils = RdsUtils() self._monitor_service: MonitorService = MonitorService(plugin_service) self._lock: Lock = Lock() + self._is_enabled = WrapperProperties.FAILURE_DETECTION_ENABLED.get_bool(self._props) + self._failure_detection_time_ms = WrapperProperties.FAILURE_DETECTION_TIME_MS.get_int(self._props) + self._failure_detection_interval = WrapperProperties.FAILURE_DETECTION_INTERVAL_MS.get_int(self._props) + self._failure_detection_count = WrapperProperties.FAILURE_DETECTION_COUNT.get_int(self._props) HostMonitoringPlugin._SUBSCRIBED_METHODS.update(self._plugin_service.network_bound_methods) @property @@ -105,14 +115,9 @@ def execute(self, target: object, method_name: str, execute_func: Callable, *arg if host_info is None: raise AwsWrapperError(Messages.get_formatted("HostMonitoringPlugin.HostInfoNoneForMethod", method_name)) - is_enabled = WrapperProperties.FAILURE_DETECTION_ENABLED.get_bool(self._props) - if not is_enabled or not self._plugin_service.is_network_bound_method(method_name): + if not self._is_enabled or not self._plugin_service.is_network_bound_method(method_name): return execute_func() - failure_detection_time_ms = WrapperProperties.FAILURE_DETECTION_TIME_MS.get_int(self._props) - failure_detection_interval = WrapperProperties.FAILURE_DETECTION_INTERVAL_MS.get_int(self._props) - failure_detection_count = WrapperProperties.FAILURE_DETECTION_COUNT.get_int(self._props) - monitor_context = None result = None @@ -123,9 +128,9 @@ def execute(self, target: object, method_name: str, execute_func: Callable, *arg self._get_monitoring_host_info().all_aliases, self._get_monitoring_host_info(), self._props, - failure_detection_time_ms, - failure_detection_interval, - failure_detection_count + self._failure_detection_time_ms, + self._failure_detection_interval, + self._failure_detection_count ) result = execute_func() finally: @@ -216,7 +221,7 @@ def __init__( failure_detection_time_ms: int, failure_detection_interval_ms: int, failure_detection_count: int, - aborted_connections_counter: TelemetryCounter): + aborted_connections_counter: TelemetryCounter | None): self._monitor: Monitor = monitor self._connection: Connection = connection self._target_dialect: DriverDialect = target_dialect @@ -322,7 +327,8 @@ def _set_host_availability( logger.debug("MonitorContext.HostUnavailable", host) self._is_host_unavailable = True self._abort_connection() - self._aborted_connections_counter.inc() + if self._aborted_connections_counter is not None: + self._aborted_connections_counter.inc() return logger.debug("MonitorContext.HostNotResponding", host, self._current_failure_count) @@ -506,7 +512,9 @@ def run(self): def _check_host_status(self, host_check_timeout_ms: int) -> HostStatus: context = self._telemetry_factory.open_telemetry_context( "connection status check", TelemetryTraceLevel.FORCE_TOP_LEVEL) - context.set_attribute("url", self._host_info.url) + + if context is not None: + context.set_attribute("url", self._host_info.url) start_ns = perf_counter_ns() try: @@ -527,13 +535,16 @@ def _check_host_status(self, host_check_timeout_ms: int) -> HostStatus: start_ns = perf_counter_ns() is_available = self._is_host_available(self._monitoring_conn, host_check_timeout_ms / 1000) if not is_available: - self._host_invalid_counter.inc() + if self._host_invalid_counter is not None: + self._host_invalid_counter.inc() return Monitor.HostStatus(is_available, perf_counter_ns() - start_ns) except Exception: - self._host_invalid_counter.inc() + if self._host_invalid_counter is not None: + self._host_invalid_counter.inc() return Monitor.HostStatus(False, perf_counter_ns() - start_ns) finally: - context.close_context() + if context is not None: + context.close_context() def _is_host_available(self, conn: Connection, timeout_sec: float) -> bool: try: @@ -546,7 +557,7 @@ def _execute_conn_check(self, conn: Connection, timeout_sec: float): driver_dialect = self._plugin_service.driver_dialect with conn.cursor() as cursor: query = Monitor._QUERY - driver_dialect.execute("Cursor.execute", lambda: cursor.execute(query), query, exec_timeout=timeout_sec) + driver_dialect.execute(DbApiMethod.CURSOR_EXECUTE.method_name, lambda: cursor.execute(query), query, exec_timeout=timeout_sec) cursor.fetchone() def sleep(self, duration: int): diff --git a/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py b/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py index c2291027..b3dc881c 100644 --- a/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py +++ b/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py @@ -22,6 +22,7 @@ from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.host_availability import HostAvailability +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import (CanReleaseResources, Plugin, PluginFactory) from aws_advanced_python_wrapper.utils.atomic import (AtomicBoolean, @@ -161,8 +162,8 @@ def release_resources(self): class HostMonitoringV2PluginFactory(PluginFactory): - - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return HostMonitoringV2Plugin(plugin_service, props) @@ -216,7 +217,7 @@ def __init__( failure_detection_time_ms: int, failure_detection_interval_ms: int, failure_detection_count: int, - aborted_connection_counter: TelemetryCounter): + aborted_connection_counter: TelemetryCounter | None): self._plugin_service: PluginService = plugin_service self._host_info: HostInfo = host_info self._props: Properties = props @@ -224,7 +225,7 @@ def __init__( self._failure_detection_time_ns: int = failure_detection_time_ms * 10**6 self._failure_detection_interval_ns: int = failure_detection_interval_ms * 10**6 self._failure_detection_count: int = failure_detection_count - self._aborted_connection_counter: TelemetryCounter = aborted_connection_counter + self._aborted_connection_counter: TelemetryCounter | None = aborted_connection_counter self._active_contexts: Queue = Queue() self._new_contexts: ConcurrentDict[float, Queue] = ConcurrentDict() @@ -399,7 +400,7 @@ def _execute_conn_check(self, conn: Connection, timeout_sec: float): driver_dialect = self._plugin_service.driver_dialect with conn.cursor() as cursor: query = HostMonitorV2._QUERY - driver_dialect.execute("Cursor.execute", lambda: cursor.execute(query), query, exec_timeout=timeout_sec) + driver_dialect.execute(DbApiMethod.CURSOR_EXECUTE.method_name, lambda: cursor.execute(query), query, exec_timeout=timeout_sec) cursor.fetchone() def _update_host_health_status( diff --git a/aws_advanced_python_wrapper/iam_plugin.py b/aws_advanced_python_wrapper/iam_plugin.py index a503be4c..f42ef92c 100644 --- a/aws_advanced_python_wrapper/iam_plugin.py +++ b/aws_advanced_python_wrapper/iam_plugin.py @@ -30,6 +30,7 @@ from typing import Callable, Dict, Optional, Set from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages @@ -41,7 +42,7 @@ class IamAuthPlugin(Plugin): - _SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"} + _SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.CONNECT.method_name, DbApiMethod.FORCE_CONNECT.method_name} # Leave 30 second buffer to prevent time-of-check to time-of-use errors _DEFAULT_TOKEN_EXPIRATION_SEC = 15 * 60 - 30 @@ -101,7 +102,8 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl self._plugin_service.driver_dialect.set_password(props, token_info.token) else: token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec) - self._fetch_token_counter.inc() + if self._fetch_token_counter is not None: + self._fetch_token_counter.inc() token: str = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) self._plugin_service.driver_dialect.set_password(props, token) IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) @@ -119,7 +121,8 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl # Login unsuccessful with cached token # Try to generate a new token and try to connect again token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec) - self._fetch_token_counter.inc() + if self._fetch_token_counter is not None: + self._fetch_token_counter.inc() token = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) self._plugin_service.driver_dialect.set_password(props, token) IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) @@ -141,5 +144,6 @@ def force_connect( class IamAuthPluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return IamAuthPlugin(plugin_service) diff --git a/aws_advanced_python_wrapper/limitless_plugin.py b/aws_advanced_python_wrapper/limitless_plugin.py index 6f597583..85cb15da 100644 --- a/aws_advanced_python_wrapper/limitless_plugin.py +++ b/aws_advanced_python_wrapper/limitless_plugin.py @@ -26,6 +26,7 @@ UnsupportedOperationError) from aws_advanced_python_wrapper.host_availability import HostAvailability from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict from aws_advanced_python_wrapper.utils.log import Logger @@ -100,8 +101,8 @@ def connect( class LimitlessPluginFactory: - - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return LimitlessPlugin(plugin_service, props) @@ -233,7 +234,7 @@ def query_for_limitless_routers(self, connection: Connection, host_port_to_map: query = aurora_limitless_dialect.limitless_router_endpoint_query with closing(connection.cursor()) as cursor: - self._plugin_service.driver_dialect.execute("Cursor.execute", + self._plugin_service.driver_dialect.execute(DbApiMethod.CURSOR_EXECUTE.method_name, lambda: cursor.execute(query), query, exec_timeout=LimitlessQueryHelper._DEFAULT_QUERY_TIMEOUT_SEC) diff --git a/aws_advanced_python_wrapper/mysql_driver_dialect.py b/aws_advanced_python_wrapper/mysql_driver_dialect.py index d777cb4a..c2949ceb 100644 --- a/aws_advanced_python_wrapper/mysql_driver_dialect.py +++ b/aws_advanced_python_wrapper/mysql_driver_dialect.py @@ -26,6 +26,7 @@ from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes from aws_advanced_python_wrapper.errors import UnsupportedOperationError +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.thread_pool_container import \ ThreadPoolContainer from aws_advanced_python_wrapper.utils.decorators import timeout @@ -61,18 +62,18 @@ class MySQLDriverDialect(DriverDialect): _dialect_code: str = DriverDialectCodes.MYSQL_CONNECTOR_PYTHON _network_bound_methods: Set[str] = { - "Connection.commit", - "Connection.autocommit", - "Connection.autocommit_setter", - "Connection.is_read_only", - "Connection.set_read_only", - "Connection.rollback", - "Connection.cursor", - "Cursor.close", - "Cursor.execute", - "Cursor.fetchone", - "Cursor.fetchmany", - "Cursor.fetchall" + DbApiMethod.CONNECTION_COMMIT.method_name, + DbApiMethod.CONNECTION_AUTOCOMMIT.method_name, + DbApiMethod.CONNECTION_AUTOCOMMIT_SETTER.method_name, + DbApiMethod.CONNECTION_IS_READ_ONLY.method_name, + DbApiMethod.CONNECTION_SET_READ_ONLY.method_name, + DbApiMethod.CONNECTION_ROLLBACK.method_name, + DbApiMethod.CONNECTION_CURSOR.method_name, + DbApiMethod.CURSOR_CLOSE.method_name, + DbApiMethod.CURSOR_EXECUTE.method_name, + DbApiMethod.CURSOR_FETCHONE.method_name, + DbApiMethod.CURSOR_FETCHMANY.method_name, + DbApiMethod.CURSOR_FETCHALL.method_name } @staticmethod diff --git a/aws_advanced_python_wrapper/okta_plugin.py b/aws_advanced_python_wrapper/okta_plugin.py index f1334e91..d9caf6e2 100644 --- a/aws_advanced_python_wrapper/okta_plugin.py +++ b/aws_advanced_python_wrapper/okta_plugin.py @@ -139,7 +139,8 @@ def _update_authentication_token(self, token_expiry: datetime = datetime.now() + timedelta(seconds=token_expiration_sec) port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port) credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props) - + if self._fetch_token_counter: + self._fetch_token_counter.inc() token: str = IamAuthUtils.generate_authentication_token( self._plugin_service, user, @@ -227,8 +228,10 @@ def get_saml_assertion(self, props: Properties): class OktaAuthPluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: - return OktaAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props)) + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: + return OktaAuthPlugin(plugin_service, OktaAuthPluginFactory.get_credentials_provider_factory(plugin_service, props)) - def get_credentials_provider_factory(self, plugin_service: PluginService, props: Properties) -> OktaCredentialsProviderFactory: + @staticmethod + def get_credentials_provider_factory(plugin_service: PluginService, props: Properties) -> OktaCredentialsProviderFactory: return OktaCredentialsProviderFactory(plugin_service, props) diff --git a/aws_advanced_python_wrapper/pep249_methods.py b/aws_advanced_python_wrapper/pep249_methods.py new file mode 100644 index 00000000..8d3f6487 --- /dev/null +++ b/aws_advanced_python_wrapper/pep249_methods.py @@ -0,0 +1,90 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Dict + + +class DbApiMethod(Enum): + """Enumeration of database API methods with tuple-based attributes.""" + + # Connection methods - Core PEP 249 + CONNECTION_CLOSE = (0, "Connection.close", False) + CONNECTION_COMMIT = (1, "Connection.commit", True) + CONNECTION_ROLLBACK = (2, "Connection.rollback", True) + CONNECTION_CURSOR = (3, "Connection.cursor", False) + + # Connection methods - Optional PEP 249 extensions + CONNECTION_TPC_BEGIN = (4, "Connection.tpc_begin", True) + CONNECTION_TPC_PREPARE = (5, "Connection.tpc_prepare", True) + CONNECTION_TPC_COMMIT = (6, "Connection.tpc_commit", True) + CONNECTION_TPC_ROLLBACK = (7, "Connection.tpc_rollback", True) + CONNECTION_TPC_RECOVER = (8, "Connection.tpc_recover", False) + + # Connection properties + CONNECTION_AUTOCOMMIT = (9, "Connection.autocommit", False) + CONNECTION_AUTOCOMMIT_SETTER = (10, "Connection.autocommit_setter", False) + CONNECTION_IS_READ_ONLY = (11, "Connection.is_read_only", False) + CONNECTION_SET_READ_ONLY = (12, "Connection.set_read_only", False) + CONNECTION_IS_CLOSED = (13, "Connection.is_closed", False) + + # Cursor methods - Core PEP 249 + CURSOR_CLOSE = (14, "Cursor.close", False) + CURSOR_EXECUTE = (15, "Cursor.execute", False) + CURSOR_EXECUTEMANY = (16, "Cursor.executemany", False) + CURSOR_FETCHONE = (17, "Cursor.fetchone", False) + CURSOR_FETCHMANY = (18, "Cursor.fetchmany", False) + CURSOR_FETCHALL = (19, "Cursor.fetchall", False) + CURSOR_NEXTSET = (20, "Cursor.nextset", False) + CURSOR_SETINPUTSIZES = (21, "Cursor.setinputsizes", False) + CURSOR_SETOUTPUTSIZE = (22, "Cursor.setoutputsize", False) + + # Cursor methods - Optional extensions + CURSOR_CALLPROC = (23, "Cursor.callproc", False) + CURSOR_SCROLL = (24, "Cursor.scroll", False) + CURSOR_COPY_FROM = (25, "Cursor.copy_from", False) + CURSOR_COPY_TO = (26, "Cursor.copy_to", False) + CURSOR_COPY_EXPERT = (27, "Cursor.copy_expert", False) + + # Cursor properties and attributes + CURSOR_CONNECTION = (28, "Cursor.connection", False) + CURSOR_ROWNUMBER = (29, "Cursor.rownumber", False) + CURSOR_NEXT = (30, "Cursor.__next__", False) + CURSOR_LASTROWID = (31, "Cursor.lastrowid", False) + + # AWS Advaced Python Wrapper Methods for + CONNECT = (32, "connect", True) + FORCE_CONNECT = (33, "force_connect", True) + INIT_HOST_PROVIDER = (34, "init_host_provider", True) + NOTIFY_CONNECTION_CHANGED = (35, "notify_connection_changed", True) + NOTIFY_HOST_LIST_CHANGED = (36, "notify_host_list_changed", True) + GET_HOST_INFO_BY_STRATEGY = (37, "get_host_info_by_strategy", True) + ACCEPTS_STRATEGY = (38, "accepts_strategy", True) + + # Special marker for all methods + ALL = (39, "*", False) + + def __init__(self, id: int, method_name: str, always_use_pipeline: bool): + self.id = id + self.method_name = method_name + self.always_use_pipeline = always_use_pipeline + + +# Reverse lookup for method name to enum +_NAME_TO_METHOD: Dict[str, DbApiMethod] = {method.method_name: method for method in DbApiMethod} + + +def get_method_by_name(method_name: str) -> DbApiMethod: + """Get DbApiMethod enum by method name string.""" + return _NAME_TO_METHOD.get(method_name, DbApiMethod.ALL) diff --git a/aws_advanced_python_wrapper/pg_driver_dialect.py b/aws_advanced_python_wrapper/pg_driver_dialect.py index fbe441f3..42366831 100644 --- a/aws_advanced_python_wrapper/pg_driver_dialect.py +++ b/aws_advanced_python_wrapper/pg_driver_dialect.py @@ -27,6 +27,7 @@ from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes from aws_advanced_python_wrapper.errors import UnsupportedOperationError +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, PropertiesUtils, @@ -42,20 +43,20 @@ class PgDriverDialect(DriverDialect): _dialect_code: str = DriverDialectCodes.PSYCOPG _network_bound_methods: Set[str] = { - "Connection.commit", - "Connection.autocommit", - "Connection.autocommit_setter", - "Connection.is_read_only", - "Connection.set_read_only", - "Connection.rollback", - "Connection.close", - "Connection.cursor", - "Cursor.close", - "Cursor.callproc", - "Cursor.execute", - "Cursor.fetchone", - "Cursor.fetchmany", - "Cursor.fetchall" + DbApiMethod.CONNECTION_COMMIT.method_name, + DbApiMethod.CONNECTION_AUTOCOMMIT.method_name, + DbApiMethod.CONNECTION_AUTOCOMMIT_SETTER.method_name, + DbApiMethod.CONNECTION_IS_READ_ONLY.method_name, + DbApiMethod.CONNECTION_SET_READ_ONLY.method_name, + DbApiMethod.CONNECTION_ROLLBACK.method_name, + DbApiMethod.CONNECTION_CLOSE.method_name, + DbApiMethod.CONNECTION_CURSOR.method_name, + DbApiMethod.CURSOR_CLOSE.method_name, + DbApiMethod.CURSOR_CALLPROC.method_name, + DbApiMethod.CURSOR_EXECUTE.method_name, + DbApiMethod.CURSOR_FETCHONE.method_name, + DbApiMethod.CURSOR_FETCHMANY.method_name, + DbApiMethod.CURSOR_FETCHALL.method_name } def is_dialect(self, connect_func: Callable) -> bool: diff --git a/aws_advanced_python_wrapper/plugin.py b/aws_advanced_python_wrapper/plugin.py index 6a1d1be4..7706364c 100644 --- a/aws_advanced_python_wrapper/plugin.py +++ b/aws_advanced_python_wrapper/plugin.py @@ -135,7 +135,8 @@ def init_host_provider( class PluginFactory(Protocol): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: pass diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index fa1879ab..789c364a 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -79,6 +79,7 @@ HostMonitoringV2PluginFactory from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.iam_plugin import IamAuthPluginFactory +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import CanReleaseResources from aws_advanced_python_wrapper.read_write_splitting_plugin import \ ReadWriteSplittingPluginFactory @@ -746,15 +747,14 @@ def get_status(self, clazz: Type[StatusType], key: str) -> Optional[StatusType]: return status -class PluginManager(CanReleaseResources): - _ALL_METHODS: str = "*" - _CONNECT_METHOD: str = "connect" - _FORCE_CONNECT_METHOD: str = "force_connect" - _NOTIFY_CONNECTION_CHANGED_METHOD: str = "notify_connection_changed" - _NOTIFY_HOST_LIST_CHANGED_METHOD: str = "notify_host_list_changed" - _GET_HOST_INFO_BY_STRATEGY_METHOD: str = "get_host_info_by_strategy" - _INIT_HOST_LIST_PROVIDER_METHOD: str = "init_host_provider" +class PluginChainCallableInfo: + """Container for plugin chain callable and subscription information.""" + def __init__(self, func: Callable, is_subscribed: bool): + self.func = func + self.is_subscribed = is_subscribed + +class PluginManager(CanReleaseResources): PLUGIN_FACTORIES: Dict[str, Type[PluginFactory]] = { "iam": IamAuthPluginFactory, "aws_secrets_manager": AwsSecretsManagerPluginFactory, @@ -807,11 +807,12 @@ class PluginManager(CanReleaseResources): def __init__( self, container: PluginServiceManagerContainer, props: Properties, telemetry_factory: TelemetryFactory): self._props: Properties = props - self._function_cache: Dict[str, Callable] = {} + self._function_cache: list[Optional[PluginChainCallableInfo]] = [None] * (DbApiMethod.ALL.id + 1) # last element in DbApiMethod self._container = container self._container.plugin_manager = self self._connection_provider_manager = ConnectionProviderManager() self._telemetry_factory = telemetry_factory + self._telemetry_in_use = telemetry_factory.in_use() self._plugins = self.get_plugins() @property @@ -906,101 +907,114 @@ def get_factory_weights(factory_types: List[Type[PluginFactory]]) -> Dict[Type[P return weights - def execute(self, target: object, method_name: str, target_driver_func: Callable, *args, **kwargs) -> Any: + def must_use_pipeline(self, method: DbApiMethod): + plugin_chain_info: Optional[PluginChainCallableInfo] = self._function_cache[method.id] + return method.always_use_pipeline or plugin_chain_info is None or plugin_chain_info.is_subscribed or self._telemetry_in_use + + def execute(self, target: object, method: DbApiMethod, target_driver_func: Callable, *args, **kwargs) -> Any: plugin_service = self._container.plugin_service driver_dialect = plugin_service.driver_dialect conn: Optional[Connection] = driver_dialect.get_connection_from_obj(target) current_conn: Optional[Connection] = driver_dialect.unwrap_connection(plugin_service.current_connection) - if method_name not in ["Connection.close", "Cursor.close"] and conn is not None and conn != current_conn: + if method not in [DbApiMethod.CONNECTION_CLOSE, DbApiMethod.CURSOR_CLOSE] and conn is not None and conn != current_conn: raise AwsWrapperError(Messages.get_formatted("PluginManager.MethodInvokedAgainstOldConnection", target)) - if conn is None and method_name in ["Connection.close", "Cursor.close"]: + if conn is None and method in [DbApiMethod.CONNECTION_CLOSE, DbApiMethod.CURSOR_CLOSE]: return - context: TelemetryContext - context = self._telemetry_factory.open_telemetry_context(method_name, TelemetryTraceLevel.TOP_LEVEL) - context.set_attribute("python_call", method_name) + context: TelemetryContext | None + context = self._telemetry_factory.open_telemetry_context(method.method_name, TelemetryTraceLevel.TOP_LEVEL) + if context is not None: + context.set_attribute("python_call", method.method_name) try: result = self._execute_with_subscribed_plugins( - method_name, + method, # next_plugin_func is defined later in make_pipeline - lambda plugin, next_plugin_func: plugin.execute(target, method_name, next_plugin_func, *args, **kwargs), + lambda plugin, next_plugin_func: plugin.execute(target, method.method_name, next_plugin_func, *args, **kwargs), target_driver_func, None) - - context.set_success(True) + if context is not None: + context.set_success(True) return result except Exception as e: - context.set_success(False) + if context is not None: + context.set_success(False) raise e finally: - context.close_context() + if context is not None: + context.close_context() def _execute_with_telemetry(self, plugin_name: str, func: Callable): context = self._telemetry_factory.open_telemetry_context(plugin_name, TelemetryTraceLevel.NESTED) try: return func() finally: - context.close_context() + if context is not None: + context.close_context() def _execute_with_subscribed_plugins( self, - method_name: str, + method: DbApiMethod, plugin_func: Callable, target_driver_func: Callable, plugin_to_skip: Optional[Plugin] = None): - cache_key = method_name if plugin_to_skip is None else method_name + plugin_to_skip.__class__.__name__ - pipeline_func: Optional[Callable] = self._function_cache.get(cache_key) - if pipeline_func is None: - pipeline_func = self._make_pipeline(method_name, plugin_to_skip) - self._function_cache[cache_key] = pipeline_func - - return pipeline_func(plugin_func, target_driver_func) + pipeline_func_info: Optional[PluginChainCallableInfo] = self._function_cache[method.id] + if pipeline_func_info is None: + pipeline_func_info = self._make_pipeline(method.method_name) + self._function_cache[method.id] = pipeline_func_info + + # Execute only if method needs to use pipeline, or a plugin is subscribed to this method + if method.always_use_pipeline or pipeline_func_info.is_subscribed: + return pipeline_func_info.func(plugin_func, target_driver_func, method.method_name, plugin_to_skip) + else: + return target_driver_func() # Builds the plugin pipeline function chain. The pipeline is built in a way that allows plugins to perform logic # both before and after the target driver function call. - def _make_pipeline(self, method_name: str, plugin_to_skip: Optional[Plugin] = None) -> Callable: + def _make_pipeline(self, method_name: str) -> PluginChainCallableInfo: pipeline_func: Optional[Callable] = None num_plugins: int = len(self._plugins) + is_subscribed: bool = False # Build the pipeline starting at the end and working backwards for i in range(num_plugins - 1, -1, -1): plugin: Plugin = self._plugins[i] - if plugin_to_skip is not None and plugin_to_skip == plugin: - continue subscribed_methods: Set[str] = plugin.subscribed_methods - is_subscribed: bool = PluginManager._ALL_METHODS in subscribed_methods or method_name in subscribed_methods - if not is_subscribed: - continue - - if pipeline_func is None: - # Defines the call to DefaultPlugin, which is the last plugin in the pipeline - pipeline_func = self._create_base_pipeline_func(plugin) - else: + is_plugin_subscribed = DbApiMethod.ALL.method_name in subscribed_methods or method_name in subscribed_methods + is_subscribed |= is_plugin_subscribed + + if is_plugin_subscribed: + if pipeline_func is None: + # Defines the call to DefaultPlugin, which is the last plugin in the pipeline + pipeline_func = self._create_base_pipeline_func(plugin) + continue pipeline_func = self._extend_pipeline_func(plugin, pipeline_func) if pipeline_func is None: raise AwsWrapperError(Messages.get("PluginManager.PipelineNone")) - else: - return pipeline_func + return PluginChainCallableInfo(pipeline_func, is_subscribed) def _create_base_pipeline_func(self, plugin: Plugin): # The plugin passed here will be the DefaultPlugin, which is the last plugin in the pipeline # The second arg to plugin_func is the next call in the pipeline. Here, it is the target driver function plugin_name = plugin.__class__.__name__ - return lambda plugin_func, target_driver_func: self._execute_with_telemetry( + return lambda plugin_func, target_driver_func, *_: self._execute_with_telemetry( plugin_name, lambda: plugin_func(plugin, target_driver_func)) def _extend_pipeline_func(self, plugin: Plugin, pipeline_so_far: Callable): # Defines the call to a plugin that precedes the DefaultPlugin in the pipeline # The second arg to plugin_func effectively appends the tail end of the pipeline to the current plugin's call plugin_name = plugin.__class__.__name__ - return lambda plugin_func, target_driver_func: self._execute_with_telemetry( - plugin_name, lambda: plugin_func(plugin, lambda: pipeline_so_far(plugin_func, target_driver_func))) + return lambda plugin_func, target_driver_func, method_name, plugin_to_skip: ( + pipeline_so_far(plugin_func, target_driver_func, method_name, plugin_to_skip) + if plugin_to_skip is not None and plugin_to_skip == plugin + else self._execute_with_telemetry( + plugin_name, lambda: plugin_func(plugin, lambda: pipeline_so_far(plugin_func, target_driver_func, method_name, plugin_to_skip))) + ) def connect( self, @@ -1010,17 +1024,18 @@ def connect( props: Properties, is_initial_connection: bool, plugin_to_skip: Optional[Plugin] = None) -> Connection: - context = self._telemetry_factory.open_telemetry_context("connect", TelemetryTraceLevel.NESTED) + context = self._telemetry_factory.open_telemetry_context(DbApiMethod.CONNECT.method_name, TelemetryTraceLevel.NESTED) try: return self._execute_with_subscribed_plugins( - PluginManager._CONNECT_METHOD, + DbApiMethod.CONNECT, lambda plugin, func: plugin.connect( target_func, driver_dialect, host_info, props, is_initial_connection, func), # The final connect action will be handled by the ConnectionProvider, so this lambda will not be called. lambda: None, plugin_to_skip) finally: - context.close_context() + if context is not None: + context.close_context() def force_connect( self, @@ -1031,7 +1046,7 @@ def force_connect( is_initial_connection: bool, plugin_to_skip: Optional[Plugin] = None) -> Connection: return self._execute_with_subscribed_plugins( - PluginManager._FORCE_CONNECT_METHOD, + DbApiMethod.FORCE_CONNECT, lambda plugin, func: plugin.force_connect( target_func, driver_dialect, host_info, props, is_initial_connection, func), # The final connect action will be handled by the ConnectionProvider, so this lambda will not be called. @@ -1041,7 +1056,7 @@ def force_connect( def notify_connection_changed(self, changes: Set[ConnectionEvent]) -> OldConnectionSuggestedAction: old_conn_suggestions: Set[OldConnectionSuggestedAction] = set() self._notify_subscribed_plugins( - PluginManager._NOTIFY_CONNECTION_CHANGED_METHOD, + DbApiMethod.NOTIFY_CONNECTION_CHANGED.method_name, lambda plugin: self._notify_plugin_conn_changed(plugin, changes, old_conn_suggestions)) if OldConnectionSuggestedAction.PRESERVE in old_conn_suggestions: @@ -1054,7 +1069,7 @@ def notify_connection_changed(self, changes: Set[ConnectionEvent]) -> OldConnect def _notify_subscribed_plugins(self, method_name: str, notify_plugin_func: Callable): for plugin in self._plugins: subscribed_methods = plugin.subscribed_methods - is_subscribed = PluginManager._ALL_METHODS in subscribed_methods or method_name in subscribed_methods + is_subscribed = DbApiMethod.ALL.method_name in subscribed_methods or method_name in subscribed_methods if is_subscribed: notify_plugin_func(plugin) @@ -1067,15 +1082,15 @@ def _notify_plugin_conn_changed( old_conn_suggestions.add(suggestion) def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]): - self._notify_subscribed_plugins(PluginManager._NOTIFY_HOST_LIST_CHANGED_METHOD, + self._notify_subscribed_plugins(DbApiMethod.NOTIFY_HOST_LIST_CHANGED.method_name, lambda plugin: plugin.notify_host_list_changed(changes)) def accepts_strategy(self, role: HostRole, strategy: str) -> bool: for plugin in self._plugins: plugin_subscribed_methods = plugin.subscribed_methods is_subscribed = \ - self._ALL_METHODS in plugin_subscribed_methods \ - or self._GET_HOST_INFO_BY_STRATEGY_METHOD in plugin_subscribed_methods + DbApiMethod.ALL.method_name in plugin_subscribed_methods \ + or DbApiMethod.GET_HOST_INFO_BY_STRATEGY.method_name in plugin_subscribed_methods if is_subscribed: if plugin.accepts_strategy(role, strategy): return True @@ -1086,8 +1101,8 @@ def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Op for plugin in self._plugins: plugin_subscribed_methods = plugin.subscribed_methods is_subscribed = \ - self._ALL_METHODS in plugin_subscribed_methods \ - or self._GET_HOST_INFO_BY_STRATEGY_METHOD in plugin_subscribed_methods + DbApiMethod.ALL.method_name in plugin_subscribed_methods \ + or DbApiMethod.GET_HOST_INFO_BY_STRATEGY.method_name in plugin_subscribed_methods if is_subscribed: try: @@ -1100,15 +1115,16 @@ def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Op return None def init_host_provider(self, props: Properties, host_list_provider_service: HostListProviderService): - context = self._telemetry_factory.open_telemetry_context("init_host_provider", TelemetryTraceLevel.NESTED) + context = self._telemetry_factory.open_telemetry_context(DbApiMethod.INIT_HOST_PROVIDER.method_name, TelemetryTraceLevel.NESTED) try: return self._execute_with_subscribed_plugins( - PluginManager._INIT_HOST_LIST_PROVIDER_METHOD, + DbApiMethod.INIT_HOST_PROVIDER, lambda plugin, func: plugin.init_host_provider(props, host_list_provider_service, func), lambda: None, None) finally: - context.close_context() + if context is not None: + context.close_context() def is_plugin_in_use(self, plugin_class: Type[Plugin]) -> bool: if not self._plugins: diff --git a/aws_advanced_python_wrapper/read_write_splitting_plugin.py b/aws_advanced_python_wrapper/read_write_splitting_plugin.py index aa3aaaa1..b7f7e4c4 100644 --- a/aws_advanced_python_wrapper/read_write_splitting_plugin.py +++ b/aws_advanced_python_wrapper/read_write_splitting_plugin.py @@ -30,6 +30,7 @@ from aws_advanced_python_wrapper.errors import (AwsWrapperError, FailoverError, ReadWriteSplittingError) from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages @@ -44,13 +45,12 @@ class ReadWriteSplittingConnectionManager(Plugin): """Base class that manages connection switching logic.""" _SUBSCRIBED_METHODS: Set[str] = { - "init_host_provider", - "connect", - "notify_connection_changed", - "Connection.set_read_only", + DbApiMethod.INIT_HOST_PROVIDER.method_name, + DbApiMethod.CONNECT.method_name, + DbApiMethod.NOTIFY_CONNECTION_CHANGED.method_name, + DbApiMethod.CONNECTION_SET_READ_ONLY.method_name, } _POOL_PROVIDER_CLASS_NAME = "aws_advanced_python_wrapper.sql_alchemy_connection_provider.SqlAlchemyPooledConnectionProvider" - _CLOSE_METHOD = "Connection.close" def __init__( self, @@ -128,7 +128,7 @@ def execute( raise AwsWrapperError(msg) if ( - method_name == "Connection.set_read_only" + method_name == DbApiMethod.CONNECTION_SET_READ_ONLY.method_name and args is not None and len(args) > 0 ): @@ -390,7 +390,7 @@ def _close_connection_if_idle(self, internal_conn: Optional[Connection]): if internal_conn != current_conn and self._is_connection_usable( internal_conn, driver_dialect ): - driver_dialect.execute(ReadWriteSplittingConnectionManager._CLOSE_METHOD, lambda: internal_conn.close()) + driver_dialect.execute(DbApiMethod.CONNECTION_CLOSE.method_name, lambda: internal_conn.close()) if internal_conn == self._writer_connection: self._writer_connection = None self._writer_host_info = None @@ -431,7 +431,7 @@ def _is_connection_usable(conn: Optional[Connection], driver_dialect: DriverDial def close_connection(conn: Optional[Connection], driver_dialect: DriverDialect): if conn is not None: try: - driver_dialect.execute(ReadWriteSplittingConnectionManager._CLOSE_METHOD, lambda: conn.close()) + driver_dialect.execute(DbApiMethod.CONNECTION_CLOSE.method_name, lambda: conn.close()) except Exception: # Swallow exception return @@ -677,5 +677,6 @@ def __init__(self, plugin_service: PluginService, props: Properties): class ReadWriteSplittingPluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return ReadWriteSplittingPlugin(plugin_service, props) diff --git a/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py b/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py index fc33d48d..da811295 100644 --- a/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py +++ b/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py @@ -302,5 +302,6 @@ def __init__(self, plugin_service: PluginService, props: Properties): class SimpleReadWriteSplittingPluginFactory(PluginFactory): - def get_instance(self, plugin_service, props: Properties): + @staticmethod + def get_instance(plugin_service, props: Properties): return SimpleReadWriteSplittingPlugin(plugin_service, props) diff --git a/aws_advanced_python_wrapper/stale_dns_plugin.py b/aws_advanced_python_wrapper/stale_dns_plugin.py index 310ee69c..dc5efd83 100644 --- a/aws_advanced_python_wrapper/stale_dns_plugin.py +++ b/aws_advanced_python_wrapper/stale_dns_plugin.py @@ -27,6 +27,7 @@ from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.hostinfo import HostRole +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages @@ -152,9 +153,9 @@ def _get_writer(self) -> Optional[HostInfo]: class StaleDnsPlugin(Plugin): - _SUBSCRIBED_METHODS: Set[str] = {"init_host_provider", - "connect", - "notify_host_list_changed"} + _SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.INIT_HOST_PROVIDER.method_name, + DbApiMethod.CONNECT.method_name, + DbApiMethod.NOTIFY_HOST_LIST_CHANGED.method_name} def __init__(self, plugin_service: PluginService) -> None: self._plugin_service = plugin_service @@ -203,5 +204,6 @@ def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]): class StaleDnsPluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return StaleDnsPlugin(plugin_service) diff --git a/aws_advanced_python_wrapper/utils/iam_utils.py b/aws_advanced_python_wrapper/utils/iam_utils.py index ecb5868f..3eb312ca 100644 --- a/aws_advanced_python_wrapper/utils/iam_utils.py +++ b/aws_advanced_python_wrapper/utils/iam_utils.py @@ -110,11 +110,13 @@ def generate_authentication_token( logger.debug("IamAuthUtils.GeneratedNewAuthToken", token) return token except Exception as ex: - context.set_success(False) - context.set_exception(ex) + if context is not None: + context.set_success(False) + context.set_exception(ex) raise ex finally: - context.close_context() + if context is not None: + context.close_context() class TokenInfo: diff --git a/aws_advanced_python_wrapper/utils/telemetry/default_telemetry_factory.py b/aws_advanced_python_wrapper/utils/telemetry/default_telemetry_factory.py index 0d9b18d1..a6d9846e 100644 --- a/aws_advanced_python_wrapper/utils/telemetry/default_telemetry_factory.py +++ b/aws_advanced_python_wrapper/utils/telemetry/default_telemetry_factory.py @@ -12,38 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import ClassVar + from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.telemetry.null_telemetry import \ - NullTelemetryFactory from aws_advanced_python_wrapper.utils.telemetry.open_telemetry import \ OpenTelemetryFactory from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( - TelemetryContext, TelemetryCounter, TelemetryFactory, TelemetryTraceLevel) + TelemetryContext, TelemetryCounter, TelemetryFactory, TelemetryGauge, + TelemetryTraceLevel) from aws_advanced_python_wrapper.utils.telemetry.xray_telemetry import \ XRayTelemetryFactory class DefaultTelemetryFactory(TelemetryFactory): + _OPEN_TELEMETRY_FACTORY: ClassVar[OpenTelemetryFactory] = OpenTelemetryFactory() + _XRAY_TELEMETRY_FACTORY: ClassVar[XRayTelemetryFactory] = XRayTelemetryFactory() + def __init__(self, properties: Properties): self._enable_telemetry = WrapperProperties.ENABLE_TELEMETRY.get(properties) self._telemetry_submit_toplevel = WrapperProperties.TELEMETRY_SUBMIT_TOPLEVEL.get(properties) self._telemetry_traces_backend = WrapperProperties.TELEMETRY_TRACES_BACKEND.get(properties) self._telemetry_metrics_backend = WrapperProperties.TELEMETRY_METRICS_BACKEND.get(properties) - self._traces_telemetry_factory: TelemetryFactory - self._metrics_telemetry_factory: TelemetryFactory + self._traces_telemetry_factory: TelemetryFactory | None + self._metrics_telemetry_factory: TelemetryFactory | None + self._telemetry_in_use: bool if self._enable_telemetry: if self._telemetry_traces_backend is not None: traces_backend = self._telemetry_traces_backend.upper() if traces_backend == "OTLP": - self._traces_telemetry_factory = OpenTelemetryFactory() + self._traces_telemetry_factory = DefaultTelemetryFactory._OPEN_TELEMETRY_FACTORY elif traces_backend == "XRAY": - self._traces_telemetry_factory = XRayTelemetryFactory() + self._traces_telemetry_factory = DefaultTelemetryFactory._XRAY_TELEMETRY_FACTORY elif traces_backend == "NONE": - self._traces_telemetry_factory = NullTelemetryFactory() + self._traces_telemetry_factory = None else: raise RuntimeError(Messages.get_formatted( "DefaultTelemetryFactory.InvalidTracingBackend", self._telemetry_traces_backend)) @@ -53,9 +58,9 @@ def __init__(self, properties: Properties): if self._telemetry_metrics_backend is not None: metrics_backend = self._telemetry_metrics_backend.upper() if metrics_backend == "OTLP": - self._metrics_telemetry_factory = OpenTelemetryFactory() + self._metrics_telemetry_factory = DefaultTelemetryFactory._OPEN_TELEMETRY_FACTORY elif metrics_backend == "NONE": - self._metrics_telemetry_factory = NullTelemetryFactory() + self._metrics_telemetry_factory = None else: raise RuntimeError(Messages.get_formatted( "DefaultTelemetryFactory.InvalidMetricsBackend", self._telemetry_metrics_backend)) @@ -63,20 +68,34 @@ def __init__(self, properties: Properties): raise RuntimeError(Messages.get_formatted("DefaultTelemetryFactory.NoMetricsBackendProvided")) else: - self._traces_telemetry_factory = NullTelemetryFactory() - self._metrics_telemetry_factory = NullTelemetryFactory() + self._traces_telemetry_factory = None + self._metrics_telemetry_factory = None + + self._telemetry_in_use = self._traces_telemetry_factory is not None or self._metrics_telemetry_factory is not None + + def open_telemetry_context(self, name: str, trace_level: TelemetryTraceLevel) -> TelemetryContext | None: + if self._traces_telemetry_factory is None: + return None - def open_telemetry_context(self, name: str, trace_level: TelemetryTraceLevel) -> TelemetryContext: if not self._telemetry_submit_toplevel and trace_level == TelemetryTraceLevel.TOP_LEVEL: return self._traces_telemetry_factory.open_telemetry_context(name, TelemetryTraceLevel.NESTED) return self._traces_telemetry_factory.open_telemetry_context(name, trace_level) def post_copy(self, context: TelemetryContext, trace_level: TelemetryTraceLevel): + if self._traces_telemetry_factory is None: + return self._traces_telemetry_factory.post_copy(context, trace_level) - def create_counter(self, name: str) -> TelemetryCounter: + def create_counter(self, name: str) -> TelemetryCounter | None: + if self._metrics_telemetry_factory is None: + return None return self._metrics_telemetry_factory.create_counter(name) - def create_gauge(self, name: str, callback): + def create_gauge(self, name: str, callback) -> TelemetryGauge | None: + if self._metrics_telemetry_factory is None: + return None return self._metrics_telemetry_factory.create_gauge(name, callback) + + def in_use(self) -> bool: + return self._telemetry_in_use diff --git a/aws_advanced_python_wrapper/utils/telemetry/null_telemetry.py b/aws_advanced_python_wrapper/utils/telemetry/null_telemetry.py index d7ddfb6a..55d5e194 100644 --- a/aws_advanced_python_wrapper/utils/telemetry/null_telemetry.py +++ b/aws_advanced_python_wrapper/utils/telemetry/null_telemetry.py @@ -12,44 +12,44 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import ClassVar + from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryContext, TelemetryCounter, TelemetryFactory, TelemetryGauge, TelemetryTraceLevel) class NullTelemetryContext(TelemetryContext): - def __init__(self, name: str): - self.name: str = name - def get_name(self): - return self.name + return "null" class NullTelemetryCounter(TelemetryCounter): - def __init__(self, name: str): - self.name: str = name - def get_name(self): - return self.name + return "null" class NullTelemetryGauge(TelemetryGauge): - def __init__(self, name: str): - self.name: str = name - def get_name(self): - return self.name + return "null" class NullTelemetryFactory(TelemetryFactory): + _NULL_CONTEXT: ClassVar[NullTelemetryContext] = NullTelemetryContext() + _NULL_COUNTER: ClassVar[NullTelemetryCounter] = NullTelemetryCounter() + _NULL_GAUGE: ClassVar[NullTelemetryGauge] = NullTelemetryGauge() + def open_telemetry_context(self, name: str, trace_level: TelemetryTraceLevel) -> TelemetryContext: - return NullTelemetryContext(name) + return NullTelemetryFactory._NULL_CONTEXT def post_copy(self, context: TelemetryContext, trace_level: TelemetryTraceLevel): pass # Do nothing def create_counter(self, name: str) -> TelemetryCounter: - return NullTelemetryCounter(name) + return NullTelemetryFactory._NULL_COUNTER + + def create_gauge(self, name: str, callback) -> TelemetryGauge: + return NullTelemetryFactory._NULL_GAUGE - def create_gauge(self, name: str, callback): - return NullTelemetryGauge(name) + def in_use(self) -> bool: + return False diff --git a/aws_advanced_python_wrapper/utils/telemetry/open_telemetry.py b/aws_advanced_python_wrapper/utils/telemetry/open_telemetry.py index 59087022..4ea991a4 100644 --- a/aws_advanced_python_wrapper/utils/telemetry/open_telemetry.py +++ b/aws_advanced_python_wrapper/utils/telemetry/open_telemetry.py @@ -173,7 +173,7 @@ def _callback_observation(self, options: CallbackOptions): class OpenTelemetryFactory(TelemetryFactory): - def open_telemetry_context(self, name: str, trace_level: TelemetryTraceLevel) -> TelemetryContext: + def open_telemetry_context(self, name: str, trace_level: TelemetryTraceLevel) -> TelemetryContext | None: return OpenTelemetryContext(trace.get_tracer(INSTRUMENTATION_NAME), name, trace_level) # type: ignore def post_copy(self, context: TelemetryContext, trace_level: TelemetryTraceLevel): @@ -182,8 +182,11 @@ def post_copy(self, context: TelemetryContext, trace_level: TelemetryTraceLevel) else: raise RuntimeError(Messages.get_formatted("OpenTelemetryFactory.WrongParameterType", type(context))) - def create_counter(self, name: str) -> TelemetryCounter: + def create_counter(self, name: str) -> TelemetryCounter | None: return OpenTelemetryCounter(get_meter(INSTRUMENTATION_NAME), name) - def create_gauge(self, name: str, callback: Callable[[], Union[float, int]]): + def create_gauge(self, name: str, callback: Callable[[], Union[float, int]]) -> TelemetryGauge | None: return OpenTelemetryGauge(get_meter(INSTRUMENTATION_NAME), name, callback) + + def in_use(self) -> bool: + return True diff --git a/aws_advanced_python_wrapper/utils/telemetry/telemetry.py b/aws_advanced_python_wrapper/utils/telemetry/telemetry.py index 8586d547..ba3fffca 100644 --- a/aws_advanced_python_wrapper/utils/telemetry/telemetry.py +++ b/aws_advanced_python_wrapper/utils/telemetry/telemetry.py @@ -69,7 +69,7 @@ class TelemetryGauge(ABC): class TelemetryFactory(ABC): @abstractmethod - def open_telemetry_context(self, name: str, trace_level: TelemetryTraceLevel) -> TelemetryContext: + def open_telemetry_context(self, name: str, trace_level: TelemetryTraceLevel) -> TelemetryContext | None: pass @abstractmethod @@ -77,9 +77,13 @@ def post_copy(self, context: TelemetryContext, trace_level: TelemetryTraceLevel) pass @abstractmethod - def create_counter(self, name: str) -> TelemetryCounter: + def create_counter(self, name: str) -> TelemetryCounter | None: pass @abstractmethod - def create_gauge(self, name: str, callback): + def create_gauge(self, name: str, callback) -> TelemetryGauge | None: + pass + + @abstractmethod + def in_use(self) -> bool: pass diff --git a/aws_advanced_python_wrapper/utils/telemetry/xray_telemetry.py b/aws_advanced_python_wrapper/utils/telemetry/xray_telemetry.py index b3a7bf53..5ec1ac4c 100644 --- a/aws_advanced_python_wrapper/utils/telemetry/xray_telemetry.py +++ b/aws_advanced_python_wrapper/utils/telemetry/xray_telemetry.py @@ -28,7 +28,7 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryConst, TelemetryContext, TelemetryCounter, TelemetryFactory, - TelemetryTraceLevel) + TelemetryGauge, TelemetryTraceLevel) logger = Logger(__name__) @@ -110,7 +110,7 @@ def _clone_and_close_context(context: XRayTelemetryContext, trace_level: Telemet class XRayTelemetryFactory(TelemetryFactory): - def open_telemetry_context(self, name: str, trace_level: TelemetryTraceLevel) -> TelemetryContext: + def open_telemetry_context(self, name: str, trace_level: TelemetryTraceLevel) -> TelemetryContext | None: return XRayTelemetryContext(name, trace_level) def post_copy(self, context: TelemetryContext, trace_level: TelemetryTraceLevel): @@ -119,8 +119,11 @@ def post_copy(self, context: TelemetryContext, trace_level: TelemetryTraceLevel) else: raise RuntimeError(Messages.get_formatted("XRayTelemetryFactory.WrongParameterType", type(context))) - def create_counter(self, name: str) -> TelemetryCounter: + def create_counter(self, name: str) -> TelemetryCounter | None: raise RuntimeError(Messages.get_formatted("XRayTelemetryFactory.MetricsNotSupported")) - def create_gauge(self, name: str, callback): + def create_gauge(self, name: str, callback) -> TelemetryGauge | None: raise RuntimeError(Messages.get_formatted("XRayTelemetryFactory.MetricsNotSupported")) + + def in_use(self) -> bool: + return True diff --git a/aws_advanced_python_wrapper/wrapper.py b/aws_advanced_python_wrapper/wrapper.py index e04d293f..7cfac6b8 100644 --- a/aws_advanced_python_wrapper/wrapper.py +++ b/aws_advanced_python_wrapper/wrapper.py @@ -25,6 +25,7 @@ from aws_advanced_python_wrapper.errors import (AwsWrapperError, FailoverSuccessError) from aws_advanced_python_wrapper.pep249 import Connection, Cursor, Error +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import CanReleaseResources from aws_advanced_python_wrapper.plugin_service import ( PluginManager, PluginService, PluginServiceImpl, @@ -91,14 +92,14 @@ def is_closed(self): def read_only(self) -> bool: return self._plugin_manager.execute( self.target_connection, - "Connection.is_read_only", + DbApiMethod.CONNECTION_IS_READ_ONLY, lambda: self._is_read_only()) @read_only.setter def read_only(self, val: bool): self._plugin_manager.execute( self.target_connection, - "Connection.set_read_only", + DbApiMethod.CONNECTION_SET_READ_ONLY, lambda: self._set_read_only(val), val) @@ -116,14 +117,14 @@ def _set_read_only(self, val: bool): def autocommit(self): return self._plugin_manager.execute( self.target_connection, - "Connection.autocommit", + DbApiMethod.CONNECTION_AUTOCOMMIT, lambda: self._plugin_service.driver_dialect.get_autocommit(self.target_connection)) @autocommit.setter def autocommit(self, val: bool): self._plugin_manager.execute( self.target_connection, - "Connection.autocommit_setter", + DbApiMethod.CONNECTION_AUTOCOMMIT_SETTER, lambda: self._set_autocommit(val), val) @@ -166,48 +167,50 @@ def connect( return AwsWrapperConnection(target_func, plugin_service, plugin_service, plugin_manager) except Exception as ex: - context.set_exception(ex) - context.set_success(False) + if context is not None: + context.set_exception(ex) + context.set_success(False) raise ex finally: - context.close_context() + if context is not None: + context.close_context() def close(self) -> None: - self._plugin_manager.execute(self.target_connection, "Connection.close", + self._plugin_manager.execute(self.target_connection, DbApiMethod.CONNECTION_CLOSE, lambda: self.target_connection.close()) def cursor(self, *args: Any, **kwargs: Any) -> AwsWrapperCursor: - _cursor = self._plugin_manager.execute(self.target_connection, "Connection.cursor", + _cursor = self._plugin_manager.execute(self.target_connection, DbApiMethod.CONNECTION_CURSOR, lambda: self.target_connection.cursor(*args, **kwargs), *args, **kwargs) return AwsWrapperCursor(self, self._plugin_service, self._plugin_manager, _cursor) def commit(self) -> None: - self._plugin_manager.execute(self.target_connection, "Connection.commit", + self._plugin_manager.execute(self.target_connection, DbApiMethod.CONNECTION_COMMIT, lambda: self.target_connection.commit()) def rollback(self) -> None: - self._plugin_manager.execute(self.target_connection, "Connection.rollback", + self._plugin_manager.execute(self.target_connection, DbApiMethod.CONNECTION_ROLLBACK, lambda: self.target_connection.rollback()) def tpc_begin(self, xid: Any) -> None: - self._plugin_manager.execute(self.target_connection, "Connection.tpc_begin", + self._plugin_manager.execute(self.target_connection, DbApiMethod.CONNECTION_TPC_BEGIN, lambda: self.target_connection.tpc_begin(xid), xid) def tpc_prepare(self) -> None: - self._plugin_manager.execute(self.target_connection, "Connection.tpc_prepare", + self._plugin_manager.execute(self.target_connection, DbApiMethod.CONNECTION_TPC_PREPARE, lambda: self.target_connection.tpc_prepare()) def tpc_commit(self, xid: Any = None) -> None: - self._plugin_manager.execute(self.target_connection, "Connection.tpc_commit", + self._plugin_manager.execute(self.target_connection, DbApiMethod.CONNECTION_TPC_COMMIT, lambda: self.target_connection.tpc_commit(xid), xid) def tpc_rollback(self, xid: Any = None) -> None: - self._plugin_manager.execute(self.target_connection, "Connection.tpc_rollback", + self._plugin_manager.execute(self.target_connection, DbApiMethod.CONNECTION_TPC_ROLLBACK, lambda: self.target_connection.tpc_rollback(xid), xid) def tpc_recover(self) -> Any: - return self._plugin_manager.execute(self.target_connection, "Connection.tpc_recover", + return self._plugin_manager.execute(self.target_connection, DbApiMethod.CONNECTION_TPC_RECOVER, lambda: self.target_connection.tpc_recover()) def release_resources(self): @@ -226,7 +229,7 @@ def __enter__(self: AwsWrapperConnection) -> AwsWrapperConnection: return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - self._plugin_manager.execute(self.target_connection, "Connection.close", + self._plugin_manager.execute(self.target_connection, DbApiMethod.CONNECTION_CLOSE, lambda: self.target_connection.close(), exc_type, exc_val, exc_tb) @@ -266,11 +269,11 @@ def arraysize(self) -> int: return self.target_cursor.arraysize def close(self) -> None: - self._plugin_manager.execute(self.target_cursor, "Cursor.close", + self._plugin_manager.execute(self.target_cursor, DbApiMethod.CURSOR_CLOSE, lambda: self.target_cursor.close()) def callproc(self, *args: Any, **kwargs: Any): - return self._plugin_manager.execute(self.target_cursor, "Cursor.callproc", + return self._plugin_manager.execute(self.target_cursor, DbApiMethod.CURSOR_CALLPROC, lambda: self.target_cursor.callproc(**kwargs), *args, **kwargs) def execute( @@ -279,7 +282,7 @@ def execute( **kwargs: Any ) -> AwsWrapperCursor: try: - return self._plugin_manager.execute(self.target_cursor, "Cursor.execute", + return self._plugin_manager.execute(self.target_cursor, DbApiMethod.CURSOR_EXECUTE, lambda: self.target_cursor.execute(*args, **kwargs), *args, **kwargs) except FailoverSuccessError as e: self._target_cursor = self.connection.target_connection.cursor() @@ -290,35 +293,35 @@ def executemany( *args: Any, **kwargs: Any ) -> None: - self._plugin_manager.execute(self.target_cursor, "Cursor.executemany", + self._plugin_manager.execute(self.target_cursor, DbApiMethod.CURSOR_EXECUTEMANY, lambda: self.target_cursor.executemany(*args, **kwargs), *args, **kwargs) def nextset(self) -> bool: - return self._plugin_manager.execute(self.target_cursor, "Cursor.nextset", + return self._plugin_manager.execute(self.target_cursor, DbApiMethod.CURSOR_NEXTSET, lambda: self.target_cursor.nextset()) def fetchone(self) -> Any: - return self._plugin_manager.execute(self.target_cursor, "Cursor.fetchone", + return self._plugin_manager.execute(self.target_cursor, DbApiMethod.CURSOR_FETCHONE, lambda: self.target_cursor.fetchone()) def fetchmany(self, size: int = 0) -> List[Any]: - return self._plugin_manager.execute(self.target_cursor, "Cursor.fetchmany", + return self._plugin_manager.execute(self.target_cursor, DbApiMethod.CURSOR_FETCHMANY, lambda: self.target_cursor.fetchmany(size), size) def fetchall(self) -> List[Any]: - return self._plugin_manager.execute(self.target_cursor, "Cursor.fetchall", + return self._plugin_manager.execute(self.target_cursor, DbApiMethod.CURSOR_FETCHALL, lambda: self.target_cursor.fetchall()) def __iter__(self) -> Iterator[Any]: return self.target_cursor.__iter__() def setinputsizes(self, sizes: Any) -> None: - return self._plugin_manager.execute(self.target_cursor, "Cursor.setinputsizes", + return self._plugin_manager.execute(self.target_cursor, DbApiMethod.CURSOR_SETINPUTSIZES, lambda: self.target_cursor.setinputsizes(sizes), sizes) def setoutputsize(self, size: Any, column: Optional[int] = None) -> None: - return self._plugin_manager.execute(self.target_cursor, "Cursor.setoutputsize", + return self._plugin_manager.execute(self.target_cursor, DbApiMethod.CURSOR_SETOUTPUTSIZE, lambda: self.target_cursor.setoutputsize(size, column), size, column) def __enter__(self: AwsWrapperCursor) -> AwsWrapperCursor: diff --git a/benchmarks/benchmark_plugin.py b/benchmarks/benchmark_plugin.py index 602f70a3..4dde356b 100644 --- a/benchmarks/benchmark_plugin.py +++ b/benchmarks/benchmark_plugin.py @@ -87,5 +87,6 @@ def init_host_provider( class BenchmarkPluginFactory(PluginFactory): - def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: return BenchmarkPlugin() diff --git a/benchmarks/plugin_benchmarks.py b/benchmarks/plugin_benchmarks.py index fb8a0c71..7106c35f 100644 --- a/benchmarks/plugin_benchmarks.py +++ b/benchmarks/plugin_benchmarks.py @@ -30,6 +30,8 @@ from aws_advanced_python_wrapper.sql_alchemy_connection_provider import \ SqlAlchemyPooledConnectionProvider from aws_advanced_python_wrapper.utils.properties import Properties +from aws_advanced_python_wrapper.utils.telemetry.null_telemetry import \ + NullTelemetryFactory host_info = HostInfo(host="host", port=1234) @@ -83,7 +85,7 @@ def plugin_service_manager_container_mock(mocker, plugin_service_mock): @pytest.fixture def plugin_manager_with_execute_time_plugin(plugin_service_manager_container_mock, props_with_execute_time_plugin): - manager: PluginManager = PluginManager(plugin_service_manager_container_mock, props_with_execute_time_plugin) + manager: PluginManager = PluginManager(plugin_service_manager_container_mock, props_with_execute_time_plugin, NullTelemetryFactory()) return manager @@ -92,7 +94,7 @@ def plugin_manager_with_aurora_connection_tracker_plugin( plugin_service_manager_container_mock, props_with_aurora_connection_tracker_plugin): manager: PluginManager = PluginManager( - plugin_service_manager_container_mock, props_with_aurora_connection_tracker_plugin) + plugin_service_manager_container_mock, props_with_aurora_connection_tracker_plugin, NullTelemetryFactory()) return manager @@ -101,7 +103,7 @@ def plugin_manager_with_execute_time_and_aurora_connection_tracker_plugin( plugin_service_manager_container_mock, props_with_execute_time_and_aurora_connection_tracker_plugin): manager: PluginManager = PluginManager( - plugin_service_manager_container_mock, props_with_execute_time_and_aurora_connection_tracker_plugin) + plugin_service_manager_container_mock, props_with_execute_time_and_aurora_connection_tracker_plugin, NullTelemetryFactory()) return manager @@ -110,7 +112,7 @@ def plugin_manager_with_read_write_splitting_plugin( plugin_service_manager_container_mock, props_with_read_write_splitting_plugin): manager: PluginManager = PluginManager( - plugin_service_manager_container_mock, props_with_read_write_splitting_plugin) + plugin_service_manager_container_mock, props_with_read_write_splitting_plugin, NullTelemetryFactory()) return manager @@ -119,7 +121,7 @@ def plugin_manager_with_aurora_connection_tracker_and_read_write_splitting_plugi plugin_service_manager_container_mock, props_with_aurora_connection_tracker_and_read_write_splitting_plugin): manager: PluginManager = PluginManager( - plugin_service_manager_container_mock, props_with_aurora_connection_tracker_and_read_write_splitting_plugin) + plugin_service_manager_container_mock, props_with_aurora_connection_tracker_and_read_write_splitting_plugin, NullTelemetryFactory()) return manager diff --git a/benchmarks/plugin_manager_benchmarks.py b/benchmarks/plugin_manager_benchmarks.py index 0ed9ae3a..cb659d3a 100644 --- a/benchmarks/plugin_manager_benchmarks.py +++ b/benchmarks/plugin_manager_benchmarks.py @@ -25,9 +25,12 @@ from aws_advanced_python_wrapper.driver_configuration_profiles import \ DriverConfigurationProfiles from aws_advanced_python_wrapper.hostinfo import HostInfo +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin_service import ( PluginManager, PluginServiceManagerContainer) from aws_advanced_python_wrapper.utils.properties import Properties +from aws_advanced_python_wrapper.utils.telemetry.null_telemetry import \ + NullTelemetryFactory from benchmarks.benchmark_plugin import BenchmarkPluginFactory host_info = HostInfo(host="host", port=1234) @@ -76,31 +79,31 @@ def plugin_service_manager_container_mock(mocker, plugin_service_mock): @pytest.fixture def plugin_manager_with_no_plugins(plugin_service_manager_container_mock, props_without_plugins): - manager = PluginManager(plugin_service_manager_container_mock, props_without_plugins) + manager = PluginManager(plugin_service_manager_container_mock, props_without_plugins, NullTelemetryFactory()) return manager @pytest.fixture def plugin_manager_with_plugins(plugin_service_manager_container_mock, props_with_plugins): - manager = PluginManager(plugin_service_manager_container_mock, props_with_plugins) + manager = PluginManager(plugin_service_manager_container_mock, props_with_plugins, NullTelemetryFactory()) return manager def init_plugin_manager(plugin_service_manager_container, props): - manager = PluginManager(plugin_service_manager_container, props) + manager = PluginManager(plugin_service_manager_container, props, NullTelemetryFactory()) return manager +@pytest.mark.benchmark(group="plugin_manager_init") def test_init_plugin_manager_with_no_plugins( benchmark, plugin_service_manager_container_mock, props_without_plugins): - result = benchmark(init_plugin_manager, plugin_service_manager_container_mock, props_without_plugins) assert result is not None +@pytest.mark.benchmark(group="plugin_manager_init") def test_init_plugin_manager_with_plugins( benchmark, plugin_service_manager_container_mock, props_with_plugins): - result = benchmark(init_plugin_manager, plugin_service_manager_container_mock, props_with_plugins) assert result is not None @@ -110,26 +113,30 @@ def connect(mocker, plugin_manager, props): return conn +@pytest.mark.benchmark(group="plugin_manager_connect") def test_connect_with_no_plugins(benchmark, mocker, plugin_manager_with_no_plugins, props_without_plugins): result = benchmark(connect, mocker, plugin_manager_with_no_plugins, props_without_plugins) assert result is not None +@pytest.mark.benchmark(group="plugin_manager_connect") def test_connect_with_plugins(benchmark, mocker, plugin_manager_with_plugins, props_with_plugins): result = benchmark(connect, mocker, plugin_manager_with_plugins, props_with_plugins) assert result is not None def execute(mocker, plugin_manager, statement): - result = plugin_manager.execute(mocker.MagicMock(), "Statement.execute", statement) + result = plugin_manager.execute(mocker.MagicMock(), DbApiMethod.CURSOR_EXECUTE, statement) return result +@pytest.mark.benchmark(group="plugin_manager_execute") def test_execute_with_no_plugins(benchmark, mocker, plugin_manager_with_no_plugins, statement_mock): result = benchmark(execute, mocker, plugin_manager_with_no_plugins, statement_mock) assert result is not None +@pytest.mark.benchmark(group="plugin_manager_execute") def test_execute_with_plugins(benchmark, mocker, plugin_manager_with_plugins, statement_mock): result = benchmark(execute, mocker, plugin_manager_with_plugins, statement_mock) assert result is not None @@ -139,10 +146,12 @@ def init_host_provider(mocker, plugin_manager, props): plugin_manager.init_host_provider(props, mocker.MagicMock()) +@pytest.mark.benchmark(group="plugin_manager_host_provider") def test_init_host_provider_with_no_plugins(benchmark, mocker, plugin_manager_with_no_plugins, props_without_plugins): benchmark(init_host_provider, mocker, plugin_manager_with_no_plugins, props_without_plugins) +@pytest.mark.benchmark(group="plugin_manager_host_provider") def test_init_host_provider_with_plugins(benchmark, mocker, plugin_manager_with_plugins, props_with_plugins): benchmark(init_host_provider, mocker, plugin_manager_with_plugins, props_with_plugins) @@ -152,11 +161,13 @@ def notify_connection_changed(mocker, plugin_manager): return result +@pytest.mark.benchmark(group="plugin_manager_notify") def test_notify_connection_changed_with_no_plugins(benchmark, mocker, plugin_manager_with_no_plugins): result = benchmark(notify_connection_changed, mocker, plugin_manager_with_no_plugins) assert result is not None +@pytest.mark.benchmark(group="plugin_manager_notify") def test_notify_connection_changed_with_plugins(benchmark, mocker, plugin_manager_with_plugins): result = benchmark(notify_connection_changed, mocker, plugin_manager_with_plugins) assert result is not None @@ -166,9 +177,11 @@ def release_resources(plugin_manager): plugin_manager.release_resources() +@pytest.mark.benchmark(group="plugin_manager_cleanup") def test_release_resources_with_no_plugins(benchmark, plugin_manager_with_no_plugins): benchmark(release_resources, plugin_manager_with_no_plugins) +@pytest.mark.benchmark(group="plugin_manager_cleanup") def test_release_resources_with_plugins(benchmark, plugin_manager_with_plugins): benchmark(release_resources, plugin_manager_with_plugins) diff --git a/tests/unit/test_plugin_manager.py b/tests/unit/test_plugin_manager.py index 3fbb1e09..fd3b758d 100644 --- a/tests/unit/test_plugin_manager.py +++ b/tests/unit/test_plugin_manager.py @@ -20,6 +20,7 @@ from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.pep249 import Connection +from enum import Enum from typing import Any, Callable, Dict, List, Optional, Set import psycopg @@ -34,6 +35,7 @@ HostMonitoringPlugin from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin from aws_advanced_python_wrapper.plugin_service import PluginManager from aws_advanced_python_wrapper.utils.notifications import ( @@ -41,6 +43,20 @@ from aws_advanced_python_wrapper.utils.properties import Properties +class DbApiMethodTest(Enum): + """Test-specific enumeration of API methods for testing purposes.""" + + TEST_CALL_A = (DbApiMethod.ALL.id + 1, "test_call_a", False) + TEST_CALL_B = (DbApiMethod.ALL.id + 2, "test_call_b", False) + TEST_CALL_C = (DbApiMethod.ALL.id + 3, "test_call_c", False) + TEST_EXECUTE = (DbApiMethod.ALL.id + 4, "test_execute", False) + + def __init__(self, id: int, method_name: str, always_use_pipeline: bool): + self.id = id + self.method_name = method_name + self.always_use_pipeline = always_use_pipeline + + @pytest.fixture def mock_conn(mocker): return mocker.MagicMock(spec=psycopg.Connection) @@ -141,7 +157,7 @@ def test_unknown_profile(mocker, mock_telemetry_factory): PluginManager(mocker.MagicMock(), props, mock_telemetry_factory()) -def test_execute_call_a(mocker, mock_conn, container, mock_plugin_service, mock_driver_dialect, mock_telemetry_factory): +def test_execute_call_method(mocker, mock_conn, container, mock_plugin_service, mock_driver_dialect, mock_telemetry_factory): calls = [] args = [10, "arg2", 3.33] plugins = [TestPluginOne(calls), TestPluginTwo(calls), TestPluginThree(calls)] @@ -152,12 +168,12 @@ def test_execute_call_a(mocker, mock_conn, container, mock_plugin_service, mock_ manager._container = container manager._plugins = plugins manager._telemetry_factory = mock_telemetry_factory - manager._function_cache = {} + manager._function_cache = [None] * (DbApiMethodTest.TEST_EXECUTE.id + 1) make_pipeline_func = mocker.patch.object(manager, '_make_pipeline', wraps=manager._make_pipeline) - result = manager.execute(mock_conn, "test_call_a", lambda: _target_call(calls), *args) + result = manager.execute(mock_conn, DbApiMethodTest.TEST_CALL_A, lambda: _target_call(calls), *args) - make_pipeline_func.assert_called_once_with("test_call_a", None) + make_pipeline_func.assert_called_once_with(DbApiMethodTest.TEST_CALL_A.method_name) assert result == "result_value" assert len(calls) == 7 assert calls[0] == "TestPluginOne:before execute" @@ -169,10 +185,10 @@ def test_execute_call_a(mocker, mock_conn, container, mock_plugin_service, mock_ assert calls[6] == "TestPluginOne:after execute" calls.clear() - result = manager.execute(mock_conn, "test_call_a", lambda: _target_call(calls), *args) + result = manager.execute(mock_conn, DbApiMethodTest.TEST_CALL_A, lambda: _target_call(calls), *args) # The first execute call should cache the pipeline - make_pipeline_func.assert_called_once_with("test_call_a", None) + make_pipeline_func.assert_called_once_with(DbApiMethodTest.TEST_CALL_A.method_name) assert result == "result_value" assert len(calls) == 7 assert calls[0] == "TestPluginOne:before execute" @@ -199,9 +215,9 @@ def test_execute_call_b(mocker, container, mock_driver_dialect, mock_telemetry_f manager._container = container manager._plugins = plugins manager._telemetry_factory = mock_telemetry_factory - manager._function_cache = {} + manager._function_cache = [None] * (DbApiMethodTest.TEST_EXECUTE.id + 1) - result = manager.execute(mock_conn, "test_call_b", lambda: _target_call(calls), *args) + result = manager.execute(mock_conn, DbApiMethodTest.TEST_CALL_B, lambda: _target_call(calls), *args) assert result == "result_value" assert len(calls) == 5 @@ -222,9 +238,9 @@ def test_execute_call_c(mocker, container, mock_driver_dialect, mock_telemetry_f manager._container = container manager._plugins = plugins manager._telemetry_factory = mock_telemetry_factory - manager._function_cache = {} + manager._function_cache = [None] * (DbApiMethodTest.TEST_EXECUTE.id + 1) - result = manager.execute(mock_conn, "test_call_c", lambda: _target_call(calls), *args) + result = manager.execute(mock_conn, DbApiMethodTest.TEST_CALL_C, lambda: _target_call(calls), *args) assert result == "result_value" assert len(calls) == 3 @@ -239,12 +255,12 @@ def test_execute_against_old_target(mocker, container, mock_driver_dialect, mock manager._container = container manager._plugins = "" manager._telemetry_factory = mock_telemetry_factory - manager._function_cache = {} + manager._function_cache = [None] * (DbApiMethodTest.TEST_EXECUTE.id + 1) # Set current connection to a new connection object container.plugin_service.current_connection = mocker.MagicMock(spec=psycopg.Connection) with pytest.raises(AwsWrapperError): - manager.execute(mock_conn, "test_execute", lambda: _target_call([])) + manager.execute(mock_conn, DbApiMethodTest.TEST_EXECUTE, lambda: _target_call([])) def test_connect(mocker, container, mock_conn, mock_driver_dialect, mock_telemetry_factory): @@ -255,7 +271,7 @@ def test_connect(mocker, container, mock_conn, mock_driver_dialect, mock_telemet mocker.patch.object(PluginManager, "__init__", lambda w, x, y, z: None) manager = PluginManager(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) manager._plugins = plugins - manager._function_cache = {} + manager._function_cache = [None] * (DbApiMethodTest.TEST_EXECUTE.id + 1) manager._telemetry_factory = mock_telemetry_factory manager._container = container @@ -278,7 +294,7 @@ def test_connect__skip_plugin(mocker, container, mock_conn, mock_driver_dialect, mocker.patch.object(PluginManager, "__init__", lambda w, x, y, z: None) manager = PluginManager(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) manager._plugins = plugins - manager._function_cache = {} + manager._function_cache = [None] * (DbApiMethodTest.TEST_EXECUTE.id + 1) manager._telemetry_factory = mock_telemetry_factory manager._container = container @@ -298,7 +314,7 @@ def test_force_connect(mocker, container, mock_conn, mock_driver_dialect, mock_t mocker.patch.object(PluginManager, "__init__", lambda w, x, y, z: None) manager = PluginManager(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) manager._plugins = plugins - manager._function_cache = {} + manager._function_cache = [None] * (DbApiMethodTest.TEST_EXECUTE.id + 1) manager._telemetry_factory = mock_telemetry_factory manager._container = container @@ -306,7 +322,7 @@ def test_force_connect(mocker, container, mock_conn, mock_driver_dialect, mock_t # The first call to force_connect should generate the plugin pipeline and cache it result = manager.force_connect(mocker.MagicMock(), mocker.MagicMock(), HostInfo("localhost"), Properties(), True) - make_pipeline_func.assert_called_once_with("force_connect", None) + make_pipeline_func.assert_called_once_with(DbApiMethod.FORCE_CONNECT.method_name) assert result == mock_conn assert len(calls) == 4 assert calls[0] == "TestPluginOne:before forceConnect" @@ -319,7 +335,7 @@ def test_force_connect(mocker, container, mock_conn, mock_driver_dialect, mock_t result = manager.force_connect(mocker.MagicMock(), mocker.MagicMock(), HostInfo("localhost"), Properties(), True) # The second call should have used the cached plugin pipeline, so make_pipeline should not have been called again - make_pipeline_func.assert_called_once_with("force_connect", None) + make_pipeline_func.assert_called_once_with(DbApiMethod.FORCE_CONNECT.method_name) assert result == mock_conn assert len(calls) == 4 assert calls[0] == "TestPluginOne:before forceConnect" @@ -336,7 +352,7 @@ def test_force_connect__cached(mocker, container, mock_conn, mock_driver_dialect mocker.patch.object(PluginManager, "__init__", lambda w, x, y, z: None) manager = PluginManager(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) manager._plugins = plugins - manager._function_cache = {} + manager._function_cache = [None] * (DbApiMethodTest.TEST_EXECUTE.id + 1) manager._telemetry_factory = mock_telemetry_factory manager._container = container @@ -359,7 +375,7 @@ def test_exception_before_connect(mocker, container, mock_telemetry_factory): manager = PluginManager(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) manager._plugins = plugins manager._telemetry_factory = mock_telemetry_factory - manager._function_cache = {} + manager._function_cache = [None] * (DbApiMethodTest.TEST_EXECUTE.id + 1) manager._container = container with pytest.raises(AwsWrapperError): @@ -381,7 +397,7 @@ def test_exception_after_connect(mocker, container, mock_telemetry_factory): manager = PluginManager(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) manager._plugins = plugins manager._telemetry_factory = mock_telemetry_factory - manager._function_cache = {} + manager._function_cache = [None] * (DbApiMethodTest.TEST_EXECUTE.id + 1) manager._container = container with pytest.raises(AwsWrapperError): @@ -424,7 +440,7 @@ def test_notify_connection_changed(mocker, mock_telemetry_factory): manager = PluginManager(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) manager._plugins = plugins manager._telemetry_factory = mock_telemetry_factory - manager._function_cache = {} + manager._function_cache = [None] * (DbApiMethodTest.TEST_EXECUTE.id + 1) old_connection_suggestion = manager.notify_connection_changed({ConnectionEvent.CONNECTION_OBJECT_CHANGED}) @@ -451,7 +467,7 @@ def test_notify_host_list_changed(mocker, mock_telemetry_factory): manager = PluginManager(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) manager._plugins = plugins manager._telemetry_factory = mock_telemetry_factory - manager._function_cache = {} + manager._function_cache = [None] * (DbApiMethodTest.TEST_EXECUTE.id + 1) manager.notify_host_list_changed( {"host-1": {HostEvent.CONVERTED_TO_READER}, "host-2": {HostEvent.CONVERTED_TO_WRITER}}) @@ -532,7 +548,7 @@ class TestPluginTwo(TestPlugin): @property def subscribed_methods(self) -> Set[str]: - return {"test_call_a", "test_call_b", "notify_connection_changed"} + return {DbApiMethodTest.TEST_CALL_A.method_name, DbApiMethodTest.TEST_CALL_B.method_name, DbApiMethod.NOTIFY_CONNECTION_CHANGED.method_name} def notify_connection_changed(self, changes: Set[ConnectionEvent]) -> OldConnectionSuggestedAction: self._calls.append(type(self).__name__ + ":notify_connection_changed") @@ -540,10 +556,15 @@ def notify_connection_changed(self, changes: Set[ConnectionEvent]) -> OldConnect class TestPluginThree(TestPlugin): - @property def subscribed_methods(self) -> Set[str]: - return {"test_call_a", "connect", "force_connect", "notify_connection_changed", "notify_host_list_changed"} + return { + DbApiMethodTest.TEST_CALL_A.method_name, + DbApiMethod.CONNECT.method_name, + DbApiMethod.FORCE_CONNECT.method_name, + DbApiMethod.NOTIFY_CONNECTION_CHANGED.method_name, + DbApiMethod.NOTIFY_HOST_LIST_CHANGED.method_name, + } def notify_connection_changed(self, changes: Set[ConnectionEvent]) -> OldConnectionSuggestedAction: self._calls.append(type(self).__name__ + ":notify_connection_changed")