Initial commit: igny8 project

This commit is contained in:
igny8
2025-11-09 10:27:02 +00:00
commit 60b8188111
27265 changed files with 4360521 additions and 0 deletions

View File

@@ -0,0 +1,530 @@
import asyncio
import logging
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Union
from redis.asyncio.client import PubSubHandler
from redis.asyncio.multidb.command_executor import DefaultCommandExecutor
from redis.asyncio.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig
from redis.asyncio.multidb.database import AsyncDatabase, Databases
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
from redis.asyncio.multidb.healthcheck import HealthCheck, HealthCheckPolicy
from redis.background import BackgroundScheduler
from redis.commands import AsyncCoreCommands, AsyncRedisModuleCommands
from redis.multidb.circuit import CircuitBreaker
from redis.multidb.circuit import State as CBState
from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException
from redis.typing import ChannelT, EncodableT, KeyT
from redis.utils import experimental
logger = logging.getLogger(__name__)
@experimental
class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands):
"""
Client that operates on multiple logical Redis databases.
Should be used in Active-Active database setups.
"""
def __init__(self, config: MultiDbConfig):
self._databases = config.databases()
self._health_checks = config.default_health_checks()
if config.health_checks is not None:
self._health_checks.extend(config.health_checks)
self._health_check_interval = config.health_check_interval
self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value(
config.health_check_probes, config.health_check_delay
)
self._failure_detectors = config.default_failure_detectors()
if config.failure_detectors is not None:
self._failure_detectors.extend(config.failure_detectors)
self._failover_strategy = (
config.default_failover_strategy()
if config.failover_strategy is None
else config.failover_strategy
)
self._failover_strategy.set_databases(self._databases)
self._auto_fallback_interval = config.auto_fallback_interval
self._event_dispatcher = config.event_dispatcher
self._command_retry = config.command_retry
self._command_retry.update_supported_errors([ConnectionRefusedError])
self.command_executor = DefaultCommandExecutor(
failure_detectors=self._failure_detectors,
databases=self._databases,
command_retry=self._command_retry,
failover_strategy=self._failover_strategy,
failover_attempts=config.failover_attempts,
failover_delay=config.failover_delay,
event_dispatcher=self._event_dispatcher,
auto_fallback_interval=self._auto_fallback_interval,
)
self.initialized = False
self._hc_lock = asyncio.Lock()
self._bg_scheduler = BackgroundScheduler()
self._config = config
self._recurring_hc_task = None
self._hc_tasks = []
self._half_open_state_task = None
async def __aenter__(self: "MultiDBClient") -> "MultiDBClient":
if not self.initialized:
await self.initialize()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
if self._recurring_hc_task:
self._recurring_hc_task.cancel()
if self._half_open_state_task:
self._half_open_state_task.cancel()
for hc_task in self._hc_tasks:
hc_task.cancel()
async def initialize(self):
"""
Perform initialization of databases to define their initial state.
"""
async def raise_exception_on_failed_hc(error):
raise error
# Initial databases check to define initial state
await self._check_databases_health(on_error=raise_exception_on_failed_hc)
# Starts recurring health checks on the background.
self._recurring_hc_task = asyncio.create_task(
self._bg_scheduler.run_recurring_async(
self._health_check_interval,
self._check_databases_health,
)
)
is_active_db_found = False
for database, weight in self._databases:
# Set on state changed callback for each circuit.
database.circuit.on_state_changed(self._on_circuit_state_change_callback)
# Set states according to a weights and circuit state
if database.circuit.state == CBState.CLOSED and not is_active_db_found:
await self.command_executor.set_active_database(database)
is_active_db_found = True
if not is_active_db_found:
raise NoValidDatabaseException(
"Initial connection failed - no active database found"
)
self.initialized = True
def get_databases(self) -> Databases:
"""
Returns a sorted (by weight) list of all databases.
"""
return self._databases
async def set_active_database(self, database: AsyncDatabase) -> None:
"""
Promote one of the existing databases to become an active.
"""
exists = None
for existing_db, _ in self._databases:
if existing_db == database:
exists = True
break
if not exists:
raise ValueError("Given database is not a member of database list")
await self._check_db_health(database)
if database.circuit.state == CBState.CLOSED:
highest_weighted_db, _ = self._databases.get_top_n(1)[0]
await self.command_executor.set_active_database(database)
return
raise NoValidDatabaseException(
"Cannot set active database, database is unhealthy"
)
async def add_database(self, database: AsyncDatabase):
"""
Adds a new database to the database list.
"""
for existing_db, _ in self._databases:
if existing_db == database:
raise ValueError("Given database already exists")
await self._check_db_health(database)
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
self._databases.add(database, database.weight)
await self._change_active_database(database, highest_weighted_db)
async def _change_active_database(
self, new_database: AsyncDatabase, highest_weight_database: AsyncDatabase
):
if (
new_database.weight > highest_weight_database.weight
and new_database.circuit.state == CBState.CLOSED
):
await self.command_executor.set_active_database(new_database)
async def remove_database(self, database: AsyncDatabase):
"""
Removes a database from the database list.
"""
weight = self._databases.remove(database)
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
if (
highest_weight <= weight
and highest_weighted_db.circuit.state == CBState.CLOSED
):
await self.command_executor.set_active_database(highest_weighted_db)
async def update_database_weight(self, database: AsyncDatabase, weight: float):
"""
Updates a database from the database list.
"""
exists = None
for existing_db, _ in self._databases:
if existing_db == database:
exists = True
break
if not exists:
raise ValueError("Given database is not a member of database list")
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
self._databases.update_weight(database, weight)
database.weight = weight
await self._change_active_database(database, highest_weighted_db)
def add_failure_detector(self, failure_detector: AsyncFailureDetector):
"""
Adds a new failure detector to the database.
"""
self._failure_detectors.append(failure_detector)
async def add_health_check(self, healthcheck: HealthCheck):
"""
Adds a new health check to the database.
"""
async with self._hc_lock:
self._health_checks.append(healthcheck)
async def execute_command(self, *args, **options):
"""
Executes a single command and return its result.
"""
if not self.initialized:
await self.initialize()
return await self.command_executor.execute_command(*args, **options)
def pipeline(self):
"""
Enters into pipeline mode of the client.
"""
return Pipeline(self)
async def transaction(
self,
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
*watches: KeyT,
shard_hint: Optional[str] = None,
value_from_callable: bool = False,
watch_delay: Optional[float] = None,
):
"""
Executes callable as transaction.
"""
if not self.initialized:
await self.initialize()
return await self.command_executor.execute_transaction(
func,
*watches,
shard_hint=shard_hint,
value_from_callable=value_from_callable,
watch_delay=watch_delay,
)
async def pubsub(self, **kwargs):
"""
Return a Publish/Subscribe object. With this object, you can
subscribe to channels and listen for messages that get published to
them.
"""
if not self.initialized:
await self.initialize()
return PubSub(self, **kwargs)
async def _check_databases_health(
self,
on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None,
):
"""
Runs health checks as a recurring task.
Runs health checks against all databases.
"""
try:
self._hc_tasks = [
asyncio.create_task(self._check_db_health(database))
for database, _ in self._databases
]
results = await asyncio.wait_for(
asyncio.gather(
*self._hc_tasks,
return_exceptions=True,
),
timeout=self._health_check_interval,
)
except asyncio.TimeoutError:
raise asyncio.TimeoutError(
"Health check execution exceeds health_check_interval"
)
for result in results:
if isinstance(result, UnhealthyDatabaseException):
unhealthy_db = result.database
unhealthy_db.circuit.state = CBState.OPEN
logger.exception(
"Health check failed, due to exception",
exc_info=result.original_exception,
)
if on_error:
on_error(result.original_exception)
async def _check_db_health(self, database: AsyncDatabase) -> bool:
"""
Runs health checks on the given database until first failure.
"""
# Health check will setup circuit state
is_healthy = await self._health_check_policy.execute(
self._health_checks, database
)
if not is_healthy:
if database.circuit.state != CBState.OPEN:
database.circuit.state = CBState.OPEN
return is_healthy
elif is_healthy and database.circuit.state != CBState.CLOSED:
database.circuit.state = CBState.CLOSED
return is_healthy
def _on_circuit_state_change_callback(
self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState
):
loop = asyncio.get_running_loop()
if new_state == CBState.HALF_OPEN:
self._half_open_state_task = asyncio.create_task(
self._check_db_health(circuit.database)
)
return
if old_state == CBState.CLOSED and new_state == CBState.OPEN:
loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit)
async def aclose(self):
if self.command_executor.active_database:
await self.command_executor.active_database.client.aclose()
def _half_open_circuit(circuit: CircuitBreaker):
circuit.state = CBState.HALF_OPEN
class Pipeline(AsyncRedisModuleCommands, AsyncCoreCommands):
"""
Pipeline implementation for multiple logical Redis databases.
"""
def __init__(self, client: MultiDBClient):
self._command_stack = []
self._client = client
async def __aenter__(self: "Pipeline") -> "Pipeline":
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.reset()
await self._client.__aexit__(exc_type, exc_value, traceback)
def __await__(self):
return self._async_self().__await__()
async def _async_self(self):
return self
def __len__(self) -> int:
return len(self._command_stack)
def __bool__(self) -> bool:
"""Pipeline instances should always evaluate to True"""
return True
async def reset(self) -> None:
self._command_stack = []
async def aclose(self) -> None:
"""Close the pipeline"""
await self.reset()
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
"""
Stage a command to be executed when execute() is next called
Returns the current Pipeline object back so commands can be
chained together, such as:
pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
At some other point, you can then run: pipe.execute(),
which will execute all commands queued in the pipe.
"""
self._command_stack.append((args, options))
return self
def execute_command(self, *args, **kwargs):
"""Adds a command to the stack"""
return self.pipeline_execute_command(*args, **kwargs)
async def execute(self) -> List[Any]:
"""Execute all the commands in the current pipeline"""
if not self._client.initialized:
await self._client.initialize()
try:
return await self._client.command_executor.execute_pipeline(
tuple(self._command_stack)
)
finally:
await self.reset()
class PubSub:
"""
PubSub object for multi database client.
"""
def __init__(self, client: MultiDBClient, **kwargs):
"""Initialize the PubSub object for a multi-database client.
Args:
client: MultiDBClient instance to use for pub/sub operations
**kwargs: Additional keyword arguments to pass to the underlying pubsub implementation
"""
self._client = client
self._client.command_executor.pubsub(**kwargs)
async def __aenter__(self) -> "PubSub":
return self
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
await self.aclose()
async def aclose(self):
return await self._client.command_executor.execute_pubsub_method("aclose")
@property
def subscribed(self) -> bool:
return self._client.command_executor.active_pubsub.subscribed
async def execute_command(self, *args: EncodableT):
return await self._client.command_executor.execute_pubsub_method(
"execute_command", *args
)
async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
"""
Subscribe to channel patterns. Patterns supplied as keyword arguments
expect a pattern name as the key and a callable as the value. A
pattern's callable will be invoked automatically when a message is
received on that pattern rather than producing a message via
``listen()``.
"""
return await self._client.command_executor.execute_pubsub_method(
"psubscribe", *args, **kwargs
)
async def punsubscribe(self, *args: ChannelT):
"""
Unsubscribe from the supplied patterns. If empty, unsubscribe from
all patterns.
"""
return await self._client.command_executor.execute_pubsub_method(
"punsubscribe", *args
)
async def subscribe(self, *args: ChannelT, **kwargs: Callable):
"""
Subscribe to channels. Channels supplied as keyword arguments expect
a channel name as the key and a callable as the value. A channel's
callable will be invoked automatically when a message is received on
that channel rather than producing a message via ``listen()`` or
``get_message()``.
"""
return await self._client.command_executor.execute_pubsub_method(
"subscribe", *args, **kwargs
)
async def unsubscribe(self, *args):
"""
Unsubscribe from the supplied channels. If empty, unsubscribe from
all channels
"""
return await self._client.command_executor.execute_pubsub_method(
"unsubscribe", *args
)
async def get_message(
self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
):
"""
Get the next message if one is available, otherwise None.
If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number or None to wait indefinitely.
"""
return await self._client.command_executor.execute_pubsub_method(
"get_message",
ignore_subscribe_messages=ignore_subscribe_messages,
timeout=timeout,
)
async def run(
self,
*,
exception_handler=None,
poll_timeout: float = 1.0,
) -> None:
"""Process pub/sub messages using registered callbacks.
This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
redis-py, but it is a coroutine. To launch it as a separate task, use
``asyncio.create_task``:
>>> task = asyncio.create_task(pubsub.run())
To shut it down, use asyncio cancellation:
>>> task.cancel()
>>> await task
"""
return await self._client.command_executor.execute_pubsub_run(
sleep_time=poll_timeout, exception_handler=exception_handler, pubsub=self
)

View File

@@ -0,0 +1,339 @@
from abc import abstractmethod
from asyncio import iscoroutinefunction
from datetime import datetime
from typing import Any, Awaitable, Callable, List, Optional, Union
from redis.asyncio import RedisCluster
from redis.asyncio.client import Pipeline, PubSub
from redis.asyncio.multidb.database import AsyncDatabase, Database, Databases
from redis.asyncio.multidb.event import (
AsyncActiveDatabaseChanged,
CloseConnectionOnActiveDatabaseChanged,
RegisterCommandFailure,
ResubscribeOnActiveDatabaseChanged,
)
from redis.asyncio.multidb.failover import (
DEFAULT_FAILOVER_ATTEMPTS,
DEFAULT_FAILOVER_DELAY,
AsyncFailoverStrategy,
DefaultFailoverStrategyExecutor,
FailoverStrategyExecutor,
)
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
from redis.asyncio.retry import Retry
from redis.event import AsyncOnCommandsFailEvent, EventDispatcherInterface
from redis.multidb.circuit import State as CBState
from redis.multidb.command_executor import BaseCommandExecutor, CommandExecutor
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
from redis.typing import KeyT
class AsyncCommandExecutor(CommandExecutor):
@property
@abstractmethod
def databases(self) -> Databases:
"""Returns a list of databases."""
pass
@property
@abstractmethod
def failure_detectors(self) -> List[AsyncFailureDetector]:
"""Returns a list of failure detectors."""
pass
@abstractmethod
def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
"""Adds a new failure detector to the list of failure detectors."""
pass
@property
@abstractmethod
def active_database(self) -> Optional[AsyncDatabase]:
"""Returns currently active database."""
pass
@abstractmethod
async def set_active_database(self, database: AsyncDatabase) -> None:
"""Sets the currently active database."""
pass
@property
@abstractmethod
def active_pubsub(self) -> Optional[PubSub]:
"""Returns currently active pubsub."""
pass
@active_pubsub.setter
@abstractmethod
def active_pubsub(self, pubsub: PubSub) -> None:
"""Sets currently active pubsub."""
pass
@property
@abstractmethod
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
"""Returns failover strategy executor."""
pass
@property
@abstractmethod
def command_retry(self) -> Retry:
"""Returns command retry object."""
pass
@abstractmethod
async def pubsub(self, **kwargs):
"""Initializes a PubSub object on a currently active database"""
pass
@abstractmethod
async def execute_command(self, *args, **options):
"""Executes a command and returns the result."""
pass
@abstractmethod
async def execute_pipeline(self, command_stack: tuple):
"""Executes a stack of commands in pipeline."""
pass
@abstractmethod
async def execute_transaction(
self, transaction: Callable[[Pipeline], None], *watches, **options
):
"""Executes a transaction block wrapped in callback."""
pass
@abstractmethod
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
"""Executes a given method on active pub/sub."""
pass
@abstractmethod
async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any:
"""Executes pub/sub run in a thread."""
pass
class DefaultCommandExecutor(BaseCommandExecutor, AsyncCommandExecutor):
def __init__(
self,
failure_detectors: List[AsyncFailureDetector],
databases: Databases,
command_retry: Retry,
failover_strategy: AsyncFailoverStrategy,
event_dispatcher: EventDispatcherInterface,
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
failover_delay: float = DEFAULT_FAILOVER_DELAY,
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
):
"""
Initialize the DefaultCommandExecutor instance.
Args:
failure_detectors: List of failure detector instances to monitor database health
databases: Collection of available databases to execute commands on
command_retry: Retry policy for failed command execution
failover_strategy: Strategy for handling database failover
event_dispatcher: Interface for dispatching events
failover_attempts: Number of failover attempts
failover_delay: Delay between failover attempts
auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database
"""
super().__init__(auto_fallback_interval)
for fd in failure_detectors:
fd.set_command_executor(command_executor=self)
self._databases = databases
self._failure_detectors = failure_detectors
self._command_retry = command_retry
self._failover_strategy_executor = DefaultFailoverStrategyExecutor(
failover_strategy, failover_attempts, failover_delay
)
self._event_dispatcher = event_dispatcher
self._active_database: Optional[Database] = None
self._active_pubsub: Optional[PubSub] = None
self._active_pubsub_kwargs = {}
self._setup_event_dispatcher()
self._schedule_next_fallback()
@property
def databases(self) -> Databases:
return self._databases
@property
def failure_detectors(self) -> List[AsyncFailureDetector]:
return self._failure_detectors
def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
self._failure_detectors.append(failure_detector)
@property
def active_database(self) -> Optional[AsyncDatabase]:
return self._active_database
async def set_active_database(self, database: AsyncDatabase) -> None:
old_active = self._active_database
self._active_database = database
if old_active is not None and old_active is not database:
await self._event_dispatcher.dispatch_async(
AsyncActiveDatabaseChanged(
old_active,
self._active_database,
self,
**self._active_pubsub_kwargs,
)
)
@property
def active_pubsub(self) -> Optional[PubSub]:
return self._active_pubsub
@active_pubsub.setter
def active_pubsub(self, pubsub: PubSub) -> None:
self._active_pubsub = pubsub
@property
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
return self._failover_strategy_executor
@property
def command_retry(self) -> Retry:
return self._command_retry
def pubsub(self, **kwargs):
if self._active_pubsub is None:
if isinstance(self._active_database.client, RedisCluster):
raise ValueError("PubSub is not supported for RedisCluster")
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
self._active_pubsub_kwargs = kwargs
async def execute_command(self, *args, **options):
async def callback():
response = await self._active_database.client.execute_command(
*args, **options
)
await self._register_command_execution(args)
return response
return await self._execute_with_failure_detection(callback, args)
async def execute_pipeline(self, command_stack: tuple):
async def callback():
async with self._active_database.client.pipeline() as pipe:
for command, options in command_stack:
pipe.execute_command(*command, **options)
response = await pipe.execute()
await self._register_command_execution(command_stack)
return response
return await self._execute_with_failure_detection(callback, command_stack)
async def execute_transaction(
self,
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
*watches: KeyT,
shard_hint: Optional[str] = None,
value_from_callable: bool = False,
watch_delay: Optional[float] = None,
):
async def callback():
response = await self._active_database.client.transaction(
func,
*watches,
shard_hint=shard_hint,
value_from_callable=value_from_callable,
watch_delay=watch_delay,
)
await self._register_command_execution(())
return response
return await self._execute_with_failure_detection(callback)
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
async def callback():
method = getattr(self.active_pubsub, method_name)
if iscoroutinefunction(method):
response = await method(*args, **kwargs)
else:
response = method(*args, **kwargs)
await self._register_command_execution(args)
return response
return await self._execute_with_failure_detection(callback, *args)
async def execute_pubsub_run(
self, sleep_time: float, exception_handler=None, pubsub=None
) -> Any:
async def callback():
return await self._active_pubsub.run(
poll_timeout=sleep_time,
exception_handler=exception_handler,
pubsub=pubsub,
)
return await self._execute_with_failure_detection(callback)
async def _execute_with_failure_detection(
self, callback: Callable, cmds: tuple = ()
):
"""
Execute a commands execution callback with failure detection.
"""
async def wrapper():
# On each retry we need to check active database as it might change.
await self._check_active_database()
return await callback()
return await self._command_retry.call_with_retry(
lambda: wrapper(),
lambda error: self._on_command_fail(error, *cmds),
)
async def _check_active_database(self):
"""
Checks if active a database needs to be updated.
"""
if (
self._active_database is None
or self._active_database.circuit.state != CBState.CLOSED
or (
self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL
and self._next_fallback_attempt <= datetime.now()
)
):
await self.set_active_database(
await self._failover_strategy_executor.execute()
)
self._schedule_next_fallback()
async def _on_command_fail(self, error, *args):
await self._event_dispatcher.dispatch_async(
AsyncOnCommandsFailEvent(args, error)
)
async def _register_command_execution(self, cmd: tuple):
for detector in self._failure_detectors:
await detector.register_command_execution(cmd)
def _setup_event_dispatcher(self):
"""
Registers necessary listeners.
"""
failure_listener = RegisterCommandFailure(self._failure_detectors)
resubscribe_listener = ResubscribeOnActiveDatabaseChanged()
close_connection_listener = CloseConnectionOnActiveDatabaseChanged()
self._event_dispatcher.register_listeners(
{
AsyncOnCommandsFailEvent: [failure_listener],
AsyncActiveDatabaseChanged: [
close_connection_listener,
resubscribe_listener,
],
}
)

View File

@@ -0,0 +1,210 @@
from dataclasses import dataclass, field
from typing import List, Optional, Type, Union
import pybreaker
from redis.asyncio import ConnectionPool, Redis, RedisCluster
from redis.asyncio.multidb.database import Database, Databases
from redis.asyncio.multidb.failover import (
DEFAULT_FAILOVER_ATTEMPTS,
DEFAULT_FAILOVER_DELAY,
AsyncFailoverStrategy,
WeightBasedFailoverStrategy,
)
from redis.asyncio.multidb.failure_detector import (
AsyncFailureDetector,
FailureDetectorAsyncWrapper,
)
from redis.asyncio.multidb.healthcheck import (
DEFAULT_HEALTH_CHECK_DELAY,
DEFAULT_HEALTH_CHECK_INTERVAL,
DEFAULT_HEALTH_CHECK_POLICY,
DEFAULT_HEALTH_CHECK_PROBES,
HealthCheck,
HealthCheckPolicies,
PingHealthCheck,
)
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
from redis.data_structure import WeightedList
from redis.event import EventDispatcher, EventDispatcherInterface
from redis.multidb.circuit import (
DEFAULT_GRACE_PERIOD,
CircuitBreaker,
PBCircuitBreakerAdapter,
)
from redis.multidb.failure_detector import (
DEFAULT_FAILURE_RATE_THRESHOLD,
DEFAULT_FAILURES_DETECTION_WINDOW,
DEFAULT_MIN_NUM_FAILURES,
CommandFailureDetector,
)
DEFAULT_AUTO_FALLBACK_INTERVAL = 120
def default_event_dispatcher() -> EventDispatcherInterface:
return EventDispatcher()
@dataclass
class DatabaseConfig:
"""
Dataclass representing the configuration for a database connection.
This class is used to store configuration settings for a database connection,
including client options, connection sourcing details, circuit breaker settings,
and cluster-specific properties. It provides a structure for defining these
attributes and allows for the creation of customized configurations for various
database setups.
Attributes:
weight (float): Weight of the database to define the active one.
client_kwargs (dict): Additional parameters for the database client connection.
from_url (Optional[str]): Redis URL way of connecting to the database.
from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use.
circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation.
grace_period (float): Grace period after which we need to check if the circuit could be closed again.
health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used
on public Redis Enterprise endpoints.
Methods:
default_circuit_breaker:
Generates and returns a default CircuitBreaker instance adapted for use.
"""
weight: float = 1.0
client_kwargs: dict = field(default_factory=dict)
from_url: Optional[str] = None
from_pool: Optional[ConnectionPool] = None
circuit: Optional[CircuitBreaker] = None
grace_period: float = DEFAULT_GRACE_PERIOD
health_check_url: Optional[str] = None
def default_circuit_breaker(self) -> CircuitBreaker:
circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period)
return PBCircuitBreakerAdapter(circuit_breaker)
@dataclass
class MultiDbConfig:
"""
Configuration class for managing multiple database connections in a resilient and fail-safe manner.
Attributes:
databases_config: A list of database configurations.
client_class: The client class used to manage database connections.
command_retry: Retry strategy for executing database commands.
failure_detectors: Optional list of additional failure detectors for monitoring database failures.
min_num_failures: Minimal count of failures required for failover
failure_rate_threshold: Percentage of failures required for failover
failures_detection_window: Time interval for tracking database failures.
health_checks: Optional list of additional health checks performed on databases.
health_check_interval: Time interval for executing health checks.
health_check_probes: Number of attempts to evaluate the health of a database.
health_check_delay: Delay between health check attempts.
failover_strategy: Optional strategy for handling database failover scenarios.
failover_attempts: Number of retries allowed for failover operations.
failover_delay: Delay between failover attempts.
auto_fallback_interval: Time interval to trigger automatic fallback.
event_dispatcher: Interface for dispatching events related to database operations.
Methods:
databases:
Retrieves a collection of database clients managed by weighted configurations.
Initializes database clients based on the provided configuration and removes
redundant retry objects for lower-level clients to rely on global retry logic.
default_failure_detectors:
Returns the default list of failure detectors used to monitor database failures.
default_health_checks:
Returns the default list of health checks used to monitor database health
with specific retry and backoff strategies.
default_failover_strategy:
Provides the default failover strategy used for handling failover scenarios
with defined retry and backoff configurations.
"""
databases_config: List[DatabaseConfig]
client_class: Type[Union[Redis, RedisCluster]] = Redis
command_retry: Retry = Retry(
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
)
failure_detectors: Optional[List[AsyncFailureDetector]] = None
min_num_failures: int = DEFAULT_MIN_NUM_FAILURES
failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD
failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW
health_checks: Optional[List[HealthCheck]] = None
health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL
health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES
health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY
health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY
failover_strategy: Optional[AsyncFailoverStrategy] = None
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS
failover_delay: float = DEFAULT_FAILOVER_DELAY
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL
event_dispatcher: EventDispatcherInterface = field(
default_factory=default_event_dispatcher
)
def databases(self) -> Databases:
databases = WeightedList()
for database_config in self.databases_config:
# The retry object is not used in the lower level clients, so we can safely remove it.
# We rely on command_retry in terms of global retries.
database_config.client_kwargs.update(
{"retry": Retry(retries=0, backoff=NoBackoff())}
)
if database_config.from_url:
client = self.client_class.from_url(
database_config.from_url, **database_config.client_kwargs
)
elif database_config.from_pool:
database_config.from_pool.set_retry(
Retry(retries=0, backoff=NoBackoff())
)
client = self.client_class.from_pool(
connection_pool=database_config.from_pool
)
else:
client = self.client_class(**database_config.client_kwargs)
circuit = (
database_config.default_circuit_breaker()
if database_config.circuit is None
else database_config.circuit
)
databases.add(
Database(
client=client,
circuit=circuit,
weight=database_config.weight,
health_check_url=database_config.health_check_url,
),
database_config.weight,
)
return databases
def default_failure_detectors(self) -> List[AsyncFailureDetector]:
return [
FailureDetectorAsyncWrapper(
CommandFailureDetector(
min_num_failures=self.min_num_failures,
failure_rate_threshold=self.failure_rate_threshold,
failure_detection_window=self.failures_detection_window,
)
),
]
def default_health_checks(self) -> List[HealthCheck]:
return [
PingHealthCheck(),
]
def default_failover_strategy(self) -> AsyncFailoverStrategy:
return WeightBasedFailoverStrategy()

View File

@@ -0,0 +1,69 @@
from abc import abstractmethod
from typing import Optional, Union
from redis.asyncio import Redis, RedisCluster
from redis.data_structure import WeightedList
from redis.multidb.circuit import CircuitBreaker
from redis.multidb.database import AbstractDatabase, BaseDatabase
from redis.typing import Number
class AsyncDatabase(AbstractDatabase):
"""Database with an underlying asynchronous redis client."""
@property
@abstractmethod
def client(self) -> Union[Redis, RedisCluster]:
"""The underlying redis client."""
pass
@client.setter
@abstractmethod
def client(self, client: Union[Redis, RedisCluster]):
"""Set the underlying redis client."""
pass
@property
@abstractmethod
def circuit(self) -> CircuitBreaker:
"""Circuit breaker for the current database."""
pass
@circuit.setter
@abstractmethod
def circuit(self, circuit: CircuitBreaker):
"""Set the circuit breaker for the current database."""
pass
Databases = WeightedList[tuple[AsyncDatabase, Number]]
class Database(BaseDatabase, AsyncDatabase):
def __init__(
self,
client: Union[Redis, RedisCluster],
circuit: CircuitBreaker,
weight: float,
health_check_url: Optional[str] = None,
):
self._client = client
self._cb = circuit
self._cb.database = self
super().__init__(weight, health_check_url)
@property
def client(self) -> Union[Redis, RedisCluster]:
return self._client
@client.setter
def client(self, client: Union[Redis, RedisCluster]):
self._client = client
@property
def circuit(self) -> CircuitBreaker:
return self._cb
@circuit.setter
def circuit(self, circuit: CircuitBreaker):
self._cb = circuit

View File

@@ -0,0 +1,84 @@
from typing import List
from redis.asyncio import Redis
from redis.asyncio.multidb.database import AsyncDatabase
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
from redis.event import AsyncEventListenerInterface, AsyncOnCommandsFailEvent
class AsyncActiveDatabaseChanged:
"""
Event fired when an async active database has been changed.
"""
def __init__(
self,
old_database: AsyncDatabase,
new_database: AsyncDatabase,
command_executor,
**kwargs,
):
self._old_database = old_database
self._new_database = new_database
self._command_executor = command_executor
self._kwargs = kwargs
@property
def old_database(self) -> AsyncDatabase:
return self._old_database
@property
def new_database(self) -> AsyncDatabase:
return self._new_database
@property
def command_executor(self):
return self._command_executor
@property
def kwargs(self):
return self._kwargs
class ResubscribeOnActiveDatabaseChanged(AsyncEventListenerInterface):
"""
Re-subscribe the currently active pub / sub to a new active database.
"""
async def listen(self, event: AsyncActiveDatabaseChanged):
old_pubsub = event.command_executor.active_pubsub
if old_pubsub is not None:
# Re-assign old channels and patterns so they will be automatically subscribed on connection.
new_pubsub = event.new_database.client.pubsub(**event.kwargs)
new_pubsub.channels = old_pubsub.channels
new_pubsub.patterns = old_pubsub.patterns
await new_pubsub.on_connect(None)
event.command_executor.active_pubsub = new_pubsub
await old_pubsub.aclose()
class CloseConnectionOnActiveDatabaseChanged(AsyncEventListenerInterface):
"""
Close connection to the old active database.
"""
async def listen(self, event: AsyncActiveDatabaseChanged):
await event.old_database.client.aclose()
if isinstance(event.old_database.client, Redis):
await event.old_database.client.connection_pool.update_active_connections_for_reconnect()
await event.old_database.client.connection_pool.disconnect()
class RegisterCommandFailure(AsyncEventListenerInterface):
"""
Event listener that registers command failures and passing it to the failure detectors.
"""
def __init__(self, failure_detectors: List[AsyncFailureDetector]):
self._failure_detectors = failure_detectors
async def listen(self, event: AsyncOnCommandsFailEvent) -> None:
for failure_detector in self._failure_detectors:
await failure_detector.register_failure(event.exception, event.commands)

View File

@@ -0,0 +1,125 @@
import time
from abc import ABC, abstractmethod
from redis.asyncio.multidb.database import AsyncDatabase, Databases
from redis.data_structure import WeightedList
from redis.multidb.circuit import State as CBState
from redis.multidb.exception import (
NoValidDatabaseException,
TemporaryUnavailableException,
)
DEFAULT_FAILOVER_ATTEMPTS = 10
DEFAULT_FAILOVER_DELAY = 12
class AsyncFailoverStrategy(ABC):
@abstractmethod
async def database(self) -> AsyncDatabase:
"""Select the database according to the strategy."""
pass
@abstractmethod
def set_databases(self, databases: Databases) -> None:
"""Set the database strategy operates on."""
pass
class FailoverStrategyExecutor(ABC):
@property
@abstractmethod
def failover_attempts(self) -> int:
"""The number of failover attempts."""
pass
@property
@abstractmethod
def failover_delay(self) -> float:
"""The delay between failover attempts."""
pass
@property
@abstractmethod
def strategy(self) -> AsyncFailoverStrategy:
"""The strategy to execute."""
pass
@abstractmethod
async def execute(self) -> AsyncDatabase:
"""Execute the failover strategy."""
pass
class WeightBasedFailoverStrategy(AsyncFailoverStrategy):
"""
Failover strategy based on database weights.
"""
def __init__(self):
self._databases = WeightedList()
async def database(self) -> AsyncDatabase:
for database, _ in self._databases:
if database.circuit.state == CBState.CLOSED:
return database
raise NoValidDatabaseException("No valid database available for communication")
def set_databases(self, databases: Databases) -> None:
self._databases = databases
class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor):
"""
Executes given failover strategy.
"""
def __init__(
self,
strategy: AsyncFailoverStrategy,
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
failover_delay: float = DEFAULT_FAILOVER_DELAY,
):
self._strategy = strategy
self._failover_attempts = failover_attempts
self._failover_delay = failover_delay
self._next_attempt_ts: int = 0
self._failover_counter: int = 0
@property
def failover_attempts(self) -> int:
return self._failover_attempts
@property
def failover_delay(self) -> float:
return self._failover_delay
@property
def strategy(self) -> AsyncFailoverStrategy:
return self._strategy
async def execute(self) -> AsyncDatabase:
try:
database = await self._strategy.database()
self._reset()
return database
except NoValidDatabaseException as e:
if self._next_attempt_ts == 0:
self._next_attempt_ts = time.time() + self._failover_delay
self._failover_counter += 1
elif time.time() >= self._next_attempt_ts:
self._next_attempt_ts += self._failover_delay
self._failover_counter += 1
if self._failover_counter > self._failover_attempts:
self._reset()
raise e
else:
raise TemporaryUnavailableException(
"No database connections currently available. "
"This is a temporary condition - please retry the operation."
)
def _reset(self) -> None:
self._next_attempt_ts = 0
self._failover_counter = 0

View File

@@ -0,0 +1,38 @@
from abc import ABC, abstractmethod
from redis.multidb.failure_detector import FailureDetector
class AsyncFailureDetector(ABC):
@abstractmethod
async def register_failure(self, exception: Exception, cmd: tuple) -> None:
"""Register a failure that occurred during command execution."""
pass
@abstractmethod
async def register_command_execution(self, cmd: tuple) -> None:
"""Register a command execution."""
pass
@abstractmethod
def set_command_executor(self, command_executor) -> None:
"""Set the command executor for this failure."""
pass
class FailureDetectorAsyncWrapper(AsyncFailureDetector):
"""
Async wrapper for the failure detector.
"""
def __init__(self, failure_detector: FailureDetector) -> None:
self._failure_detector = failure_detector
async def register_failure(self, exception: Exception, cmd: tuple) -> None:
self._failure_detector.register_failure(exception, cmd)
async def register_command_execution(self, cmd: tuple) -> None:
self._failure_detector.register_command_execution(cmd)
def set_command_executor(self, command_executor) -> None:
self._failure_detector.set_command_executor(command_executor)

View File

@@ -0,0 +1,285 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Optional, Tuple, Union
from redis.asyncio import Redis
from redis.asyncio.http.http_client import DEFAULT_TIMEOUT, AsyncHTTPClientWrapper
from redis.backoff import NoBackoff
from redis.http.http_client import HttpClient
from redis.multidb.exception import UnhealthyDatabaseException
from redis.retry import Retry
DEFAULT_HEALTH_CHECK_PROBES = 3
DEFAULT_HEALTH_CHECK_INTERVAL = 5
DEFAULT_HEALTH_CHECK_DELAY = 0.5
DEFAULT_LAG_AWARE_TOLERANCE = 5000
logger = logging.getLogger(__name__)
class HealthCheck(ABC):
@abstractmethod
async def check_health(self, database) -> bool:
"""Function to determine the health status."""
pass
class HealthCheckPolicy(ABC):
"""
Health checks execution policy.
"""
@property
@abstractmethod
def health_check_probes(self) -> int:
"""Number of probes to execute health checks."""
pass
@property
@abstractmethod
def health_check_delay(self) -> float:
"""Delay between health check probes."""
pass
@abstractmethod
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
"""Execute health checks and return database health status."""
pass
class AbstractHealthCheckPolicy(HealthCheckPolicy):
def __init__(self, health_check_probes: int, health_check_delay: float):
if health_check_probes < 1:
raise ValueError("health_check_probes must be greater than 0")
self._health_check_probes = health_check_probes
self._health_check_delay = health_check_delay
@property
def health_check_probes(self) -> int:
return self._health_check_probes
@property
def health_check_delay(self) -> float:
return self._health_check_delay
@abstractmethod
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
pass
class HealthyAllPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if all health check probes are successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
for health_check in health_checks:
for attempt in range(self.health_check_probes):
try:
if not await health_check.check_health(database):
return False
except Exception as e:
raise UnhealthyDatabaseException("Unhealthy database", database, e)
if attempt < self.health_check_probes - 1:
await asyncio.sleep(self._health_check_delay)
return True
class HealthyMajorityPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if a majority of health check probes are successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
for health_check in health_checks:
if self.health_check_probes % 2 == 0:
allowed_unsuccessful_probes = self.health_check_probes / 2
else:
allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2
for attempt in range(self.health_check_probes):
try:
if not await health_check.check_health(database):
allowed_unsuccessful_probes -= 1
if allowed_unsuccessful_probes <= 0:
return False
except Exception as e:
allowed_unsuccessful_probes -= 1
if allowed_unsuccessful_probes <= 0:
raise UnhealthyDatabaseException(
"Unhealthy database", database, e
)
if attempt < self.health_check_probes - 1:
await asyncio.sleep(self._health_check_delay)
return True
class HealthyAnyPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if at least one health check probe is successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
is_healthy = False
for health_check in health_checks:
exception = None
for attempt in range(self.health_check_probes):
try:
if await health_check.check_health(database):
is_healthy = True
break
else:
is_healthy = False
except Exception as e:
exception = UnhealthyDatabaseException(
"Unhealthy database", database, e
)
if attempt < self.health_check_probes - 1:
await asyncio.sleep(self._health_check_delay)
if not is_healthy and not exception:
return is_healthy
elif not is_healthy and exception:
raise exception
return is_healthy
class HealthCheckPolicies(Enum):
HEALTHY_ALL = HealthyAllPolicy
HEALTHY_MAJORITY = HealthyMajorityPolicy
HEALTHY_ANY = HealthyAnyPolicy
DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL
class PingHealthCheck(HealthCheck):
"""
Health check based on PING command.
"""
async def check_health(self, database) -> bool:
if isinstance(database.client, Redis):
return await database.client.execute_command("PING")
else:
# For a cluster checks if all nodes are healthy.
all_nodes = database.client.get_nodes()
for node in all_nodes:
if not await node.redis_connection.execute_command("PING"):
return False
return True
class LagAwareHealthCheck(HealthCheck):
"""
Health check available for Redis Enterprise deployments.
Verify via REST API that the database is healthy based on different lags.
"""
def __init__(
self,
rest_api_port: int = 9443,
lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE,
timeout: float = DEFAULT_TIMEOUT,
auth_basic: Optional[Tuple[str, str]] = None,
verify_tls: bool = True,
# TLS verification (server) options
ca_file: Optional[str] = None,
ca_path: Optional[str] = None,
ca_data: Optional[Union[str, bytes]] = None,
# Mutual TLS (client cert) options
client_cert_file: Optional[str] = None,
client_key_file: Optional[str] = None,
client_key_password: Optional[str] = None,
):
"""
Initialize LagAwareHealthCheck with the specified parameters.
Args:
rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100)
timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
auth_basic: Tuple of (username, password) for basic authentication
verify_tls: Whether to verify TLS certificates (default: True)
ca_file: Path to CA certificate file for TLS verification
ca_path: Path to CA certificates directory for TLS verification
ca_data: CA certificate data as string or bytes
client_cert_file: Path to client certificate file for mutual TLS
client_key_file: Path to client private key file for mutual TLS
client_key_password: Password for encrypted client private key
"""
self._http_client = AsyncHTTPClientWrapper(
HttpClient(
timeout=timeout,
auth_basic=auth_basic,
retry=Retry(NoBackoff(), retries=0),
verify_tls=verify_tls,
ca_file=ca_file,
ca_path=ca_path,
ca_data=ca_data,
client_cert_file=client_cert_file,
client_key_file=client_key_file,
client_key_password=client_key_password,
)
)
self._rest_api_port = rest_api_port
self._lag_aware_tolerance = lag_aware_tolerance
async def check_health(self, database) -> bool:
if database.health_check_url is None:
raise ValueError(
"Database health check url is not set. Please check DatabaseConfig for the current database."
)
if isinstance(database.client, Redis):
db_host = database.client.get_connection_kwargs()["host"]
else:
db_host = database.client.startup_nodes[0].host
base_url = f"{database.health_check_url}:{self._rest_api_port}"
self._http_client.client.base_url = base_url
# Find bdb matching to the current database host
matching_bdb = None
for bdb in await self._http_client.get("/v1/bdbs"):
for endpoint in bdb["endpoints"]:
if endpoint["dns_name"] == db_host:
matching_bdb = bdb
break
# In case if the host was set as public IP
for addr in endpoint["addr"]:
if addr == db_host:
matching_bdb = bdb
break
if matching_bdb is None:
logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
raise ValueError("Could not find a matching bdb")
url = (
f"/v1/bdbs/{matching_bdb['uid']}/availability"
f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}"
)
await self._http_client.get(url, expect_json=False)
# Status checked in an http client, otherwise HttpError will be raised
return True