From 91ff23009934f659899e2062be685c0241ecfa30 Mon Sep 17 00:00:00 2001 From: Karen Chen Date: Mon, 26 Feb 2024 17:44:18 -0800 Subject: [PATCH] feat: implement configuration profiles --- .../aurora_connection_tracker_plugin.py | 62 +++-- ...rora_initial_connection_strategy_plugin.py | 230 ++++++++++++++++++ .../driver_configuration_profiles.py | 44 ---- .../failover_plugin.py | 4 + .../host_list_provider.py | 2 +- .../mysql_driver_dialect.py | 4 +- .../pg_driver_dialect.py | 4 +- aws_advanced_python_wrapper/plugin_service.py | 30 ++- .../profiles/__init__.py | 13 + .../profiles/configuration_profile.py | 69 ++++++ .../configuration_profile_preset_codes.py | 31 +++ .../profiles/driver_configuration_profiles.py | 62 +++++ ...dvanced_python_wrapper_messages.properties | 7 +- .../utils/properties.py | 10 + aws_advanced_python_wrapper/wrapper.py | 22 +- .../writer_failover_handler.py | 2 - benchmarks/plugin_manager_benchmarks.py | 4 +- tests/unit/test_developer_plugin.py | 9 +- 18 files changed, 522 insertions(+), 87 deletions(-) create mode 100644 aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py delete mode 100644 aws_advanced_python_wrapper/driver_configuration_profiles.py create mode 100644 aws_advanced_python_wrapper/profiles/__init__.py create mode 100644 aws_advanced_python_wrapper/profiles/configuration_profile.py create mode 100644 aws_advanced_python_wrapper/profiles/configuration_profile_preset_codes.py create mode 100644 aws_advanced_python_wrapper/profiles/driver_configuration_profiles.py diff --git a/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py b/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py index a2ce609b5..a51104968 100644 --- a/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py +++ b/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py @@ -64,6 +64,19 @@ def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection): self._track_connection(instance_endpoint, conn) + def invalidate_current_connection(self, host_info: HostInfo, conn: Optional[Connection]): + host: Optional[str] = host_info.as_alias() \ + if self._rds_utils.is_rds_instance(host_info.host) \ + else next(alias for alias in host_info.aliases if self._rds_utils.is_rds_instance(alias)) + + if not host: + return + + connection_set: Optional[WeakSet] = self._opened_connections.get(host) + if connection_set is not None: + self._log_connection_set(host, connection_set) + connection_set.discard(conn) + def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host: Optional[FrozenSet[str]] = None): """ Invalidates all opened connections pointing to the same host in a daemon thread. @@ -77,14 +90,10 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host: self.invalidate_all_connections(host=host_info.as_aliases()) return - instance_endpoint: Optional[str] = None if host is None: return - for instance in host: - if instance is not None and self._rds_utils.is_rds_instance(instance): - instance_endpoint = instance - break + instance_endpoint = next(instance for instance in host if self._rds_utils.is_rds_instance(instance)) if not instance_endpoint: return @@ -135,8 +144,8 @@ def log_opened_connections(self): return logger.debug("OpenedConnectionTracker.OpenedConnectionsTracked", msg) - def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]): - if conn_set is None or len(conn_set) == 0: + def _log_connection_set(self, host: Optional[str], conn_set: Optional[WeakSet]): + if host is None or conn_set is None or len(conn_set) == 0: return conn = "" @@ -148,13 +157,14 @@ def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]): class AuroraConnectionTrackerPlugin(Plugin): - _SUBSCRIBED_METHODS: Set[str] = {"*"} + _SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"} _current_writer: Optional[HostInfo] = None _need_update_current_writer: bool = False + _METHOD_CLOSE = "Connection.close" @property def subscribed_methods(self) -> Set[str]: - return self._SUBSCRIBED_METHODS + return AuroraConnectionTrackerPlugin._SUBSCRIBED_METHODS.union(self._plugin_service.network_bound_methods) def __init__(self, plugin_service: PluginService, @@ -201,19 +211,20 @@ def _connect(self, host_info: HostInfo, connect_func: Callable): return conn def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any: - if self._current_writer is None or self._need_update_current_writer: - self._current_writer = self._get_writer(self._plugin_service.hosts) - self._need_update_current_writer = False + self._remember_writer() try: - return execute_func() + results = execute_func() + if method_name == AuroraConnectionTrackerPlugin._METHOD_CLOSE and self._plugin_service.current_host_info is not None: + self._tracker.invalidate_current_connection(self._plugin_service.current_host_info, self._plugin_service.current_connection) + elif self._need_update_current_writer: + self._check_writer_changed() + return results except Exception as e: # Check that e is a FailoverError and that the writer has changed - if isinstance(e, FailoverError) and self._get_writer(self._plugin_service.hosts) != self._current_writer: - self._tracker.invalidate_all_connections(host_info=self._current_writer) - self._tracker.log_opened_connections() - self._need_update_current_writer = True + if isinstance(e, FailoverError): + self._check_writer_changed() raise e def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]: @@ -222,6 +233,23 @@ def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]: return host return None + def _remember_writer(self): + if self._current_writer is None or self._need_update_current_writer: + self._current_writer = self._get_writer(self._plugin_service.hosts) + self._need_update_current_writer = False + + def _check_writer_changed(self): + host_info_after_failover = self._get_writer(self._plugin_service.hosts) + + if self._current_writer is None: + self._current_writer = host_info_after_failover + self._need_update_current_writer = False + elif self._current_writer != host_info_after_failover: + self._tracker.invalidate_all_connections(self._current_writer) + self._tracker.log_opened_connections() + self._current_writer = host_info_after_failover + self._need_update_current_writer = False + class AuroraConnectionTrackerPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: diff --git a/aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py b/aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py new file mode 100644 index 000000000..6c3f53194 --- /dev/null +++ b/aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py @@ -0,0 +1,230 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Callable, Optional, Set + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.driver_dialect import DriverDialect + from aws_advanced_python_wrapper.pep249 import Connection + from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.host_list_provider import HostListProviderService + from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils + +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.host_availability import HostAvailability +from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType + + +class AuroraInitialConnectionStrategyPlugin(Plugin): + _plugin_service: PluginService + _host_list_provider_service: HostListProviderService + _rds_utils: RdsUtils + + @property + def subscribed_methods(self) -> Set[str]: + return {"init_host_provider", "connect", "force_connect"} + + def __init__(self, plugin_service: PluginService, properties: Properties): + self._plugin_service = plugin_service + + def init_host_provider(self, props: Properties, host_list_provider_service: HostListProviderService, init_host_provider_func: Callable): + self._host_list_provider_service = host_list_provider_service + if host_list_provider_service.is_static_host_list_provider(): + raise AwsWrapperError(Messages.get("AuroraInitialConnectionStrategyPlugin.RequireDynamicProvider")) + init_host_provider_func() + + def connect(self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties, + is_initial_connection: bool, connect_func: Callable) -> Connection: + return self._connect_internal(host_info, props, is_initial_connection, connect_func) + + def force_connect(self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties, + is_initial_connection: bool, force_connect_func: Callable) -> Connection: + return self._connect_internal(host_info, props, is_initial_connection, force_connect_func) + + def _connect_internal(self, host_info: HostInfo, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Connection: + type: RdsUrlType = self._rds_utils.identify_rds_type(host_info.host) + if not type.is_rds_cluster: + return connect_func() + + if type == RdsUrlType.RDS_WRITER_CLUSTER: + writer_candidate_conn = self._get_verified_writer_connection(props, is_initial_connection, connect_func) + if writer_candidate_conn is None: + return connect_func() + return writer_candidate_conn + + if type == RdsUrlType.RDS_READER_CLUSTER: + reader_candidate_conn = self._get_verified_reader_connection(props, is_initial_connection, connect_func) + if reader_candidate_conn is None: + return connect_func() + return reader_candidate_conn + + # Continue with a normal workflow. + return connect_func() + + def _get_verified_writer_connection(self, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Optional[Connection]: + retry_delay_ms: int = WrapperProperties.OPEN_CONNECTION_RETRY_TIMEOUT_MS.get_int(props) + end_time_nano = self._get_time() + retry_delay_ms * 1_000_000 + + writer_candidate_conn: Optional[Connection] + writer_candidate: Optional[HostInfo] + + while self._get_time() < end_time_nano: + writer_candidate_conn = None + writer_candidate = None + + try: + writer_candidate = self._get_writer() + if writer_candidate_conn is None or self._rds_utils.is_rds_cluster_dns(writer_candidate.host): + writer_candidate_conn = connect_func() + self._plugin_service.force_refresh_host_list(writer_candidate_conn) + writer_candidate = self._plugin_service.identify_connection(writer_candidate_conn) + + if writer_candidate is not None and writer_candidate.role != HostRole.WRITER: + # Shouldn't be here. But let's try again. + self._close_connection(writer_candidate_conn) + self._delay(retry_delay_ms) + continue + + if is_initial_connection: + self._host_list_provider_service.initial_connection_host_info = writer_candidate + + return writer_candidate_conn + + writer_candidate_conn = self._plugin_service.connect(writer_candidate, props) + + if self._plugin_service.get_host_role(writer_candidate_conn) != HostRole.WRITER: + self._plugin_service.force_refresh_host_list(writer_candidate_conn) + self._close_connection(writer_candidate_conn) + self._delay(retry_delay_ms) + continue + + if is_initial_connection: + self._host_list_provider_service.initial_connection_host_info = writer_candidate + return writer_candidate_conn + + except Exception as e: + if writer_candidate is not None: + self._plugin_service.set_availability(writer_candidate.as_aliases(), HostAvailability.UNAVAILABLE) + self._close_connection(writer_candidate_conn) + raise e + + return None + + def _get_verified_reader_connection(self, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Optional[Connection]: + retry_delay_ms: int = WrapperProperties.OPEN_CONNECTION_RETRY_INTERVAL_MS.get_int(props) + end_time_nano = self._get_time() + WrapperProperties.OPEN_CONNECTION_RETRY_TIMEOUT_MS.get_int(props) * 1_000_000 + + reader_candidate_conn: Optional[Connection] + reader_candidate: Optional[HostInfo] + + while self._get_time() < end_time_nano: + reader_candidate_conn = None + reader_candidate = None + + try: + reader_candidate = self._get_reader(props) + if reader_candidate is None or self._rds_utils.is_rds_cluster_dns(reader_candidate.host): + # Reader not found, topology may be outdated + reader_candidate_conn = connect_func() + self._plugin_service.force_refresh_host_list(reader_candidate_conn) + reader_candidate = self._plugin_service.identify_connection(reader_candidate_conn) + + if reader_candidate is not None and reader_candidate.role != HostRole.READER: + if self._has_no_readers(): + # Cluster has no readers. Simulate Aurora reader cluster endpoint logic + if is_initial_connection and reader_candidate.host is not None: + self._host_list_provider_service.initial_connection_host_info = reader_candidate + return reader_candidate_conn + self._close_connection(reader_candidate_conn) + self._delay(retry_delay_ms) + continue + + if reader_candidate is not None and is_initial_connection: + self._host_list_provider_service.initial_connection_host_info = reader_candidate + return reader_candidate_conn + + reader_candidate_conn = self._plugin_service.connect(reader_candidate, props) + if self._plugin_service.get_host_role(reader_candidate_conn) != HostRole.READER: + # If the new connection resolves to a writer instance, this means the topology is outdated. + # Force refresh to update the topology. + self._plugin_service.force_refresh_host_list(reader_candidate_conn) + + if self._has_no_readers(): + # Cluster has no readers. Simulate Aurora reader cluster endpoint logic + if is_initial_connection: + self._host_list_provider_service.initial_connection_host_info = reader_candidate + return reader_candidate_conn + + self._close_connection(reader_candidate_conn) + self._delay(retry_delay_ms) + continue + + # Reader connection is valid and verified. + if is_initial_connection: + self._host_list_provider_service.initial_connection_host_info = reader_candidate + return reader_candidate_conn + + except Exception: + self._close_connection(reader_candidate_conn) + if reader_candidate is not None: + self._plugin_service.set_availability(reader_candidate.as_aliases(), HostAvailability.AVAILABLE) + + return None + + def _close_connection(self, connection: Optional[Connection]): + if connection is not None: + try: + connection.close() + except Exception: + # ignore + pass + + def _delay(self, delay_ms: int): + time.sleep(delay_ms / 1000) + + def _get_writer(self) -> Optional[HostInfo]: + return next(host for host in self._plugin_service.hosts if host.role == HostRole.WRITER) + + def _get_reader(self, props: Properties) -> Optional[HostInfo]: + strategy: Optional[str] = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.get(props) + if strategy is not None and self._plugin_service.accepts_strategy(HostRole.READER, strategy): + try: + return self._plugin_service.get_host_info_by_strategy(HostRole.READER, strategy) + except Exception: + # Host isn't found + return None + + raise AwsWrapperError(Messages.get_formatted("AuroraInitialConnectionStrategyPlugin.UnsupportedStrategy", strategy)) + + def _has_no_readers(self) -> bool: + if len(self._plugin_service.hosts) == 0: + # Topology inconclusive. + return False + return next(host_info for host_info in self._plugin_service.hosts if host_info.role == HostRole.READER) is None + + def _get_time(self): + return time.perf_counter_ns() + + +class AuroraInitialConnectionStrategyPluginFactory(PluginFactory): + def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + return AuroraInitialConnectionStrategyPlugin(plugin_service, props) diff --git a/aws_advanced_python_wrapper/driver_configuration_profiles.py b/aws_advanced_python_wrapper/driver_configuration_profiles.py deleted file mode 100644 index a5ea171a4..000000000 --- a/aws_advanced_python_wrapper/driver_configuration_profiles.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import TYPE_CHECKING, Dict, List - -if TYPE_CHECKING: - from aws_advanced_python_wrapper.plugin import PluginFactory - - -class DriverConfigurationProfiles: - _profiles: Dict[str, List[PluginFactory]] = {} - - @classmethod - def clear_profiles(cls): - cls._profiles.clear() - - @classmethod - def add_or_replace_profile(cls, profile_name: str, factories: List[PluginFactory]): - cls._profiles[profile_name] = factories - - @classmethod - def remove_profile(cls, profile_name: str): - cls._profiles.pop(profile_name) - - @classmethod - def contains_profile(cls, profile_name: str): - return profile_name in cls._profiles - - @classmethod - def get_plugin_factories(cls, profile_name: str): - return cls._profiles[profile_name] diff --git a/aws_advanced_python_wrapper/failover_plugin.py b/aws_advanced_python_wrapper/failover_plugin.py index f0776611c..d4ff35d85 100644 --- a/aws_advanced_python_wrapper/failover_plugin.py +++ b/aws_advanced_python_wrapper/failover_plugin.py @@ -273,6 +273,10 @@ def _failover(self, failed_host: Optional[HostInfo]): :param failed_host: The host with network errors. """ + + if failed_host is not None: + self._plugin_service.set_availability(failed_host.as_aliases(), HostAvailability.AVAILABLE) + if self._failover_mode == FailoverMode.STRICT_WRITER: self._failover_writer() else: diff --git a/aws_advanced_python_wrapper/host_list_provider.py b/aws_advanced_python_wrapper/host_list_provider.py index f6b94d89a..90d04fb0f 100644 --- a/aws_advanced_python_wrapper/host_list_provider.py +++ b/aws_advanced_python_wrapper/host_list_provider.py @@ -127,7 +127,7 @@ def initial_connection_host_info(self) -> Optional[HostInfo]: ... @initial_connection_host_info.setter - def initial_connection_host_info(self, value: HostInfo): + def initial_connection_host_info(self, value: Optional[HostInfo]): ... def is_static_host_list_provider(self) -> bool: diff --git a/aws_advanced_python_wrapper/mysql_driver_dialect.py b/aws_advanced_python_wrapper/mysql_driver_dialect.py index c999eba60..5dcb528f5 100644 --- a/aws_advanced_python_wrapper/mysql_driver_dialect.py +++ b/aws_advanced_python_wrapper/mysql_driver_dialect.py @@ -63,7 +63,9 @@ class MySQLDriverDialect(DriverDialect): } def is_dialect(self, connect_func: Callable) -> bool: - return MySQLDriverDialect.TARGET_DRIVER_CODE in str(signature(connect_func)) + if MySQLDriverDialect.TARGET_DRIVER_CODE not in str(signature(connect_func)): + return MySQLDriverDialect.TARGET_DRIVER_CODE.lower() in (connect_func.__module__ + connect_func.__qualname__).lower() + return True def is_closed(self, conn: Connection) -> bool: if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): diff --git a/aws_advanced_python_wrapper/pg_driver_dialect.py b/aws_advanced_python_wrapper/pg_driver_dialect.py index 921e59c9e..51333a759 100644 --- a/aws_advanced_python_wrapper/pg_driver_dialect.py +++ b/aws_advanced_python_wrapper/pg_driver_dialect.py @@ -58,7 +58,9 @@ class PgDriverDialect(DriverDialect): } def is_dialect(self, connect_func: Callable) -> bool: - return PgDriverDialect.TARGET_DRIVER_CODE in str(signature(connect_func)) + if PgDriverDialect.TARGET_DRIVER_CODE not in str(signature(connect_func)): + return PgDriverDialect.TARGET_DRIVER_CODE.lower() in (connect_func.__module__ + connect_func.__qualname__).lower() + return True def is_closed(self, conn: Connection) -> bool: if isinstance(conn, psycopg.Connection): diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index f05f11fab..f2bb03c6d 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -24,6 +24,7 @@ from aws_advanced_python_wrapper.driver_dialect_manager import DriverDialectManager from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory + from aws_advanced_python_wrapper.profiles.configuration_profile import ConfigurationProfile from threading import Event from abc import abstractmethod @@ -45,8 +46,6 @@ UnknownDatabaseDialect) from aws_advanced_python_wrapper.default_plugin import DefaultPlugin from aws_advanced_python_wrapper.developer_plugin import DeveloperPluginFactory -from aws_advanced_python_wrapper.driver_configuration_profiles import \ - DriverConfigurationProfiles from aws_advanced_python_wrapper.errors import (AwsWrapperError, QueryTimeoutError, UnsupportedOperationError) @@ -249,7 +248,6 @@ def get_telemetry_factory(self) -> TelemetryFactory: class PluginServiceImpl(PluginService, HostListProviderService, CanReleaseResources): - _host_availability_expiring_cache: CacheMap[str, HostAvailability] = CacheMap() _executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="PluginServiceImplExecutor") @@ -260,10 +258,12 @@ def __init__( props: Properties, target_func: Callable, driver_dialect_manager: DriverDialectManager, - driver_dialect: DriverDialect): + driver_dialect: DriverDialect, + profile: Optional[ConfigurationProfile]): self._container = container self._container.plugin_service = self self._props = props + self._configuration_profile = profile self._original_url = PropertiesUtils.get_url(props) self._host_list_provider: HostListProvider = ConnectionStringHostListProvider(self, props) @@ -277,7 +277,9 @@ def __init__( self._target_func = target_func self._driver_dialect_manager = driver_dialect_manager self._driver_dialect = driver_dialect - self._database_dialect = self._dialect_provider.get_dialect(driver_dialect.dialect_code, props) + self._database_dialect = self._configuration_profile.database_dialect \ + if self._configuration_profile is not None and self._configuration_profile.database_dialect is not None \ + else self._dialect_provider.get_dialect(driver_dialect.dialect_code, props) @property def hosts(self) -> Tuple[HostInfo, ...]: @@ -597,14 +599,18 @@ class PluginManager(CanReleaseResources): FederatedAuthPluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN } - def __init__( - self, container: PluginServiceManagerContainer, props: Properties, telemetry_factory: TelemetryFactory): + def __init__(self, + container: PluginServiceManagerContainer, + props: Properties, + telemetry_factory: TelemetryFactory, + profile: Optional[ConfigurationProfile] = None): self._props: Properties = props self._function_cache: Dict[str, Callable] = {} self._container = container self._container.plugin_manager = self self._connection_provider_manager = ConnectionProviderManager() self._telemetry_factory = telemetry_factory + self._configuration_profile: Optional[ConfigurationProfile] = profile self._plugins = self.get_plugins() @property @@ -632,12 +638,10 @@ def get_plugins(self) -> List[Plugin]: plugin_factories: List[PluginFactory] = [] plugins: List[Plugin] = [] - profile_name = WrapperProperties.PROFILE_NAME.get(self._props) - if profile_name is not None: - if not DriverConfigurationProfiles.contains_profile(profile_name): - raise AwsWrapperError( - Messages.get_formatted("PluginManager.ConfigurationProfileNotFound", profile_name)) - plugin_factories = DriverConfigurationProfiles.get_plugin_factories(profile_name) + if self._configuration_profile is not None: + factories = self._configuration_profile.plugin_factories + if factories is not None: + plugin_factories = self._configuration_profile.plugin_factories else: plugin_codes = WrapperProperties.PLUGINS.get(self._props) if plugin_codes is None: diff --git a/aws_advanced_python_wrapper/profiles/__init__.py b/aws_advanced_python_wrapper/profiles/__init__.py new file mode 100644 index 000000000..bd4acb2bf --- /dev/null +++ b/aws_advanced_python_wrapper/profiles/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/aws_advanced_python_wrapper/profiles/configuration_profile.py b/aws_advanced_python_wrapper/profiles/configuration_profile.py new file mode 100644 index 000000000..32a49fafa --- /dev/null +++ b/aws_advanced_python_wrapper/profiles/configuration_profile.py @@ -0,0 +1,69 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.connection_provider import \ + ConnectionProvider + from aws_advanced_python_wrapper.database_dialect import DatabaseDialect + from aws_advanced_python_wrapper.driver_dialect import DriverDialect + from aws_advanced_python_wrapper.exception_handling import ExceptionHandler + from aws_advanced_python_wrapper.plugin import PluginFactory + from aws_advanced_python_wrapper.utils.properties import Properties + + +class ConfigurationProfile: + def __init__( + self, + name: str, + properties: Properties, + plugin_factories: List[PluginFactory] = [], + dialect: Optional[DatabaseDialect] = None, + target_driver_dialect: Optional[DriverDialect] = None, + exception_handler: Optional[ExceptionHandler] = None, + connection_provider: Optional[ConnectionProvider] = None): + self._name = name + self._plugin_factories = plugin_factories + self._properties = properties + self._database_dialect = dialect + self._target_driver_dialect = target_driver_dialect + self._exception_handler = exception_handler + self._connection_provider = connection_provider + + @property + def name(self) -> str: + return self._name + + @property + def properties(self) -> Properties: + return self._properties + + @property + def plugin_factories(self) -> List[PluginFactory]: + return self._plugin_factories + + @property + def database_dialect(self) -> Optional[DatabaseDialect]: + return self._database_dialect + + @property + def target_driver_dialect(self) -> Optional[DriverDialect]: + return self._target_driver_dialect + + @property + def connection_provider(self) -> Optional[ConnectionProvider]: + return self._connection_provider diff --git a/aws_advanced_python_wrapper/profiles/configuration_profile_preset_codes.py b/aws_advanced_python_wrapper/profiles/configuration_profile_preset_codes.py new file mode 100644 index 000000000..a05240c4c --- /dev/null +++ b/aws_advanced_python_wrapper/profiles/configuration_profile_preset_codes.py @@ -0,0 +1,31 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class ConfigurationProfilePresetCodes: + A0 = "A0" # Normal + A1 = "A1" # Easy + A2 = "A2" # Aggressive + B = "B" # Normal + C0 = "C0" # Normal + C1 = "C1" # Aggressive + D0 = "D0" # Normal + D1 = "D1" # Easy + E = "E" # Normal + F0 = "F0" # Normal + F1 = "F1" # Aggressive + G0 = "G0" # Normal + G1 = "G1" # Easy + H = "H" # Normal + I0 = "I0" # Normal + I1 = "I1" # Aggressive diff --git a/aws_advanced_python_wrapper/profiles/driver_configuration_profiles.py b/aws_advanced_python_wrapper/profiles/driver_configuration_profiles.py new file mode 100644 index 000000000..d953ad257 --- /dev/null +++ b/aws_advanced_python_wrapper/profiles/driver_configuration_profiles.py @@ -0,0 +1,62 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Dict, Optional + +from aws_advanced_python_wrapper.profiles.configuration_profile import \ + ConfigurationProfile +from aws_advanced_python_wrapper.profiles.configuration_profile_preset_codes import \ + ConfigurationProfilePresetCodes +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) + + +class DriverConfigurationProfiles: + _profiles: Dict[str, Optional[ConfigurationProfile]] = {} + _presets: Dict[str, ConfigurationProfile] = { + ConfigurationProfilePresetCodes.A0: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.A0, + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: "10", + WrapperProperties.TCP_KEEPALIVE.name: False}) + ), + } + + @classmethod + def clear_profiles(cls): + cls._profiles.clear() + + @classmethod + def add_or_replace_profile(cls, profile_name: str, profile: Optional[ConfigurationProfile]): + cls._profiles[profile_name] = profile + + @classmethod + def remove_profile(cls, profile_name: str): + cls._profiles.pop(profile_name) + + @classmethod + def contains_profile(cls, profile_name: str): + return profile_name in cls._profiles + + @classmethod + def get_plugin_factories(cls, profile_name: str): + return cls._profiles[profile_name] + + @classmethod + def get_profile_configuration(cls, profile_name: str): + profile: Optional[ConfigurationProfile] = cls._profiles.get(profile_name) + if profile is not None: + return profile + return cls._presets.get(profile_name) diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index f8b7a1cf7..e7f79eb5b 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -25,6 +25,9 @@ AdfsCredentialsProviderFactory.SignOnPagePostActionRequestFailed=[AdfsCredential AdfsCredentialsProviderFactory.SignOnPageRequestFailed=[AdfsCredentialsProviderFactory] ADFS SignOn Page Request Failed with HTTP status '{}', reason phrase '{}', and response '{}' AdfsCredentialsProviderFactory.SignOnPageUrl=[AdfsCredentialsProviderFactory] ADFS SignOn URL: '{}' +AuroraInitialConnectionStrategyPlugin.UnsupportedStrategy=Unsupported host selection strategy '{}'. +AuroraInitialConnectionStrategyPlugin.RequireDynamicProvider=Dynamic host list provider is required. + AwsSdk.UnsupportedRegion=[AwsSdk] Unsupported AWS region {}. For supported regions please read https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html AwsSecretsManagerPlugin.ConnectException=[AwsSecretsManagerPlugin] Error occurred while opening a connection: {} @@ -160,7 +163,7 @@ OpenTelemetryFactory.WrongParameterType="[OpenTelemetryFactory] Wrong parameter Plugin.UnsupportedMethod=[Plugin] '{}' is not supported by this plugin. -PluginManager.ConfigurationProfileNotFound=PluginManager] Configuration profile '{}' not found. +PluginManager.ConfigurationProfileNotFound=[PluginManager] Configuration profile '{}' not found. PluginManager.InvalidPlugin=[PluginManager] Invalid plugin requested: '{}'. PluginManager.MethodInvokedAgainstOldConnection = [PluginManager] The internal connection has changed since '{}' was created. This is likely due to failover or read-write splitting functionality. To ensure you are using the updated connection, please re-create Cursor objects after failover and/or setting readonly. PluginManager.PipelineNone=[PluginManager] A pipeline was requested but the created pipeline evaluated to None. @@ -273,6 +276,8 @@ Wrapper.ConnectMethod=[Wrapper] Target driver should be a target driver's connec Wrapper.RequiredTargetDriver=[Wrapper] Target driver is required. Wrapper.UnsupportedAttribute=[Wrapper] Target driver does not have the attribute: '{}' Wrapper.Properties=[Wrapper] "Connection Properties: " +Wrapper.ConfigurationProfileNotFound=[Wrapper] Configuration profile '{}' not found. + WriterFailoverHandler.AlreadyWriter=[WriterFailoverHandler] Current reader connection is actually a new writer connection. WriterFailoverHandler.CurrentTopologyNone=[WriterFailoverHandler] Current topology cannot be None. diff --git a/aws_advanced_python_wrapper/utils/properties.py b/aws_advanced_python_wrapper/utils/properties.py index 8859bac80..31ba7636b 100644 --- a/aws_advanced_python_wrapper/utils/properties.py +++ b/aws_advanced_python_wrapper/utils/properties.py @@ -314,6 +314,16 @@ class WrapperProperties: False ) + # Aurora Initial Connection Strategy Plugin + READER_INITIAL_HOST_SELECTOR_STRATEGY = WrapperProperty("reader_initial_connection_host_selector_strategy", + "The strategy that should be used to select a " + "new reader host while opening a new connection.", + "random") + + OPEN_CONNECTION_RETRY_TIMEOUT_MS = WrapperProperty("open_connection_retry_timeout_ms", + "Maximum allowed time for the retries opening a connection.", 30_000) + OPEN_CONNECTION_RETRY_INTERVAL_MS = WrapperProperty("open_connection_retry_interval_ms", "Time between each retry of opening a connection.", 1000) + class PropertiesUtils: @staticmethod diff --git a/aws_advanced_python_wrapper/wrapper.py b/aws_advanced_python_wrapper/wrapper.py index c6318a021..43cc9cc16 100644 --- a/aws_advanced_python_wrapper/wrapper.py +++ b/aws_advanced_python_wrapper/wrapper.py @@ -17,6 +17,12 @@ from typing import (TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Union) +if TYPE_CHECKING: + from aws_advanced_python_wrapper.profiles.configuration_profile import ConfigurationProfile + +from aws_advanced_python_wrapper.profiles.driver_configuration_profiles import \ + DriverConfigurationProfiles + if TYPE_CHECKING: from aws_advanced_python_wrapper.host_list_provider import HostListProviderService @@ -32,7 +38,8 @@ from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, - PropertiesUtils) + PropertiesUtils, + WrapperProperties) from aws_advanced_python_wrapper.utils.telemetry.default_telemetry_factory import \ DefaultTelemetryFactory from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ @@ -137,10 +144,19 @@ def connect( try: driver_dialect_manager: DriverDialectManager = DriverDialectManager() driver_dialect = driver_dialect_manager.get_dialect(target_func, props) + + profile_name: Optional[str] = WrapperProperties.PROFILE_NAME.get(props) + configuration_profile: Optional[ConfigurationProfile] = None + if profile_name: + configuration_profile = DriverConfigurationProfiles.get_profile_configuration(profile_name) + if configuration_profile is None: + raise AwsWrapperError(Messages.get_formatted("Wrapper.ConfigurationProfileNotFound")) + props = Properties({**props, **configuration_profile.properties}) + container: PluginServiceManagerContainer = PluginServiceManagerContainer() plugin_service = PluginServiceImpl( - container, props, target_func, driver_dialect_manager, driver_dialect) - plugin_manager: PluginManager = PluginManager(container, props, telemetry_factory) + container, props, target_func, driver_dialect_manager, driver_dialect, configuration_profile) + plugin_manager: PluginManager = PluginManager(container, props, telemetry_factory, configuration_profile) return AwsWrapperConnection(target_func, plugin_service, plugin_service, plugin_manager) except Exception as ex: diff --git a/aws_advanced_python_wrapper/writer_failover_handler.py b/aws_advanced_python_wrapper/writer_failover_handler.py index 600c24d22..fc0a22d2c 100644 --- a/aws_advanced_python_wrapper/writer_failover_handler.py +++ b/aws_advanced_python_wrapper/writer_failover_handler.py @@ -116,8 +116,6 @@ def get_writer(self, topology: Tuple[HostInfo, ...]) -> Optional[HostInfo]: def get_result_from_future(self, current_topology: Tuple[HostInfo, ...]) -> WriterFailoverResult: writer_host: Optional[HostInfo] = self.get_writer(current_topology) if writer_host is not None: - self._plugin_service.set_availability(writer_host.as_aliases(), HostAvailability.UNAVAILABLE) - with ThreadPoolExecutor(thread_name_prefix="WriterFailoverHandlerExecutor") as executor: try: futures = [executor.submit(self.reconnect_to_writer, writer_host), diff --git a/benchmarks/plugin_manager_benchmarks.py b/benchmarks/plugin_manager_benchmarks.py index 0ed9ae3ad..32a31bbb6 100644 --- a/benchmarks/plugin_manager_benchmarks.py +++ b/benchmarks/plugin_manager_benchmarks.py @@ -22,11 +22,11 @@ from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.plugin import PluginFactory -from aws_advanced_python_wrapper.driver_configuration_profiles import \ - DriverConfigurationProfiles from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.plugin_service import ( PluginManager, PluginServiceManagerContainer) +from aws_advanced_python_wrapper.profiles.driver_configuration_profiles import \ + DriverConfigurationProfiles from aws_advanced_python_wrapper.utils.properties import Properties from benchmarks.benchmark_plugin import BenchmarkPluginFactory diff --git a/tests/unit/test_developer_plugin.py b/tests/unit/test_developer_plugin.py index 51625fe14..06832fd8a 100644 --- a/tests/unit/test_developer_plugin.py +++ b/tests/unit/test_developer_plugin.py @@ -50,14 +50,19 @@ def mock_dialect_manager(mocker, mock_driver_dialect): return dialect_manager +@pytest.fixture +def mock_configuration_profile(mocker): + return mocker.MagicMock() + + @pytest.fixture def container(): return PluginServiceManagerContainer() @pytest.fixture -def plugin_service(mocker, container, props, mock_dialect_manager, mock_driver_dialect): - return PluginServiceImpl(container, props, mocker.MagicMock(), mock_dialect_manager, mock_driver_dialect) +def plugin_service(mocker, container, props, mock_dialect_manager, mock_driver_dialect, mock_configuration_profile): + return PluginServiceImpl(container, props, mocker.MagicMock(), mock_dialect_manager, mock_driver_dialect, mock_configuration_profile) @pytest.fixture