Initial commit: igny8 project

This commit is contained in:
igny8
2025-11-09 10:27:02 +00:00
commit 60b8188111
27265 changed files with 4360521 additions and 0 deletions

View File

@@ -0,0 +1,144 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Callable
import pybreaker
DEFAULT_GRACE_PERIOD = 60
class State(Enum):
CLOSED = "closed"
OPEN = "open"
HALF_OPEN = "half-open"
class CircuitBreaker(ABC):
@property
@abstractmethod
def grace_period(self) -> float:
"""The grace period in seconds when the circle should be kept open."""
pass
@grace_period.setter
@abstractmethod
def grace_period(self, grace_period: float):
"""Set the grace period in seconds."""
@property
@abstractmethod
def state(self) -> State:
"""The current state of the circuit."""
pass
@state.setter
@abstractmethod
def state(self, state: State):
"""Set current state of the circuit."""
pass
@property
@abstractmethod
def database(self):
"""Database associated with this circuit."""
pass
@database.setter
@abstractmethod
def database(self, database):
"""Set database associated with this circuit."""
pass
@abstractmethod
def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]):
"""Callback called when the state of the circuit changes."""
pass
class BaseCircuitBreaker(CircuitBreaker):
"""
Base implementation of Circuit Breaker interface.
"""
def __init__(self, cb: pybreaker.CircuitBreaker):
self._cb = cb
self._state_pb_mapper = {
State.CLOSED: self._cb.close,
State.OPEN: self._cb.open,
State.HALF_OPEN: self._cb.half_open,
}
self._database = None
@property
def grace_period(self) -> float:
return self._cb.reset_timeout
@grace_period.setter
def grace_period(self, grace_period: float):
self._cb.reset_timeout = grace_period
@property
def state(self) -> State:
return State(value=self._cb.state.name)
@state.setter
def state(self, state: State):
self._state_pb_mapper[state]()
@property
def database(self):
return self._database
@database.setter
def database(self, database):
self._database = database
@abstractmethod
def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]):
"""Callback called when the state of the circuit changes."""
pass
class PBListener(pybreaker.CircuitBreakerListener):
"""Wrapper for callback to be compatible with pybreaker implementation."""
def __init__(
self,
cb: Callable[[CircuitBreaker, State, State], None],
database,
):
"""
Initialize a PBListener instance.
Args:
cb: Callback function that will be called when the circuit breaker state changes.
database: Database instance associated with this circuit breaker.
"""
self._cb = cb
self._database = database
def state_change(self, cb, old_state, new_state):
cb = PBCircuitBreakerAdapter(cb)
cb.database = self._database
old_state = State(value=old_state.name)
new_state = State(value=new_state.name)
self._cb(cb, old_state, new_state)
class PBCircuitBreakerAdapter(BaseCircuitBreaker):
def __init__(self, cb: pybreaker.CircuitBreaker):
"""
Initialize a PBCircuitBreakerAdapter instance.
This adapter wraps pybreaker's CircuitBreaker implementation to make it compatible
with our CircuitBreaker interface.
Args:
cb: A pybreaker CircuitBreaker instance to be adapted.
"""
super().__init__(cb)
def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]):
listener = PBListener(cb, self.database)
self._cb.add_listener(listener)

View File

@@ -0,0 +1,526 @@
import logging
import threading
from concurrent.futures import as_completed
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Any, Callable, List, Optional
from redis.background import BackgroundScheduler
from redis.client import PubSubWorkerThread
from redis.commands import CoreCommands, RedisModuleCommands
from redis.multidb.circuit import CircuitBreaker
from redis.multidb.circuit import State as CBState
from redis.multidb.command_executor import DefaultCommandExecutor
from redis.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig
from redis.multidb.database import Database, Databases, SyncDatabase
from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException
from redis.multidb.failure_detector import FailureDetector
from redis.multidb.healthcheck import HealthCheck, HealthCheckPolicy
from redis.utils import experimental
logger = logging.getLogger(__name__)
@experimental
class MultiDBClient(RedisModuleCommands, CoreCommands):
"""
Client that operates on multiple logical Redis databases.
Should be used in Active-Active database setups.
"""
def __init__(self, config: MultiDbConfig):
self._databases = config.databases()
self._health_checks = config.default_health_checks()
if config.health_checks is not None:
self._health_checks.extend(config.health_checks)
self._health_check_interval = config.health_check_interval
self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value(
config.health_check_probes, config.health_check_probes_delay
)
self._failure_detectors = config.default_failure_detectors()
if config.failure_detectors is not None:
self._failure_detectors.extend(config.failure_detectors)
self._failover_strategy = (
config.default_failover_strategy()
if config.failover_strategy is None
else config.failover_strategy
)
self._failover_strategy.set_databases(self._databases)
self._auto_fallback_interval = config.auto_fallback_interval
self._event_dispatcher = config.event_dispatcher
self._command_retry = config.command_retry
self._command_retry.update_supported_errors((ConnectionRefusedError,))
self.command_executor = DefaultCommandExecutor(
failure_detectors=self._failure_detectors,
databases=self._databases,
command_retry=self._command_retry,
failover_strategy=self._failover_strategy,
failover_attempts=config.failover_attempts,
failover_delay=config.failover_delay,
event_dispatcher=self._event_dispatcher,
auto_fallback_interval=self._auto_fallback_interval,
)
self.initialized = False
self._hc_lock = threading.RLock()
self._bg_scheduler = BackgroundScheduler()
self._config = config
def initialize(self):
"""
Perform initialization of databases to define their initial state.
"""
def raise_exception_on_failed_hc(error):
raise error
# Initial databases check to define initial state
self._check_databases_health(on_error=raise_exception_on_failed_hc)
# Starts recurring health checks on the background.
self._bg_scheduler.run_recurring(
self._health_check_interval,
self._check_databases_health,
)
is_active_db_found = False
for database, weight in self._databases:
# Set on state changed callback for each circuit.
database.circuit.on_state_changed(self._on_circuit_state_change_callback)
# Set states according to a weights and circuit state
if database.circuit.state == CBState.CLOSED and not is_active_db_found:
self.command_executor.active_database = database
is_active_db_found = True
if not is_active_db_found:
raise NoValidDatabaseException(
"Initial connection failed - no active database found"
)
self.initialized = True
def get_databases(self) -> Databases:
"""
Returns a sorted (by weight) list of all databases.
"""
return self._databases
def set_active_database(self, database: SyncDatabase) -> None:
"""
Promote one of the existing databases to become an active.
"""
exists = None
for existing_db, _ in self._databases:
if existing_db == database:
exists = True
break
if not exists:
raise ValueError("Given database is not a member of database list")
self._check_db_health(database)
if database.circuit.state == CBState.CLOSED:
highest_weighted_db, _ = self._databases.get_top_n(1)[0]
self.command_executor.active_database = database
return
raise NoValidDatabaseException(
"Cannot set active database, database is unhealthy"
)
def add_database(self, database: SyncDatabase):
"""
Adds a new database to the database list.
"""
for existing_db, _ in self._databases:
if existing_db == database:
raise ValueError("Given database already exists")
self._check_db_health(database)
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
self._databases.add(database, database.weight)
self._change_active_database(database, highest_weighted_db)
def _change_active_database(
self, new_database: SyncDatabase, highest_weight_database: SyncDatabase
):
if (
new_database.weight > highest_weight_database.weight
and new_database.circuit.state == CBState.CLOSED
):
self.command_executor.active_database = new_database
def remove_database(self, database: Database):
"""
Removes a database from the database list.
"""
weight = self._databases.remove(database)
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
if (
highest_weight <= weight
and highest_weighted_db.circuit.state == CBState.CLOSED
):
self.command_executor.active_database = highest_weighted_db
def update_database_weight(self, database: SyncDatabase, weight: float):
"""
Updates a database from the database list.
"""
exists = None
for existing_db, _ in self._databases:
if existing_db == database:
exists = True
break
if not exists:
raise ValueError("Given database is not a member of database list")
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
self._databases.update_weight(database, weight)
database.weight = weight
self._change_active_database(database, highest_weighted_db)
def add_failure_detector(self, failure_detector: FailureDetector):
"""
Adds a new failure detector to the database.
"""
self._failure_detectors.append(failure_detector)
def add_health_check(self, healthcheck: HealthCheck):
"""
Adds a new health check to the database.
"""
with self._hc_lock:
self._health_checks.append(healthcheck)
def execute_command(self, *args, **options):
"""
Executes a single command and return its result.
"""
if not self.initialized:
self.initialize()
return self.command_executor.execute_command(*args, **options)
def pipeline(self):
"""
Enters into pipeline mode of the client.
"""
return Pipeline(self)
def transaction(self, func: Callable[["Pipeline"], None], *watches, **options):
"""
Executes callable as transaction.
"""
if not self.initialized:
self.initialize()
return self.command_executor.execute_transaction(func, *watches, *options)
def pubsub(self, **kwargs):
"""
Return a Publish/Subscribe object. With this object, you can
subscribe to channels and listen for messages that get published to
them.
"""
if not self.initialized:
self.initialize()
return PubSub(self, **kwargs)
def _check_db_health(self, database: SyncDatabase) -> bool:
"""
Runs health checks on the given database until first failure.
"""
# Health check will setup circuit state
is_healthy = self._health_check_policy.execute(self._health_checks, database)
if not is_healthy:
if database.circuit.state != CBState.OPEN:
database.circuit.state = CBState.OPEN
return is_healthy
elif is_healthy and database.circuit.state != CBState.CLOSED:
database.circuit.state = CBState.CLOSED
return is_healthy
def _check_databases_health(self, on_error: Callable[[Exception], None] = None):
"""
Runs health checks as a recurring task.
Runs health checks against all databases.
"""
with ThreadPoolExecutor(max_workers=len(self._databases)) as executor:
# Submit all health checks
futures = {
executor.submit(self._check_db_health, database)
for database, _ in self._databases
}
try:
for future in as_completed(
futures, timeout=self._health_check_interval
):
try:
future.result()
except UnhealthyDatabaseException as e:
unhealthy_db = e.database
unhealthy_db.circuit.state = CBState.OPEN
logger.exception(
"Health check failed, due to exception",
exc_info=e.original_exception,
)
if on_error:
on_error(e.original_exception)
except TimeoutError:
raise TimeoutError(
"Health check execution exceeds health_check_interval"
)
def _on_circuit_state_change_callback(
self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState
):
if new_state == CBState.HALF_OPEN:
self._check_db_health(circuit.database)
return
if old_state == CBState.CLOSED and new_state == CBState.OPEN:
self._bg_scheduler.run_once(
DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit
)
def close(self):
self.command_executor.active_database.client.close()
def _half_open_circuit(circuit: CircuitBreaker):
circuit.state = CBState.HALF_OPEN
class Pipeline(RedisModuleCommands, CoreCommands):
"""
Pipeline implementation for multiple logical Redis databases.
"""
def __init__(self, client: MultiDBClient):
self._command_stack = []
self._client = client
def __enter__(self) -> "Pipeline":
return self
def __exit__(self, exc_type, exc_value, traceback):
self.reset()
def __del__(self):
try:
self.reset()
except Exception:
pass
def __len__(self) -> int:
return len(self._command_stack)
def __bool__(self) -> bool:
"""Pipeline instances should always evaluate to True"""
return True
def reset(self) -> None:
self._command_stack = []
def close(self) -> None:
"""Close the pipeline"""
self.reset()
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
"""
Stage a command to be executed when execute() is next called
Returns the current Pipeline object back so commands can be
chained together, such as:
pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
At some other point, you can then run: pipe.execute(),
which will execute all commands queued in the pipe.
"""
self._command_stack.append((args, options))
return self
def execute_command(self, *args, **kwargs):
"""Adds a command to the stack"""
return self.pipeline_execute_command(*args, **kwargs)
def execute(self) -> List[Any]:
"""Execute all the commands in the current pipeline"""
if not self._client.initialized:
self._client.initialize()
try:
return self._client.command_executor.execute_pipeline(
tuple(self._command_stack)
)
finally:
self.reset()
class PubSub:
"""
PubSub object for multi database client.
"""
def __init__(self, client: MultiDBClient, **kwargs):
"""Initialize the PubSub object for a multi-database client.
Args:
client: MultiDBClient instance to use for pub/sub operations
**kwargs: Additional keyword arguments to pass to the underlying pubsub implementation
"""
self._client = client
self._client.command_executor.pubsub(**kwargs)
def __enter__(self) -> "PubSub":
return self
def __del__(self) -> None:
try:
# if this object went out of scope prior to shutting down
# subscriptions, close the connection manually before
# returning it to the connection pool
self.reset()
except Exception:
pass
def reset(self) -> None:
return self._client.command_executor.execute_pubsub_method("reset")
def close(self) -> None:
self.reset()
@property
def subscribed(self) -> bool:
return self._client.command_executor.active_pubsub.subscribed
def execute_command(self, *args):
return self._client.command_executor.execute_pubsub_method(
"execute_command", *args
)
def psubscribe(self, *args, **kwargs):
"""
Subscribe to channel patterns. Patterns supplied as keyword arguments
expect a pattern name as the key and a callable as the value. A
pattern's callable will be invoked automatically when a message is
received on that pattern rather than producing a message via
``listen()``.
"""
return self._client.command_executor.execute_pubsub_method(
"psubscribe", *args, **kwargs
)
def punsubscribe(self, *args):
"""
Unsubscribe from the supplied patterns. If empty, unsubscribe from
all patterns.
"""
return self._client.command_executor.execute_pubsub_method(
"punsubscribe", *args
)
def subscribe(self, *args, **kwargs):
"""
Subscribe to channels. Channels supplied as keyword arguments expect
a channel name as the key and a callable as the value. A channel's
callable will be invoked automatically when a message is received on
that channel rather than producing a message via ``listen()`` or
``get_message()``.
"""
return self._client.command_executor.execute_pubsub_method(
"subscribe", *args, **kwargs
)
def unsubscribe(self, *args):
"""
Unsubscribe from the supplied channels. If empty, unsubscribe from
all channels
"""
return self._client.command_executor.execute_pubsub_method("unsubscribe", *args)
def ssubscribe(self, *args, **kwargs):
"""
Subscribes the client to the specified shard channels.
Channels supplied as keyword arguments expect a channel name as the key
and a callable as the value. A channel's callable will be invoked automatically
when a message is received on that channel rather than producing a message via
``listen()`` or ``get_sharded_message()``.
"""
return self._client.command_executor.execute_pubsub_method(
"ssubscribe", *args, **kwargs
)
def sunsubscribe(self, *args):
"""
Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
all shard_channels
"""
return self._client.command_executor.execute_pubsub_method(
"sunsubscribe", *args
)
def get_message(
self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
):
"""
Get the next message if one is available, otherwise None.
If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number, or None, to wait indefinitely.
"""
return self._client.command_executor.execute_pubsub_method(
"get_message",
ignore_subscribe_messages=ignore_subscribe_messages,
timeout=timeout,
)
def get_sharded_message(
self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
):
"""
Get the next message if one is available in a sharded channel, otherwise None.
If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number, or None, to wait indefinitely.
"""
return self._client.command_executor.execute_pubsub_method(
"get_sharded_message",
ignore_subscribe_messages=ignore_subscribe_messages,
timeout=timeout,
)
def run_in_thread(
self,
sleep_time: float = 0.0,
daemon: bool = False,
exception_handler: Optional[Callable] = None,
sharded_pubsub: bool = False,
) -> "PubSubWorkerThread":
return self._client.command_executor.execute_pubsub_run(
sleep_time,
daemon=daemon,
exception_handler=exception_handler,
pubsub=self,
sharded_pubsub=sharded_pubsub,
)

View File

@@ -0,0 +1,350 @@
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import Any, Callable, List, Optional
from redis.client import Pipeline, PubSub, PubSubWorkerThread
from redis.event import EventDispatcherInterface, OnCommandsFailEvent
from redis.multidb.circuit import State as CBState
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
from redis.multidb.database import Database, Databases, SyncDatabase
from redis.multidb.event import (
ActiveDatabaseChanged,
CloseConnectionOnActiveDatabaseChanged,
RegisterCommandFailure,
ResubscribeOnActiveDatabaseChanged,
)
from redis.multidb.failover import (
DEFAULT_FAILOVER_ATTEMPTS,
DEFAULT_FAILOVER_DELAY,
DefaultFailoverStrategyExecutor,
FailoverStrategy,
FailoverStrategyExecutor,
)
from redis.multidb.failure_detector import FailureDetector
from redis.retry import Retry
class CommandExecutor(ABC):
@property
@abstractmethod
def auto_fallback_interval(self) -> float:
"""Returns auto-fallback interval."""
pass
@auto_fallback_interval.setter
@abstractmethod
def auto_fallback_interval(self, auto_fallback_interval: float) -> None:
"""Sets auto-fallback interval."""
pass
class BaseCommandExecutor(CommandExecutor):
def __init__(
self,
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
):
self._auto_fallback_interval = auto_fallback_interval
self._next_fallback_attempt: datetime
@property
def auto_fallback_interval(self) -> float:
return self._auto_fallback_interval
@auto_fallback_interval.setter
def auto_fallback_interval(self, auto_fallback_interval: int) -> None:
self._auto_fallback_interval = auto_fallback_interval
def _schedule_next_fallback(self) -> None:
if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL:
return
self._next_fallback_attempt = datetime.now() + timedelta(
seconds=self._auto_fallback_interval
)
class SyncCommandExecutor(CommandExecutor):
@property
@abstractmethod
def databases(self) -> Databases:
"""Returns a list of databases."""
pass
@property
@abstractmethod
def failure_detectors(self) -> List[FailureDetector]:
"""Returns a list of failure detectors."""
pass
@abstractmethod
def add_failure_detector(self, failure_detector: FailureDetector) -> None:
"""Adds a new failure detector to the list of failure detectors."""
pass
@property
@abstractmethod
def active_database(self) -> Optional[Database]:
"""Returns currently active database."""
pass
@active_database.setter
@abstractmethod
def active_database(self, database: SyncDatabase) -> None:
"""Sets the currently active database."""
pass
@property
@abstractmethod
def active_pubsub(self) -> Optional[PubSub]:
"""Returns currently active pubsub."""
pass
@active_pubsub.setter
@abstractmethod
def active_pubsub(self, pubsub: PubSub) -> None:
"""Sets currently active pubsub."""
pass
@property
@abstractmethod
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
"""Returns failover strategy executor."""
pass
@property
@abstractmethod
def command_retry(self) -> Retry:
"""Returns command retry object."""
pass
@abstractmethod
def pubsub(self, **kwargs):
"""Initializes a PubSub object on a currently active database"""
pass
@abstractmethod
def execute_command(self, *args, **options):
"""Executes a command and returns the result."""
pass
@abstractmethod
def execute_pipeline(self, command_stack: tuple):
"""Executes a stack of commands in pipeline."""
pass
@abstractmethod
def execute_transaction(
self, transaction: Callable[[Pipeline], None], *watches, **options
):
"""Executes a transaction block wrapped in callback."""
pass
@abstractmethod
def execute_pubsub_method(self, method_name: str, *args, **kwargs):
"""Executes a given method on active pub/sub."""
pass
@abstractmethod
def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any:
"""Executes pub/sub run in a thread."""
pass
class DefaultCommandExecutor(SyncCommandExecutor, BaseCommandExecutor):
def __init__(
self,
failure_detectors: List[FailureDetector],
databases: Databases,
command_retry: Retry,
failover_strategy: FailoverStrategy,
event_dispatcher: EventDispatcherInterface,
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
failover_delay: float = DEFAULT_FAILOVER_DELAY,
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
):
"""
Initialize the DefaultCommandExecutor instance.
Args:
failure_detectors: List of failure detector instances to monitor database health
databases: Collection of available databases to execute commands on
command_retry: Retry policy for failed command execution
failover_strategy: Strategy for handling database failover
event_dispatcher: Interface for dispatching events
failover_attempts: Number of failover attempts
failover_delay: Delay between failover attempts
auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database
"""
super().__init__(auto_fallback_interval)
for fd in failure_detectors:
fd.set_command_executor(command_executor=self)
self._databases = databases
self._failure_detectors = failure_detectors
self._command_retry = command_retry
self._failover_strategy_executor = DefaultFailoverStrategyExecutor(
failover_strategy, failover_attempts, failover_delay
)
self._event_dispatcher = event_dispatcher
self._active_database: Optional[Database] = None
self._active_pubsub: Optional[PubSub] = None
self._active_pubsub_kwargs = {}
self._setup_event_dispatcher()
self._schedule_next_fallback()
@property
def databases(self) -> Databases:
return self._databases
@property
def failure_detectors(self) -> List[FailureDetector]:
return self._failure_detectors
def add_failure_detector(self, failure_detector: FailureDetector) -> None:
self._failure_detectors.append(failure_detector)
@property
def command_retry(self) -> Retry:
return self._command_retry
@property
def active_database(self) -> Optional[SyncDatabase]:
return self._active_database
@active_database.setter
def active_database(self, database: SyncDatabase) -> None:
old_active = self._active_database
self._active_database = database
if old_active is not None and old_active is not database:
self._event_dispatcher.dispatch(
ActiveDatabaseChanged(
old_active,
self._active_database,
self,
**self._active_pubsub_kwargs,
)
)
@property
def active_pubsub(self) -> Optional[PubSub]:
return self._active_pubsub
@active_pubsub.setter
def active_pubsub(self, pubsub: PubSub) -> None:
self._active_pubsub = pubsub
@property
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
return self._failover_strategy_executor
def execute_command(self, *args, **options):
def callback():
response = self._active_database.client.execute_command(*args, **options)
self._register_command_execution(args)
return response
return self._execute_with_failure_detection(callback, args)
def execute_pipeline(self, command_stack: tuple):
def callback():
with self._active_database.client.pipeline() as pipe:
for command, options in command_stack:
pipe.execute_command(*command, **options)
response = pipe.execute()
self._register_command_execution(command_stack)
return response
return self._execute_with_failure_detection(callback, command_stack)
def execute_transaction(
self, transaction: Callable[[Pipeline], None], *watches, **options
):
def callback():
response = self._active_database.client.transaction(
transaction, *watches, **options
)
self._register_command_execution(())
return response
return self._execute_with_failure_detection(callback)
def pubsub(self, **kwargs):
def callback():
if self._active_pubsub is None:
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
self._active_pubsub_kwargs = kwargs
return None
return self._execute_with_failure_detection(callback)
def execute_pubsub_method(self, method_name: str, *args, **kwargs):
def callback():
method = getattr(self.active_pubsub, method_name)
response = method(*args, **kwargs)
self._register_command_execution(args)
return response
return self._execute_with_failure_detection(callback, *args)
def execute_pubsub_run(self, sleep_time, **kwargs) -> "PubSubWorkerThread":
def callback():
return self._active_pubsub.run_in_thread(sleep_time, **kwargs)
return self._execute_with_failure_detection(callback)
def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()):
"""
Execute a commands execution callback with failure detection.
"""
def wrapper():
# On each retry we need to check active database as it might change.
self._check_active_database()
return callback()
return self._command_retry.call_with_retry(
lambda: wrapper(),
lambda error: self._on_command_fail(error, *cmds),
)
def _on_command_fail(self, error, *args):
self._event_dispatcher.dispatch(OnCommandsFailEvent(args, error))
def _check_active_database(self):
"""
Checks if active a database needs to be updated.
"""
if (
self._active_database is None
or self._active_database.circuit.state != CBState.CLOSED
or (
self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL
and self._next_fallback_attempt <= datetime.now()
)
):
self.active_database = self._failover_strategy_executor.execute()
self._schedule_next_fallback()
def _register_command_execution(self, cmd: tuple):
for detector in self._failure_detectors:
detector.register_command_execution(cmd)
def _setup_event_dispatcher(self):
"""
Registers necessary listeners.
"""
failure_listener = RegisterCommandFailure(self._failure_detectors)
resubscribe_listener = ResubscribeOnActiveDatabaseChanged()
close_connection_listener = CloseConnectionOnActiveDatabaseChanged()
self._event_dispatcher.register_listeners(
{
OnCommandsFailEvent: [failure_listener],
ActiveDatabaseChanged: [
close_connection_listener,
resubscribe_listener,
],
}
)

View File

@@ -0,0 +1,207 @@
from dataclasses import dataclass, field
from typing import List, Type, Union
import pybreaker
from typing_extensions import Optional
from redis import ConnectionPool, Redis, RedisCluster
from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
from redis.data_structure import WeightedList
from redis.event import EventDispatcher, EventDispatcherInterface
from redis.multidb.circuit import (
DEFAULT_GRACE_PERIOD,
CircuitBreaker,
PBCircuitBreakerAdapter,
)
from redis.multidb.database import Database, Databases
from redis.multidb.failover import (
DEFAULT_FAILOVER_ATTEMPTS,
DEFAULT_FAILOVER_DELAY,
FailoverStrategy,
WeightBasedFailoverStrategy,
)
from redis.multidb.failure_detector import (
DEFAULT_FAILURE_RATE_THRESHOLD,
DEFAULT_FAILURES_DETECTION_WINDOW,
DEFAULT_MIN_NUM_FAILURES,
CommandFailureDetector,
FailureDetector,
)
from redis.multidb.healthcheck import (
DEFAULT_HEALTH_CHECK_DELAY,
DEFAULT_HEALTH_CHECK_INTERVAL,
DEFAULT_HEALTH_CHECK_POLICY,
DEFAULT_HEALTH_CHECK_PROBES,
HealthCheck,
HealthCheckPolicies,
PingHealthCheck,
)
from redis.retry import Retry
DEFAULT_AUTO_FALLBACK_INTERVAL = 120
def default_event_dispatcher() -> EventDispatcherInterface:
return EventDispatcher()
@dataclass
class DatabaseConfig:
"""
Dataclass representing the configuration for a database connection.
This class is used to store configuration settings for a database connection,
including client options, connection sourcing details, circuit breaker settings,
and cluster-specific properties. It provides a structure for defining these
attributes and allows for the creation of customized configurations for various
database setups.
Attributes:
weight (float): Weight of the database to define the active one.
client_kwargs (dict): Additional parameters for the database client connection.
from_url (Optional[str]): Redis URL way of connecting to the database.
from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use.
circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation.
grace_period (float): Grace period after which we need to check if the circuit could be closed again.
health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used
on public Redis Enterprise endpoints.
Methods:
default_circuit_breaker:
Generates and returns a default CircuitBreaker instance adapted for use.
"""
weight: float = 1.0
client_kwargs: dict = field(default_factory=dict)
from_url: Optional[str] = None
from_pool: Optional[ConnectionPool] = None
circuit: Optional[CircuitBreaker] = None
grace_period: float = DEFAULT_GRACE_PERIOD
health_check_url: Optional[str] = None
def default_circuit_breaker(self) -> CircuitBreaker:
circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period)
return PBCircuitBreakerAdapter(circuit_breaker)
@dataclass
class MultiDbConfig:
"""
Configuration class for managing multiple database connections in a resilient and fail-safe manner.
Attributes:
databases_config: A list of database configurations.
client_class: The client class used to manage database connections.
command_retry: Retry strategy for executing database commands.
failure_detectors: Optional list of additional failure detectors for monitoring database failures.
min_num_failures: Minimal count of failures required for failover
failure_rate_threshold: Percentage of failures required for failover
failures_detection_window: Time interval for tracking database failures.
health_checks: Optional list of additional health checks performed on databases.
health_check_interval: Time interval for executing health checks.
health_check_probes: Number of attempts to evaluate the health of a database.
health_check_probes_delay: Delay between health check attempts.
health_check_policy: Policy for determining database health based on health checks.
failover_strategy: Optional strategy for handling database failover scenarios.
failover_attempts: Number of retries allowed for failover operations.
failover_delay: Delay between failover attempts.
auto_fallback_interval: Time interval to trigger automatic fallback.
event_dispatcher: Interface for dispatching events related to database operations.
Methods:
databases:
Retrieves a collection of database clients managed by weighted configurations.
Initializes database clients based on the provided configuration and removes
redundant retry objects for lower-level clients to rely on global retry logic.
default_failure_detectors:
Returns the default list of failure detectors used to monitor database failures.
default_health_checks:
Returns the default list of health checks used to monitor database health
with specific retry and backoff strategies.
default_failover_strategy:
Provides the default failover strategy used for handling failover scenarios
with defined retry and backoff configurations.
"""
databases_config: List[DatabaseConfig]
client_class: Type[Union[Redis, RedisCluster]] = Redis
command_retry: Retry = Retry(
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
)
failure_detectors: Optional[List[FailureDetector]] = None
min_num_failures: int = DEFAULT_MIN_NUM_FAILURES
failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD
failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW
health_checks: Optional[List[HealthCheck]] = None
health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL
health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES
health_check_probes_delay: float = DEFAULT_HEALTH_CHECK_DELAY
health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY
failover_strategy: Optional[FailoverStrategy] = None
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS
failover_delay: float = DEFAULT_FAILOVER_DELAY
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL
event_dispatcher: EventDispatcherInterface = field(
default_factory=default_event_dispatcher
)
def databases(self) -> Databases:
databases = WeightedList()
for database_config in self.databases_config:
# The retry object is not used in the lower level clients, so we can safely remove it.
# We rely on command_retry in terms of global retries.
database_config.client_kwargs.update(
{"retry": Retry(retries=0, backoff=NoBackoff())}
)
if database_config.from_url:
client = self.client_class.from_url(
database_config.from_url, **database_config.client_kwargs
)
elif database_config.from_pool:
database_config.from_pool.set_retry(
Retry(retries=0, backoff=NoBackoff())
)
client = self.client_class.from_pool(
connection_pool=database_config.from_pool
)
else:
client = self.client_class(**database_config.client_kwargs)
circuit = (
database_config.default_circuit_breaker()
if database_config.circuit is None
else database_config.circuit
)
databases.add(
Database(
client=client,
circuit=circuit,
weight=database_config.weight,
health_check_url=database_config.health_check_url,
),
database_config.weight,
)
return databases
def default_failure_detectors(self) -> List[FailureDetector]:
return [
CommandFailureDetector(
min_num_failures=self.min_num_failures,
failure_rate_threshold=self.failure_rate_threshold,
failure_detection_window=self.failures_detection_window,
),
]
def default_health_checks(self) -> List[HealthCheck]:
return [
PingHealthCheck(),
]
def default_failover_strategy(self) -> FailoverStrategy:
return WeightBasedFailoverStrategy()

View File

@@ -0,0 +1,130 @@
from abc import ABC, abstractmethod
from typing import Optional, Union
import redis
from redis import RedisCluster
from redis.data_structure import WeightedList
from redis.multidb.circuit import CircuitBreaker
from redis.typing import Number
class AbstractDatabase(ABC):
@property
@abstractmethod
def weight(self) -> float:
"""The weight of this database in compare to others. Used to determine the database failover to."""
pass
@weight.setter
@abstractmethod
def weight(self, weight: float):
"""Set the weight of this database in compare to others."""
pass
@property
@abstractmethod
def health_check_url(self) -> Optional[str]:
"""Health check URL associated with the current database."""
pass
@health_check_url.setter
@abstractmethod
def health_check_url(self, health_check_url: Optional[str]):
"""Set the health check URL associated with the current database."""
pass
class BaseDatabase(AbstractDatabase):
def __init__(
self,
weight: float,
health_check_url: Optional[str] = None,
):
self._weight = weight
self._health_check_url = health_check_url
@property
def weight(self) -> float:
return self._weight
@weight.setter
def weight(self, weight: float):
self._weight = weight
@property
def health_check_url(self) -> Optional[str]:
return self._health_check_url
@health_check_url.setter
def health_check_url(self, health_check_url: Optional[str]):
self._health_check_url = health_check_url
class SyncDatabase(AbstractDatabase):
"""Database with an underlying synchronous redis client."""
@property
@abstractmethod
def client(self) -> Union[redis.Redis, RedisCluster]:
"""The underlying redis client."""
pass
@client.setter
@abstractmethod
def client(self, client: Union[redis.Redis, RedisCluster]):
"""Set the underlying redis client."""
pass
@property
@abstractmethod
def circuit(self) -> CircuitBreaker:
"""Circuit breaker for the current database."""
pass
@circuit.setter
@abstractmethod
def circuit(self, circuit: CircuitBreaker):
"""Set the circuit breaker for the current database."""
pass
Databases = WeightedList[tuple[SyncDatabase, Number]]
class Database(BaseDatabase, SyncDatabase):
def __init__(
self,
client: Union[redis.Redis, RedisCluster],
circuit: CircuitBreaker,
weight: float,
health_check_url: Optional[str] = None,
):
"""
Initialize a new Database instance.
Args:
client: Underlying Redis client instance for database operations
circuit: Circuit breaker for handling database failures
weight: Weight value used for database failover prioritization
health_check_url: Health check URL associated with the current database
"""
self._client = client
self._cb = circuit
self._cb.database = self
super().__init__(weight, health_check_url)
@property
def client(self) -> Union[redis.Redis, RedisCluster]:
return self._client
@client.setter
def client(self, client: Union[redis.Redis, RedisCluster]):
self._client = client
@property
def circuit(self) -> CircuitBreaker:
return self._cb
@circuit.setter
def circuit(self, circuit: CircuitBreaker):
self._cb = circuit

View File

@@ -0,0 +1,89 @@
from typing import List
from redis.client import Redis
from redis.event import EventListenerInterface, OnCommandsFailEvent
from redis.multidb.database import SyncDatabase
from redis.multidb.failure_detector import FailureDetector
class ActiveDatabaseChanged:
"""
Event fired when an active database has been changed.
"""
def __init__(
self,
old_database: SyncDatabase,
new_database: SyncDatabase,
command_executor,
**kwargs,
):
self._old_database = old_database
self._new_database = new_database
self._command_executor = command_executor
self._kwargs = kwargs
@property
def old_database(self) -> SyncDatabase:
return self._old_database
@property
def new_database(self) -> SyncDatabase:
return self._new_database
@property
def command_executor(self):
return self._command_executor
@property
def kwargs(self):
return self._kwargs
class ResubscribeOnActiveDatabaseChanged(EventListenerInterface):
"""
Re-subscribe the currently active pub / sub to a new active database.
"""
def listen(self, event: ActiveDatabaseChanged):
old_pubsub = event.command_executor.active_pubsub
if old_pubsub is not None:
# Re-assign old channels and patterns so they will be automatically subscribed on connection.
new_pubsub = event.new_database.client.pubsub(**event.kwargs)
new_pubsub.channels = old_pubsub.channels
new_pubsub.patterns = old_pubsub.patterns
new_pubsub.shard_channels = old_pubsub.shard_channels
new_pubsub.on_connect(None)
event.command_executor.active_pubsub = new_pubsub
old_pubsub.close()
class CloseConnectionOnActiveDatabaseChanged(EventListenerInterface):
"""
Close connection to the old active database.
"""
def listen(self, event: ActiveDatabaseChanged):
event.old_database.client.close()
if isinstance(event.old_database.client, Redis):
event.old_database.client.connection_pool.update_active_connections_for_reconnect()
event.old_database.client.connection_pool.disconnect()
else:
for node in event.old_database.client.nodes_manager.nodes_cache.values():
node.redis_connection.connection_pool.update_active_connections_for_reconnect()
node.redis_connection.connection_pool.disconnect()
class RegisterCommandFailure(EventListenerInterface):
"""
Event listener that registers command failures and passing it to the failure detectors.
"""
def __init__(self, failure_detectors: List[FailureDetector]):
self._failure_detectors = failure_detectors
def listen(self, event: OnCommandsFailEvent) -> None:
for failure_detector in self._failure_detectors:
failure_detector.register_failure(event.exception, event.commands)

View File

@@ -0,0 +1,17 @@
class NoValidDatabaseException(Exception):
pass
class UnhealthyDatabaseException(Exception):
"""Exception raised when a database is unhealthy due to an underlying exception."""
def __init__(self, message, database, original_exception):
super().__init__(message)
self.database = database
self.original_exception = original_exception
class TemporaryUnavailableException(Exception):
"""Exception raised when all databases in setup are temporary unavailable."""
pass

View File

@@ -0,0 +1,125 @@
import time
from abc import ABC, abstractmethod
from redis.data_structure import WeightedList
from redis.multidb.circuit import State as CBState
from redis.multidb.database import Databases, SyncDatabase
from redis.multidb.exception import (
NoValidDatabaseException,
TemporaryUnavailableException,
)
DEFAULT_FAILOVER_ATTEMPTS = 10
DEFAULT_FAILOVER_DELAY = 12
class FailoverStrategy(ABC):
@abstractmethod
def database(self) -> SyncDatabase:
"""Select the database according to the strategy."""
pass
@abstractmethod
def set_databases(self, databases: Databases) -> None:
"""Set the database strategy operates on."""
pass
class FailoverStrategyExecutor(ABC):
@property
@abstractmethod
def failover_attempts(self) -> int:
"""The number of failover attempts."""
pass
@property
@abstractmethod
def failover_delay(self) -> float:
"""The delay between failover attempts."""
pass
@property
@abstractmethod
def strategy(self) -> FailoverStrategy:
"""The strategy to execute."""
pass
@abstractmethod
def execute(self) -> SyncDatabase:
"""Execute the failover strategy."""
pass
class WeightBasedFailoverStrategy(FailoverStrategy):
"""
Failover strategy based on database weights.
"""
def __init__(self) -> None:
self._databases = WeightedList()
def database(self) -> SyncDatabase:
for database, _ in self._databases:
if database.circuit.state == CBState.CLOSED:
return database
raise NoValidDatabaseException("No valid database available for communication")
def set_databases(self, databases: Databases) -> None:
self._databases = databases
class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor):
"""
Executes given failover strategy.
"""
def __init__(
self,
strategy: FailoverStrategy,
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
failover_delay: float = DEFAULT_FAILOVER_DELAY,
):
self._strategy = strategy
self._failover_attempts = failover_attempts
self._failover_delay = failover_delay
self._next_attempt_ts: int = 0
self._failover_counter: int = 0
@property
def failover_attempts(self) -> int:
return self._failover_attempts
@property
def failover_delay(self) -> float:
return self._failover_delay
@property
def strategy(self) -> FailoverStrategy:
return self._strategy
def execute(self) -> SyncDatabase:
try:
database = self._strategy.database()
self._reset()
return database
except NoValidDatabaseException as e:
if self._next_attempt_ts == 0:
self._next_attempt_ts = time.time() + self._failover_delay
self._failover_counter += 1
elif time.time() >= self._next_attempt_ts:
self._next_attempt_ts += self._failover_delay
self._failover_counter += 1
if self._failover_counter > self._failover_attempts:
self._reset()
raise e
else:
raise TemporaryUnavailableException(
"No database connections currently available. "
"This is a temporary condition - please retry the operation."
)
def _reset(self) -> None:
self._next_attempt_ts = 0
self._failover_counter = 0

View File

@@ -0,0 +1,104 @@
import math
import threading
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import List, Type
from typing_extensions import Optional
from redis.multidb.circuit import State as CBState
DEFAULT_MIN_NUM_FAILURES = 1000
DEFAULT_FAILURE_RATE_THRESHOLD = 0.1
DEFAULT_FAILURES_DETECTION_WINDOW = 2
class FailureDetector(ABC):
@abstractmethod
def register_failure(self, exception: Exception, cmd: tuple) -> None:
"""Register a failure that occurred during command execution."""
pass
@abstractmethod
def register_command_execution(self, cmd: tuple) -> None:
"""Register a command execution."""
pass
@abstractmethod
def set_command_executor(self, command_executor) -> None:
"""Set the command executor for this failure."""
pass
class CommandFailureDetector(FailureDetector):
"""
Detects a failure based on a threshold of failed commands during a specific period of time.
"""
def __init__(
self,
min_num_failures: int = DEFAULT_MIN_NUM_FAILURES,
failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD,
failure_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW,
error_types: Optional[List[Type[Exception]]] = None,
) -> None:
"""
Initialize a new CommandFailureDetector instance.
Args:
min_num_failures: Minimal count of failures required for failover
failure_rate_threshold: Percentage of failures required for failover
failure_detection_window: Time interval for executing health checks.
error_types: Optional list of exception types to trigger failover. If None, all exceptions are counted.
The detector tracks command failures within a sliding time window. When the number of failures
exceeds the threshold within the specified duration, it triggers failure detection.
"""
self._command_executor = None
self._min_num_failures = min_num_failures
self._failure_rate_threshold = failure_rate_threshold
self._failure_detection_window = failure_detection_window
self._error_types = error_types
self._commands_executed: int = 0
self._start_time: datetime = datetime.now()
self._end_time: datetime = self._start_time + timedelta(
seconds=self._failure_detection_window
)
self._failures_count: int = 0
self._lock = threading.RLock()
def register_failure(self, exception: Exception, cmd: tuple) -> None:
with self._lock:
if self._error_types:
if type(exception) in self._error_types:
self._failures_count += 1
else:
self._failures_count += 1
self._check_threshold()
def set_command_executor(self, command_executor) -> None:
self._command_executor = command_executor
def register_command_execution(self, cmd: tuple) -> None:
with self._lock:
if not self._start_time < datetime.now() < self._end_time:
self._reset()
self._commands_executed += 1
def _check_threshold(self):
if self._failures_count >= self._min_num_failures and self._failures_count >= (
math.ceil(self._commands_executed * self._failure_rate_threshold)
):
self._command_executor.active_database.circuit.state = CBState.OPEN
self._reset()
def _reset(self) -> None:
with self._lock:
self._start_time = datetime.now()
self._end_time = self._start_time + timedelta(
seconds=self._failure_detection_window
)
self._failures_count = 0
self._commands_executed = 0

View File

@@ -0,0 +1,282 @@
import logging
from abc import ABC, abstractmethod
from enum import Enum
from time import sleep
from typing import List, Optional, Tuple, Union
from redis import Redis
from redis.backoff import NoBackoff
from redis.http.http_client import DEFAULT_TIMEOUT, HttpClient
from redis.multidb.exception import UnhealthyDatabaseException
from redis.retry import Retry
DEFAULT_HEALTH_CHECK_PROBES = 3
DEFAULT_HEALTH_CHECK_INTERVAL = 5
DEFAULT_HEALTH_CHECK_DELAY = 0.5
DEFAULT_LAG_AWARE_TOLERANCE = 5000
logger = logging.getLogger(__name__)
class HealthCheck(ABC):
@abstractmethod
def check_health(self, database) -> bool:
"""Function to determine the health status."""
pass
class HealthCheckPolicy(ABC):
"""
Health checks execution policy.
"""
@property
@abstractmethod
def health_check_probes(self) -> int:
"""Number of probes to execute health checks."""
pass
@property
@abstractmethod
def health_check_delay(self) -> float:
"""Delay between health check probes."""
pass
@abstractmethod
def execute(self, health_checks: List[HealthCheck], database) -> bool:
"""Execute health checks and return database health status."""
pass
class AbstractHealthCheckPolicy(HealthCheckPolicy):
def __init__(self, health_check_probes: int, health_check_delay: float):
if health_check_probes < 1:
raise ValueError("health_check_probes must be greater than 0")
self._health_check_probes = health_check_probes
self._health_check_delay = health_check_delay
@property
def health_check_probes(self) -> int:
return self._health_check_probes
@property
def health_check_delay(self) -> float:
return self._health_check_delay
@abstractmethod
def execute(self, health_checks: List[HealthCheck], database) -> bool:
pass
class HealthyAllPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if all health check probes are successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
def execute(self, health_checks: List[HealthCheck], database) -> bool:
for health_check in health_checks:
for attempt in range(self.health_check_probes):
try:
if not health_check.check_health(database):
return False
except Exception as e:
raise UnhealthyDatabaseException("Unhealthy database", database, e)
if attempt < self.health_check_probes - 1:
sleep(self._health_check_delay)
return True
class HealthyMajorityPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if a majority of health check probes are successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
def execute(self, health_checks: List[HealthCheck], database) -> bool:
for health_check in health_checks:
if self.health_check_probes % 2 == 0:
allowed_unsuccessful_probes = self.health_check_probes / 2
else:
allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2
for attempt in range(self.health_check_probes):
try:
if not health_check.check_health(database):
allowed_unsuccessful_probes -= 1
if allowed_unsuccessful_probes <= 0:
return False
except Exception as e:
allowed_unsuccessful_probes -= 1
if allowed_unsuccessful_probes <= 0:
raise UnhealthyDatabaseException(
"Unhealthy database", database, e
)
if attempt < self.health_check_probes - 1:
sleep(self._health_check_delay)
return True
class HealthyAnyPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if at least one health check probe is successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
def execute(self, health_checks: List[HealthCheck], database) -> bool:
is_healthy = False
for health_check in health_checks:
exception = None
for attempt in range(self.health_check_probes):
try:
if health_check.check_health(database):
is_healthy = True
break
else:
is_healthy = False
except Exception as e:
exception = UnhealthyDatabaseException(
"Unhealthy database", database, e
)
if attempt < self.health_check_probes - 1:
sleep(self._health_check_delay)
if not is_healthy and not exception:
return is_healthy
elif not is_healthy and exception:
raise exception
return is_healthy
class HealthCheckPolicies(Enum):
HEALTHY_ALL = HealthyAllPolicy
HEALTHY_MAJORITY = HealthyMajorityPolicy
HEALTHY_ANY = HealthyAnyPolicy
DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL
class PingHealthCheck(HealthCheck):
"""
Health check based on PING command.
"""
def check_health(self, database) -> bool:
if isinstance(database.client, Redis):
return database.client.execute_command("PING")
else:
# For a cluster checks if all nodes are healthy.
all_nodes = database.client.get_nodes()
for node in all_nodes:
if not node.redis_connection.execute_command("PING"):
return False
return True
class LagAwareHealthCheck(HealthCheck):
"""
Health check available for Redis Enterprise deployments.
Verify via REST API that the database is healthy based on different lags.
"""
def __init__(
self,
rest_api_port: int = 9443,
lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE,
timeout: float = DEFAULT_TIMEOUT,
auth_basic: Optional[Tuple[str, str]] = None,
verify_tls: bool = True,
# TLS verification (server) options
ca_file: Optional[str] = None,
ca_path: Optional[str] = None,
ca_data: Optional[Union[str, bytes]] = None,
# Mutual TLS (client cert) options
client_cert_file: Optional[str] = None,
client_key_file: Optional[str] = None,
client_key_password: Optional[str] = None,
):
"""
Initialize LagAwareHealthCheck with the specified parameters.
Args:
rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100)
timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
auth_basic: Tuple of (username, password) for basic authentication
verify_tls: Whether to verify TLS certificates (default: True)
ca_file: Path to CA certificate file for TLS verification
ca_path: Path to CA certificates directory for TLS verification
ca_data: CA certificate data as string or bytes
client_cert_file: Path to client certificate file for mutual TLS
client_key_file: Path to client private key file for mutual TLS
client_key_password: Password for encrypted client private key
"""
self._http_client = HttpClient(
timeout=timeout,
auth_basic=auth_basic,
retry=Retry(NoBackoff(), retries=0),
verify_tls=verify_tls,
ca_file=ca_file,
ca_path=ca_path,
ca_data=ca_data,
client_cert_file=client_cert_file,
client_key_file=client_key_file,
client_key_password=client_key_password,
)
self._rest_api_port = rest_api_port
self._lag_aware_tolerance = lag_aware_tolerance
def check_health(self, database) -> bool:
if database.health_check_url is None:
raise ValueError(
"Database health check url is not set. Please check DatabaseConfig for the current database."
)
if isinstance(database.client, Redis):
db_host = database.client.get_connection_kwargs()["host"]
else:
db_host = database.client.startup_nodes[0].host
base_url = f"{database.health_check_url}:{self._rest_api_port}"
self._http_client.base_url = base_url
# Find bdb matching to the current database host
matching_bdb = None
for bdb in self._http_client.get("/v1/bdbs"):
for endpoint in bdb["endpoints"]:
if endpoint["dns_name"] == db_host:
matching_bdb = bdb
break
# In case if the host was set as public IP
for addr in endpoint["addr"]:
if addr == db_host:
matching_bdb = bdb
break
if matching_bdb is None:
logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
raise ValueError("Could not find a matching bdb")
url = (
f"/v1/bdbs/{matching_bdb['uid']}/availability"
f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}"
)
self._http_client.get(url, expect_json=False)
# Status checked in an http client, otherwise HttpError will be raised
return True