Initial commit: igny8 project

This commit is contained in:
igny8
2025-11-09 10:27:02 +00:00
commit 60b8188111
27265 changed files with 4360521 additions and 0 deletions

View File

@@ -0,0 +1,64 @@
from redis.asyncio.client import Redis, StrictRedis
from redis.asyncio.cluster import RedisCluster
from redis.asyncio.connection import (
BlockingConnectionPool,
Connection,
ConnectionPool,
SSLConnection,
UnixDomainSocketConnection,
)
from redis.asyncio.sentinel import (
Sentinel,
SentinelConnectionPool,
SentinelManagedConnection,
SentinelManagedSSLConnection,
)
from redis.asyncio.utils import from_url
from redis.backoff import default_backoff
from redis.exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
BusyLoadingError,
ChildDeadlockedError,
ConnectionError,
DataError,
InvalidResponse,
OutOfMemoryError,
PubSubError,
ReadOnlyError,
RedisError,
ResponseError,
TimeoutError,
WatchError,
)
__all__ = [
"AuthenticationError",
"AuthenticationWrongNumberOfArgsError",
"BlockingConnectionPool",
"BusyLoadingError",
"ChildDeadlockedError",
"Connection",
"ConnectionError",
"ConnectionPool",
"DataError",
"from_url",
"default_backoff",
"InvalidResponse",
"PubSubError",
"OutOfMemoryError",
"ReadOnlyError",
"Redis",
"RedisCluster",
"RedisError",
"ResponseError",
"Sentinel",
"SentinelConnectionPool",
"SentinelManagedConnection",
"SentinelManagedSSLConnection",
"SSLConnection",
"StrictRedis",
"TimeoutError",
"UnixDomainSocketConnection",
"WatchError",
]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,265 @@
import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Mapping, Optional, Union
from redis.http.http_client import HttpClient, HttpResponse
DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)"
DEFAULT_TIMEOUT = 30.0
RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
class AsyncHTTPClient(ABC):
@abstractmethod
async def get(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP GET request."""
pass
@abstractmethod
async def delete(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP DELETE request."""
pass
@abstractmethod
async def post(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP POST request."""
pass
@abstractmethod
async def put(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP PUT request."""
pass
@abstractmethod
async def patch(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP PATCH request."""
pass
@abstractmethod
async def request(
self,
method: str,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
body: Optional[Union[bytes, str]] = None,
timeout: Optional[float] = None,
) -> HttpResponse:
"""
Invoke HTTP request with given method."""
pass
class AsyncHTTPClientWrapper(AsyncHTTPClient):
"""
An async wrapper around sync HTTP client with thread pool execution.
"""
def __init__(self, client: HttpClient, max_workers: int = 10) -> None:
"""
Initialize a new HTTP client instance.
Args:
client: Sync HTTP client instance.
max_workers: Maximum number of concurrent requests.
The client supports both regular HTTPS with server verification and mutual TLS
authentication. For server verification, provide CA certificate information via
ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client
certificate and key via client_cert_file and client_key_file.
"""
self.client = client
self._executor = ThreadPoolExecutor(max_workers=max_workers)
async def get(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor, self.client.get, path, params, headers, timeout, expect_json
)
async def delete(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.delete,
path,
params,
headers,
timeout,
expect_json,
)
async def post(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.post,
path,
json_body,
data,
params,
headers,
timeout,
expect_json,
)
async def put(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.put,
path,
json_body,
data,
params,
headers,
timeout,
expect_json,
)
async def patch(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.patch,
path,
json_body,
data,
params,
headers,
timeout,
expect_json,
)
async def request(
self,
method: str,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
body: Optional[Union[bytes, str]] = None,
timeout: Optional[float] = None,
) -> HttpResponse:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.request,
method,
path,
params,
headers,
body,
timeout,
)

View File

@@ -0,0 +1,334 @@
import asyncio
import logging
import threading
import uuid
from types import SimpleNamespace
from typing import TYPE_CHECKING, Awaitable, Optional, Union
from redis.exceptions import LockError, LockNotOwnedError
from redis.typing import Number
if TYPE_CHECKING:
from redis.asyncio import Redis, RedisCluster
logger = logging.getLogger(__name__)
class Lock:
"""
A shared, distributed Lock. Using Redis for locking allows the Lock
to be shared across processes and/or machines.
It's left to the user to resolve deadlock issues and make sure
multiple clients play nicely together.
"""
lua_release = None
lua_extend = None
lua_reacquire = None
# KEYS[1] - lock name
# ARGV[1] - token
# return 1 if the lock was released, otherwise 0
LUA_RELEASE_SCRIPT = """
local token = redis.call('get', KEYS[1])
if not token or token ~= ARGV[1] then
return 0
end
redis.call('del', KEYS[1])
return 1
"""
# KEYS[1] - lock name
# ARGV[1] - token
# ARGV[2] - additional milliseconds
# ARGV[3] - "0" if the additional time should be added to the lock's
# existing ttl or "1" if the existing ttl should be replaced
# return 1 if the locks time was extended, otherwise 0
LUA_EXTEND_SCRIPT = """
local token = redis.call('get', KEYS[1])
if not token or token ~= ARGV[1] then
return 0
end
local expiration = redis.call('pttl', KEYS[1])
if not expiration then
expiration = 0
end
if expiration < 0 then
return 0
end
local newttl = ARGV[2]
if ARGV[3] == "0" then
newttl = ARGV[2] + expiration
end
redis.call('pexpire', KEYS[1], newttl)
return 1
"""
# KEYS[1] - lock name
# ARGV[1] - token
# ARGV[2] - milliseconds
# return 1 if the locks time was reacquired, otherwise 0
LUA_REACQUIRE_SCRIPT = """
local token = redis.call('get', KEYS[1])
if not token or token ~= ARGV[1] then
return 0
end
redis.call('pexpire', KEYS[1], ARGV[2])
return 1
"""
def __init__(
self,
redis: Union["Redis", "RedisCluster"],
name: Union[str, bytes, memoryview],
timeout: Optional[float] = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: Optional[Number] = None,
thread_local: bool = True,
raise_on_release_error: bool = True,
):
"""
Create a new Lock instance named ``name`` using the Redis client
supplied by ``redis``.
``timeout`` indicates a maximum life for the lock in seconds.
By default, it will remain locked until release() is called.
``timeout`` can be specified as a float or integer, both representing
the number of seconds to wait.
``sleep`` indicates the amount of time to sleep in seconds per loop
iteration when the lock is in blocking mode and another client is
currently holding the lock.
``blocking`` indicates whether calling ``acquire`` should block until
the lock has been acquired or to fail immediately, causing ``acquire``
to return False and the lock not being acquired. Defaults to True.
Note this value can be overridden by passing a ``blocking``
argument to ``acquire``.
``blocking_timeout`` indicates the maximum amount of time in seconds to
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
float or integer, both representing the number of seconds to wait.
``thread_local`` indicates whether the lock token is placed in
thread-local storage. By default, the token is placed in thread local
storage so that a thread only sees its token, not a token set by
another thread. Consider the following timeline:
time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
thread-1 sets the token to "abc"
time: 1, thread-2 blocks trying to acquire `my-lock` using the
Lock instance.
time: 5, thread-1 has not yet completed. redis expires the lock
key.
time: 5, thread-2 acquired `my-lock` now that it's available.
thread-2 sets the token to "xyz"
time: 6, thread-1 finishes its work and calls release(). if the
token is *not* stored in thread local storage, then
thread-1 would see the token value as "xyz" and would be
able to successfully release the thread-2's lock.
``raise_on_release_error`` indicates whether to raise an exception when
the lock is no longer owned when exiting the context manager. By default,
this is True, meaning an exception will be raised. If False, the warning
will be logged and the exception will be suppressed.
In some use cases it's necessary to disable thread local storage. For
example, if you have code where one thread acquires a lock and passes
that lock instance to a worker thread to release later. If thread
local storage isn't disabled in this case, the worker thread won't see
the token set by the thread that acquired the lock. Our assumption
is that these cases aren't common and as such default to using
thread local storage.
"""
self.redis = redis
self.name = name
self.timeout = timeout
self.sleep = sleep
self.blocking = blocking
self.blocking_timeout = blocking_timeout
self.thread_local = bool(thread_local)
self.local = threading.local() if self.thread_local else SimpleNamespace()
self.raise_on_release_error = raise_on_release_error
self.local.token = None
self.register_scripts()
def register_scripts(self):
cls = self.__class__
client = self.redis
if cls.lua_release is None:
cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT)
if cls.lua_extend is None:
cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT)
if cls.lua_reacquire is None:
cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)
async def __aenter__(self):
if await self.acquire():
return self
raise LockError("Unable to acquire lock within the time specified")
async def __aexit__(self, exc_type, exc_value, traceback):
try:
await self.release()
except LockError:
if self.raise_on_release_error:
raise
logger.warning(
"Lock was unlocked or no longer owned when exiting context manager."
)
async def acquire(
self,
blocking: Optional[bool] = None,
blocking_timeout: Optional[Number] = None,
token: Optional[Union[str, bytes]] = None,
):
"""
Use Redis to hold a shared, distributed lock named ``name``.
Returns True once the lock is acquired.
If ``blocking`` is False, always return immediately. If the lock
was acquired, return True, otherwise return False.
``blocking_timeout`` specifies the maximum number of seconds to
wait trying to acquire the lock.
``token`` specifies the token value to be used. If provided, token
must be a bytes object or a string that can be encoded to a bytes
object with the default encoding. If a token isn't specified, a UUID
will be generated.
"""
sleep = self.sleep
if token is None:
token = uuid.uuid1().hex.encode()
else:
try:
encoder = self.redis.connection_pool.get_encoder()
except AttributeError:
# Cluster
encoder = self.redis.get_encoder()
token = encoder.encode(token)
if blocking is None:
blocking = self.blocking
if blocking_timeout is None:
blocking_timeout = self.blocking_timeout
stop_trying_at = None
if blocking_timeout is not None:
stop_trying_at = asyncio.get_running_loop().time() + blocking_timeout
while True:
if await self.do_acquire(token):
self.local.token = token
return True
if not blocking:
return False
next_try_at = asyncio.get_running_loop().time() + sleep
if stop_trying_at is not None and next_try_at > stop_trying_at:
return False
await asyncio.sleep(sleep)
async def do_acquire(self, token: Union[str, bytes]) -> bool:
if self.timeout:
# convert to milliseconds
timeout = int(self.timeout * 1000)
else:
timeout = None
if await self.redis.set(self.name, token, nx=True, px=timeout):
return True
return False
async def locked(self) -> bool:
"""
Returns True if this key is locked by any process, otherwise False.
"""
return await self.redis.get(self.name) is not None
async def owned(self) -> bool:
"""
Returns True if this key is locked by this lock, otherwise False.
"""
stored_token = await self.redis.get(self.name)
# need to always compare bytes to bytes
# TODO: this can be simplified when the context manager is finished
if stored_token and not isinstance(stored_token, bytes):
try:
encoder = self.redis.connection_pool.get_encoder()
except AttributeError:
# Cluster
encoder = self.redis.get_encoder()
stored_token = encoder.encode(stored_token)
return self.local.token is not None and stored_token == self.local.token
def release(self) -> Awaitable[None]:
"""Releases the already acquired lock"""
expected_token = self.local.token
if expected_token is None:
raise LockError(
"Cannot release a lock that's not owned or is already unlocked.",
lock_name=self.name,
)
self.local.token = None
return self.do_release(expected_token)
async def do_release(self, expected_token: bytes) -> None:
if not bool(
await self.lua_release(
keys=[self.name], args=[expected_token], client=self.redis
)
):
raise LockNotOwnedError("Cannot release a lock that's no longer owned")
def extend(
self, additional_time: Number, replace_ttl: bool = False
) -> Awaitable[bool]:
"""
Adds more time to an already acquired lock.
``additional_time`` can be specified as an integer or a float, both
representing the number of seconds to add.
``replace_ttl`` if False (the default), add `additional_time` to
the lock's existing ttl. If True, replace the lock's ttl with
`additional_time`.
"""
if self.local.token is None:
raise LockError("Cannot extend an unlocked lock")
if self.timeout is None:
raise LockError("Cannot extend a lock with no timeout")
return self.do_extend(additional_time, replace_ttl)
async def do_extend(self, additional_time, replace_ttl) -> bool:
additional_time = int(additional_time * 1000)
if not bool(
await self.lua_extend(
keys=[self.name],
args=[self.local.token, additional_time, replace_ttl and "1" or "0"],
client=self.redis,
)
):
raise LockNotOwnedError("Cannot extend a lock that's no longer owned")
return True
def reacquire(self) -> Awaitable[bool]:
"""
Resets a TTL of an already acquired lock back to a timeout value.
"""
if self.local.token is None:
raise LockError("Cannot reacquire an unlocked lock")
if self.timeout is None:
raise LockError("Cannot reacquire a lock with no timeout")
return self.do_reacquire()
async def do_reacquire(self) -> bool:
timeout = int(self.timeout * 1000)
if not bool(
await self.lua_reacquire(
keys=[self.name], args=[self.local.token, timeout], client=self.redis
)
):
raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned")
return True

View File

@@ -0,0 +1,530 @@
import asyncio
import logging
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Union
from redis.asyncio.client import PubSubHandler
from redis.asyncio.multidb.command_executor import DefaultCommandExecutor
from redis.asyncio.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig
from redis.asyncio.multidb.database import AsyncDatabase, Databases
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
from redis.asyncio.multidb.healthcheck import HealthCheck, HealthCheckPolicy
from redis.background import BackgroundScheduler
from redis.commands import AsyncCoreCommands, AsyncRedisModuleCommands
from redis.multidb.circuit import CircuitBreaker
from redis.multidb.circuit import State as CBState
from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException
from redis.typing import ChannelT, EncodableT, KeyT
from redis.utils import experimental
logger = logging.getLogger(__name__)
@experimental
class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands):
"""
Client that operates on multiple logical Redis databases.
Should be used in Active-Active database setups.
"""
def __init__(self, config: MultiDbConfig):
self._databases = config.databases()
self._health_checks = config.default_health_checks()
if config.health_checks is not None:
self._health_checks.extend(config.health_checks)
self._health_check_interval = config.health_check_interval
self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value(
config.health_check_probes, config.health_check_delay
)
self._failure_detectors = config.default_failure_detectors()
if config.failure_detectors is not None:
self._failure_detectors.extend(config.failure_detectors)
self._failover_strategy = (
config.default_failover_strategy()
if config.failover_strategy is None
else config.failover_strategy
)
self._failover_strategy.set_databases(self._databases)
self._auto_fallback_interval = config.auto_fallback_interval
self._event_dispatcher = config.event_dispatcher
self._command_retry = config.command_retry
self._command_retry.update_supported_errors([ConnectionRefusedError])
self.command_executor = DefaultCommandExecutor(
failure_detectors=self._failure_detectors,
databases=self._databases,
command_retry=self._command_retry,
failover_strategy=self._failover_strategy,
failover_attempts=config.failover_attempts,
failover_delay=config.failover_delay,
event_dispatcher=self._event_dispatcher,
auto_fallback_interval=self._auto_fallback_interval,
)
self.initialized = False
self._hc_lock = asyncio.Lock()
self._bg_scheduler = BackgroundScheduler()
self._config = config
self._recurring_hc_task = None
self._hc_tasks = []
self._half_open_state_task = None
async def __aenter__(self: "MultiDBClient") -> "MultiDBClient":
if not self.initialized:
await self.initialize()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
if self._recurring_hc_task:
self._recurring_hc_task.cancel()
if self._half_open_state_task:
self._half_open_state_task.cancel()
for hc_task in self._hc_tasks:
hc_task.cancel()
async def initialize(self):
"""
Perform initialization of databases to define their initial state.
"""
async def raise_exception_on_failed_hc(error):
raise error
# Initial databases check to define initial state
await self._check_databases_health(on_error=raise_exception_on_failed_hc)
# Starts recurring health checks on the background.
self._recurring_hc_task = asyncio.create_task(
self._bg_scheduler.run_recurring_async(
self._health_check_interval,
self._check_databases_health,
)
)
is_active_db_found = False
for database, weight in self._databases:
# Set on state changed callback for each circuit.
database.circuit.on_state_changed(self._on_circuit_state_change_callback)
# Set states according to a weights and circuit state
if database.circuit.state == CBState.CLOSED and not is_active_db_found:
await self.command_executor.set_active_database(database)
is_active_db_found = True
if not is_active_db_found:
raise NoValidDatabaseException(
"Initial connection failed - no active database found"
)
self.initialized = True
def get_databases(self) -> Databases:
"""
Returns a sorted (by weight) list of all databases.
"""
return self._databases
async def set_active_database(self, database: AsyncDatabase) -> None:
"""
Promote one of the existing databases to become an active.
"""
exists = None
for existing_db, _ in self._databases:
if existing_db == database:
exists = True
break
if not exists:
raise ValueError("Given database is not a member of database list")
await self._check_db_health(database)
if database.circuit.state == CBState.CLOSED:
highest_weighted_db, _ = self._databases.get_top_n(1)[0]
await self.command_executor.set_active_database(database)
return
raise NoValidDatabaseException(
"Cannot set active database, database is unhealthy"
)
async def add_database(self, database: AsyncDatabase):
"""
Adds a new database to the database list.
"""
for existing_db, _ in self._databases:
if existing_db == database:
raise ValueError("Given database already exists")
await self._check_db_health(database)
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
self._databases.add(database, database.weight)
await self._change_active_database(database, highest_weighted_db)
async def _change_active_database(
self, new_database: AsyncDatabase, highest_weight_database: AsyncDatabase
):
if (
new_database.weight > highest_weight_database.weight
and new_database.circuit.state == CBState.CLOSED
):
await self.command_executor.set_active_database(new_database)
async def remove_database(self, database: AsyncDatabase):
"""
Removes a database from the database list.
"""
weight = self._databases.remove(database)
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
if (
highest_weight <= weight
and highest_weighted_db.circuit.state == CBState.CLOSED
):
await self.command_executor.set_active_database(highest_weighted_db)
async def update_database_weight(self, database: AsyncDatabase, weight: float):
"""
Updates a database from the database list.
"""
exists = None
for existing_db, _ in self._databases:
if existing_db == database:
exists = True
break
if not exists:
raise ValueError("Given database is not a member of database list")
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
self._databases.update_weight(database, weight)
database.weight = weight
await self._change_active_database(database, highest_weighted_db)
def add_failure_detector(self, failure_detector: AsyncFailureDetector):
"""
Adds a new failure detector to the database.
"""
self._failure_detectors.append(failure_detector)
async def add_health_check(self, healthcheck: HealthCheck):
"""
Adds a new health check to the database.
"""
async with self._hc_lock:
self._health_checks.append(healthcheck)
async def execute_command(self, *args, **options):
"""
Executes a single command and return its result.
"""
if not self.initialized:
await self.initialize()
return await self.command_executor.execute_command(*args, **options)
def pipeline(self):
"""
Enters into pipeline mode of the client.
"""
return Pipeline(self)
async def transaction(
self,
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
*watches: KeyT,
shard_hint: Optional[str] = None,
value_from_callable: bool = False,
watch_delay: Optional[float] = None,
):
"""
Executes callable as transaction.
"""
if not self.initialized:
await self.initialize()
return await self.command_executor.execute_transaction(
func,
*watches,
shard_hint=shard_hint,
value_from_callable=value_from_callable,
watch_delay=watch_delay,
)
async def pubsub(self, **kwargs):
"""
Return a Publish/Subscribe object. With this object, you can
subscribe to channels and listen for messages that get published to
them.
"""
if not self.initialized:
await self.initialize()
return PubSub(self, **kwargs)
async def _check_databases_health(
self,
on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None,
):
"""
Runs health checks as a recurring task.
Runs health checks against all databases.
"""
try:
self._hc_tasks = [
asyncio.create_task(self._check_db_health(database))
for database, _ in self._databases
]
results = await asyncio.wait_for(
asyncio.gather(
*self._hc_tasks,
return_exceptions=True,
),
timeout=self._health_check_interval,
)
except asyncio.TimeoutError:
raise asyncio.TimeoutError(
"Health check execution exceeds health_check_interval"
)
for result in results:
if isinstance(result, UnhealthyDatabaseException):
unhealthy_db = result.database
unhealthy_db.circuit.state = CBState.OPEN
logger.exception(
"Health check failed, due to exception",
exc_info=result.original_exception,
)
if on_error:
on_error(result.original_exception)
async def _check_db_health(self, database: AsyncDatabase) -> bool:
"""
Runs health checks on the given database until first failure.
"""
# Health check will setup circuit state
is_healthy = await self._health_check_policy.execute(
self._health_checks, database
)
if not is_healthy:
if database.circuit.state != CBState.OPEN:
database.circuit.state = CBState.OPEN
return is_healthy
elif is_healthy and database.circuit.state != CBState.CLOSED:
database.circuit.state = CBState.CLOSED
return is_healthy
def _on_circuit_state_change_callback(
self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState
):
loop = asyncio.get_running_loop()
if new_state == CBState.HALF_OPEN:
self._half_open_state_task = asyncio.create_task(
self._check_db_health(circuit.database)
)
return
if old_state == CBState.CLOSED and new_state == CBState.OPEN:
loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit)
async def aclose(self):
if self.command_executor.active_database:
await self.command_executor.active_database.client.aclose()
def _half_open_circuit(circuit: CircuitBreaker):
circuit.state = CBState.HALF_OPEN
class Pipeline(AsyncRedisModuleCommands, AsyncCoreCommands):
"""
Pipeline implementation for multiple logical Redis databases.
"""
def __init__(self, client: MultiDBClient):
self._command_stack = []
self._client = client
async def __aenter__(self: "Pipeline") -> "Pipeline":
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.reset()
await self._client.__aexit__(exc_type, exc_value, traceback)
def __await__(self):
return self._async_self().__await__()
async def _async_self(self):
return self
def __len__(self) -> int:
return len(self._command_stack)
def __bool__(self) -> bool:
"""Pipeline instances should always evaluate to True"""
return True
async def reset(self) -> None:
self._command_stack = []
async def aclose(self) -> None:
"""Close the pipeline"""
await self.reset()
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
"""
Stage a command to be executed when execute() is next called
Returns the current Pipeline object back so commands can be
chained together, such as:
pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
At some other point, you can then run: pipe.execute(),
which will execute all commands queued in the pipe.
"""
self._command_stack.append((args, options))
return self
def execute_command(self, *args, **kwargs):
"""Adds a command to the stack"""
return self.pipeline_execute_command(*args, **kwargs)
async def execute(self) -> List[Any]:
"""Execute all the commands in the current pipeline"""
if not self._client.initialized:
await self._client.initialize()
try:
return await self._client.command_executor.execute_pipeline(
tuple(self._command_stack)
)
finally:
await self.reset()
class PubSub:
"""
PubSub object for multi database client.
"""
def __init__(self, client: MultiDBClient, **kwargs):
"""Initialize the PubSub object for a multi-database client.
Args:
client: MultiDBClient instance to use for pub/sub operations
**kwargs: Additional keyword arguments to pass to the underlying pubsub implementation
"""
self._client = client
self._client.command_executor.pubsub(**kwargs)
async def __aenter__(self) -> "PubSub":
return self
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
await self.aclose()
async def aclose(self):
return await self._client.command_executor.execute_pubsub_method("aclose")
@property
def subscribed(self) -> bool:
return self._client.command_executor.active_pubsub.subscribed
async def execute_command(self, *args: EncodableT):
return await self._client.command_executor.execute_pubsub_method(
"execute_command", *args
)
async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
"""
Subscribe to channel patterns. Patterns supplied as keyword arguments
expect a pattern name as the key and a callable as the value. A
pattern's callable will be invoked automatically when a message is
received on that pattern rather than producing a message via
``listen()``.
"""
return await self._client.command_executor.execute_pubsub_method(
"psubscribe", *args, **kwargs
)
async def punsubscribe(self, *args: ChannelT):
"""
Unsubscribe from the supplied patterns. If empty, unsubscribe from
all patterns.
"""
return await self._client.command_executor.execute_pubsub_method(
"punsubscribe", *args
)
async def subscribe(self, *args: ChannelT, **kwargs: Callable):
"""
Subscribe to channels. Channels supplied as keyword arguments expect
a channel name as the key and a callable as the value. A channel's
callable will be invoked automatically when a message is received on
that channel rather than producing a message via ``listen()`` or
``get_message()``.
"""
return await self._client.command_executor.execute_pubsub_method(
"subscribe", *args, **kwargs
)
async def unsubscribe(self, *args):
"""
Unsubscribe from the supplied channels. If empty, unsubscribe from
all channels
"""
return await self._client.command_executor.execute_pubsub_method(
"unsubscribe", *args
)
async def get_message(
self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
):
"""
Get the next message if one is available, otherwise None.
If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number or None to wait indefinitely.
"""
return await self._client.command_executor.execute_pubsub_method(
"get_message",
ignore_subscribe_messages=ignore_subscribe_messages,
timeout=timeout,
)
async def run(
self,
*,
exception_handler=None,
poll_timeout: float = 1.0,
) -> None:
"""Process pub/sub messages using registered callbacks.
This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
redis-py, but it is a coroutine. To launch it as a separate task, use
``asyncio.create_task``:
>>> task = asyncio.create_task(pubsub.run())
To shut it down, use asyncio cancellation:
>>> task.cancel()
>>> await task
"""
return await self._client.command_executor.execute_pubsub_run(
sleep_time=poll_timeout, exception_handler=exception_handler, pubsub=self
)

View File

@@ -0,0 +1,339 @@
from abc import abstractmethod
from asyncio import iscoroutinefunction
from datetime import datetime
from typing import Any, Awaitable, Callable, List, Optional, Union
from redis.asyncio import RedisCluster
from redis.asyncio.client import Pipeline, PubSub
from redis.asyncio.multidb.database import AsyncDatabase, Database, Databases
from redis.asyncio.multidb.event import (
AsyncActiveDatabaseChanged,
CloseConnectionOnActiveDatabaseChanged,
RegisterCommandFailure,
ResubscribeOnActiveDatabaseChanged,
)
from redis.asyncio.multidb.failover import (
DEFAULT_FAILOVER_ATTEMPTS,
DEFAULT_FAILOVER_DELAY,
AsyncFailoverStrategy,
DefaultFailoverStrategyExecutor,
FailoverStrategyExecutor,
)
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
from redis.asyncio.retry import Retry
from redis.event import AsyncOnCommandsFailEvent, EventDispatcherInterface
from redis.multidb.circuit import State as CBState
from redis.multidb.command_executor import BaseCommandExecutor, CommandExecutor
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
from redis.typing import KeyT
class AsyncCommandExecutor(CommandExecutor):
@property
@abstractmethod
def databases(self) -> Databases:
"""Returns a list of databases."""
pass
@property
@abstractmethod
def failure_detectors(self) -> List[AsyncFailureDetector]:
"""Returns a list of failure detectors."""
pass
@abstractmethod
def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
"""Adds a new failure detector to the list of failure detectors."""
pass
@property
@abstractmethod
def active_database(self) -> Optional[AsyncDatabase]:
"""Returns currently active database."""
pass
@abstractmethod
async def set_active_database(self, database: AsyncDatabase) -> None:
"""Sets the currently active database."""
pass
@property
@abstractmethod
def active_pubsub(self) -> Optional[PubSub]:
"""Returns currently active pubsub."""
pass
@active_pubsub.setter
@abstractmethod
def active_pubsub(self, pubsub: PubSub) -> None:
"""Sets currently active pubsub."""
pass
@property
@abstractmethod
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
"""Returns failover strategy executor."""
pass
@property
@abstractmethod
def command_retry(self) -> Retry:
"""Returns command retry object."""
pass
@abstractmethod
async def pubsub(self, **kwargs):
"""Initializes a PubSub object on a currently active database"""
pass
@abstractmethod
async def execute_command(self, *args, **options):
"""Executes a command and returns the result."""
pass
@abstractmethod
async def execute_pipeline(self, command_stack: tuple):
"""Executes a stack of commands in pipeline."""
pass
@abstractmethod
async def execute_transaction(
self, transaction: Callable[[Pipeline], None], *watches, **options
):
"""Executes a transaction block wrapped in callback."""
pass
@abstractmethod
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
"""Executes a given method on active pub/sub."""
pass
@abstractmethod
async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any:
"""Executes pub/sub run in a thread."""
pass
class DefaultCommandExecutor(BaseCommandExecutor, AsyncCommandExecutor):
def __init__(
self,
failure_detectors: List[AsyncFailureDetector],
databases: Databases,
command_retry: Retry,
failover_strategy: AsyncFailoverStrategy,
event_dispatcher: EventDispatcherInterface,
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
failover_delay: float = DEFAULT_FAILOVER_DELAY,
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
):
"""
Initialize the DefaultCommandExecutor instance.
Args:
failure_detectors: List of failure detector instances to monitor database health
databases: Collection of available databases to execute commands on
command_retry: Retry policy for failed command execution
failover_strategy: Strategy for handling database failover
event_dispatcher: Interface for dispatching events
failover_attempts: Number of failover attempts
failover_delay: Delay between failover attempts
auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database
"""
super().__init__(auto_fallback_interval)
for fd in failure_detectors:
fd.set_command_executor(command_executor=self)
self._databases = databases
self._failure_detectors = failure_detectors
self._command_retry = command_retry
self._failover_strategy_executor = DefaultFailoverStrategyExecutor(
failover_strategy, failover_attempts, failover_delay
)
self._event_dispatcher = event_dispatcher
self._active_database: Optional[Database] = None
self._active_pubsub: Optional[PubSub] = None
self._active_pubsub_kwargs = {}
self._setup_event_dispatcher()
self._schedule_next_fallback()
@property
def databases(self) -> Databases:
return self._databases
@property
def failure_detectors(self) -> List[AsyncFailureDetector]:
return self._failure_detectors
def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
self._failure_detectors.append(failure_detector)
@property
def active_database(self) -> Optional[AsyncDatabase]:
return self._active_database
async def set_active_database(self, database: AsyncDatabase) -> None:
old_active = self._active_database
self._active_database = database
if old_active is not None and old_active is not database:
await self._event_dispatcher.dispatch_async(
AsyncActiveDatabaseChanged(
old_active,
self._active_database,
self,
**self._active_pubsub_kwargs,
)
)
@property
def active_pubsub(self) -> Optional[PubSub]:
return self._active_pubsub
@active_pubsub.setter
def active_pubsub(self, pubsub: PubSub) -> None:
self._active_pubsub = pubsub
@property
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
return self._failover_strategy_executor
@property
def command_retry(self) -> Retry:
return self._command_retry
def pubsub(self, **kwargs):
if self._active_pubsub is None:
if isinstance(self._active_database.client, RedisCluster):
raise ValueError("PubSub is not supported for RedisCluster")
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
self._active_pubsub_kwargs = kwargs
async def execute_command(self, *args, **options):
async def callback():
response = await self._active_database.client.execute_command(
*args, **options
)
await self._register_command_execution(args)
return response
return await self._execute_with_failure_detection(callback, args)
async def execute_pipeline(self, command_stack: tuple):
async def callback():
async with self._active_database.client.pipeline() as pipe:
for command, options in command_stack:
pipe.execute_command(*command, **options)
response = await pipe.execute()
await self._register_command_execution(command_stack)
return response
return await self._execute_with_failure_detection(callback, command_stack)
async def execute_transaction(
self,
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
*watches: KeyT,
shard_hint: Optional[str] = None,
value_from_callable: bool = False,
watch_delay: Optional[float] = None,
):
async def callback():
response = await self._active_database.client.transaction(
func,
*watches,
shard_hint=shard_hint,
value_from_callable=value_from_callable,
watch_delay=watch_delay,
)
await self._register_command_execution(())
return response
return await self._execute_with_failure_detection(callback)
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
async def callback():
method = getattr(self.active_pubsub, method_name)
if iscoroutinefunction(method):
response = await method(*args, **kwargs)
else:
response = method(*args, **kwargs)
await self._register_command_execution(args)
return response
return await self._execute_with_failure_detection(callback, *args)
async def execute_pubsub_run(
self, sleep_time: float, exception_handler=None, pubsub=None
) -> Any:
async def callback():
return await self._active_pubsub.run(
poll_timeout=sleep_time,
exception_handler=exception_handler,
pubsub=pubsub,
)
return await self._execute_with_failure_detection(callback)
async def _execute_with_failure_detection(
self, callback: Callable, cmds: tuple = ()
):
"""
Execute a commands execution callback with failure detection.
"""
async def wrapper():
# On each retry we need to check active database as it might change.
await self._check_active_database()
return await callback()
return await self._command_retry.call_with_retry(
lambda: wrapper(),
lambda error: self._on_command_fail(error, *cmds),
)
async def _check_active_database(self):
"""
Checks if active a database needs to be updated.
"""
if (
self._active_database is None
or self._active_database.circuit.state != CBState.CLOSED
or (
self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL
and self._next_fallback_attempt <= datetime.now()
)
):
await self.set_active_database(
await self._failover_strategy_executor.execute()
)
self._schedule_next_fallback()
async def _on_command_fail(self, error, *args):
await self._event_dispatcher.dispatch_async(
AsyncOnCommandsFailEvent(args, error)
)
async def _register_command_execution(self, cmd: tuple):
for detector in self._failure_detectors:
await detector.register_command_execution(cmd)
def _setup_event_dispatcher(self):
"""
Registers necessary listeners.
"""
failure_listener = RegisterCommandFailure(self._failure_detectors)
resubscribe_listener = ResubscribeOnActiveDatabaseChanged()
close_connection_listener = CloseConnectionOnActiveDatabaseChanged()
self._event_dispatcher.register_listeners(
{
AsyncOnCommandsFailEvent: [failure_listener],
AsyncActiveDatabaseChanged: [
close_connection_listener,
resubscribe_listener,
],
}
)

View File

@@ -0,0 +1,210 @@
from dataclasses import dataclass, field
from typing import List, Optional, Type, Union
import pybreaker
from redis.asyncio import ConnectionPool, Redis, RedisCluster
from redis.asyncio.multidb.database import Database, Databases
from redis.asyncio.multidb.failover import (
DEFAULT_FAILOVER_ATTEMPTS,
DEFAULT_FAILOVER_DELAY,
AsyncFailoverStrategy,
WeightBasedFailoverStrategy,
)
from redis.asyncio.multidb.failure_detector import (
AsyncFailureDetector,
FailureDetectorAsyncWrapper,
)
from redis.asyncio.multidb.healthcheck import (
DEFAULT_HEALTH_CHECK_DELAY,
DEFAULT_HEALTH_CHECK_INTERVAL,
DEFAULT_HEALTH_CHECK_POLICY,
DEFAULT_HEALTH_CHECK_PROBES,
HealthCheck,
HealthCheckPolicies,
PingHealthCheck,
)
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
from redis.data_structure import WeightedList
from redis.event import EventDispatcher, EventDispatcherInterface
from redis.multidb.circuit import (
DEFAULT_GRACE_PERIOD,
CircuitBreaker,
PBCircuitBreakerAdapter,
)
from redis.multidb.failure_detector import (
DEFAULT_FAILURE_RATE_THRESHOLD,
DEFAULT_FAILURES_DETECTION_WINDOW,
DEFAULT_MIN_NUM_FAILURES,
CommandFailureDetector,
)
DEFAULT_AUTO_FALLBACK_INTERVAL = 120
def default_event_dispatcher() -> EventDispatcherInterface:
return EventDispatcher()
@dataclass
class DatabaseConfig:
"""
Dataclass representing the configuration for a database connection.
This class is used to store configuration settings for a database connection,
including client options, connection sourcing details, circuit breaker settings,
and cluster-specific properties. It provides a structure for defining these
attributes and allows for the creation of customized configurations for various
database setups.
Attributes:
weight (float): Weight of the database to define the active one.
client_kwargs (dict): Additional parameters for the database client connection.
from_url (Optional[str]): Redis URL way of connecting to the database.
from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use.
circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation.
grace_period (float): Grace period after which we need to check if the circuit could be closed again.
health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used
on public Redis Enterprise endpoints.
Methods:
default_circuit_breaker:
Generates and returns a default CircuitBreaker instance adapted for use.
"""
weight: float = 1.0
client_kwargs: dict = field(default_factory=dict)
from_url: Optional[str] = None
from_pool: Optional[ConnectionPool] = None
circuit: Optional[CircuitBreaker] = None
grace_period: float = DEFAULT_GRACE_PERIOD
health_check_url: Optional[str] = None
def default_circuit_breaker(self) -> CircuitBreaker:
circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period)
return PBCircuitBreakerAdapter(circuit_breaker)
@dataclass
class MultiDbConfig:
"""
Configuration class for managing multiple database connections in a resilient and fail-safe manner.
Attributes:
databases_config: A list of database configurations.
client_class: The client class used to manage database connections.
command_retry: Retry strategy for executing database commands.
failure_detectors: Optional list of additional failure detectors for monitoring database failures.
min_num_failures: Minimal count of failures required for failover
failure_rate_threshold: Percentage of failures required for failover
failures_detection_window: Time interval for tracking database failures.
health_checks: Optional list of additional health checks performed on databases.
health_check_interval: Time interval for executing health checks.
health_check_probes: Number of attempts to evaluate the health of a database.
health_check_delay: Delay between health check attempts.
failover_strategy: Optional strategy for handling database failover scenarios.
failover_attempts: Number of retries allowed for failover operations.
failover_delay: Delay between failover attempts.
auto_fallback_interval: Time interval to trigger automatic fallback.
event_dispatcher: Interface for dispatching events related to database operations.
Methods:
databases:
Retrieves a collection of database clients managed by weighted configurations.
Initializes database clients based on the provided configuration and removes
redundant retry objects for lower-level clients to rely on global retry logic.
default_failure_detectors:
Returns the default list of failure detectors used to monitor database failures.
default_health_checks:
Returns the default list of health checks used to monitor database health
with specific retry and backoff strategies.
default_failover_strategy:
Provides the default failover strategy used for handling failover scenarios
with defined retry and backoff configurations.
"""
databases_config: List[DatabaseConfig]
client_class: Type[Union[Redis, RedisCluster]] = Redis
command_retry: Retry = Retry(
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
)
failure_detectors: Optional[List[AsyncFailureDetector]] = None
min_num_failures: int = DEFAULT_MIN_NUM_FAILURES
failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD
failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW
health_checks: Optional[List[HealthCheck]] = None
health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL
health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES
health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY
health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY
failover_strategy: Optional[AsyncFailoverStrategy] = None
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS
failover_delay: float = DEFAULT_FAILOVER_DELAY
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL
event_dispatcher: EventDispatcherInterface = field(
default_factory=default_event_dispatcher
)
def databases(self) -> Databases:
databases = WeightedList()
for database_config in self.databases_config:
# The retry object is not used in the lower level clients, so we can safely remove it.
# We rely on command_retry in terms of global retries.
database_config.client_kwargs.update(
{"retry": Retry(retries=0, backoff=NoBackoff())}
)
if database_config.from_url:
client = self.client_class.from_url(
database_config.from_url, **database_config.client_kwargs
)
elif database_config.from_pool:
database_config.from_pool.set_retry(
Retry(retries=0, backoff=NoBackoff())
)
client = self.client_class.from_pool(
connection_pool=database_config.from_pool
)
else:
client = self.client_class(**database_config.client_kwargs)
circuit = (
database_config.default_circuit_breaker()
if database_config.circuit is None
else database_config.circuit
)
databases.add(
Database(
client=client,
circuit=circuit,
weight=database_config.weight,
health_check_url=database_config.health_check_url,
),
database_config.weight,
)
return databases
def default_failure_detectors(self) -> List[AsyncFailureDetector]:
return [
FailureDetectorAsyncWrapper(
CommandFailureDetector(
min_num_failures=self.min_num_failures,
failure_rate_threshold=self.failure_rate_threshold,
failure_detection_window=self.failures_detection_window,
)
),
]
def default_health_checks(self) -> List[HealthCheck]:
return [
PingHealthCheck(),
]
def default_failover_strategy(self) -> AsyncFailoverStrategy:
return WeightBasedFailoverStrategy()

View File

@@ -0,0 +1,69 @@
from abc import abstractmethod
from typing import Optional, Union
from redis.asyncio import Redis, RedisCluster
from redis.data_structure import WeightedList
from redis.multidb.circuit import CircuitBreaker
from redis.multidb.database import AbstractDatabase, BaseDatabase
from redis.typing import Number
class AsyncDatabase(AbstractDatabase):
"""Database with an underlying asynchronous redis client."""
@property
@abstractmethod
def client(self) -> Union[Redis, RedisCluster]:
"""The underlying redis client."""
pass
@client.setter
@abstractmethod
def client(self, client: Union[Redis, RedisCluster]):
"""Set the underlying redis client."""
pass
@property
@abstractmethod
def circuit(self) -> CircuitBreaker:
"""Circuit breaker for the current database."""
pass
@circuit.setter
@abstractmethod
def circuit(self, circuit: CircuitBreaker):
"""Set the circuit breaker for the current database."""
pass
Databases = WeightedList[tuple[AsyncDatabase, Number]]
class Database(BaseDatabase, AsyncDatabase):
def __init__(
self,
client: Union[Redis, RedisCluster],
circuit: CircuitBreaker,
weight: float,
health_check_url: Optional[str] = None,
):
self._client = client
self._cb = circuit
self._cb.database = self
super().__init__(weight, health_check_url)
@property
def client(self) -> Union[Redis, RedisCluster]:
return self._client
@client.setter
def client(self, client: Union[Redis, RedisCluster]):
self._client = client
@property
def circuit(self) -> CircuitBreaker:
return self._cb
@circuit.setter
def circuit(self, circuit: CircuitBreaker):
self._cb = circuit

View File

@@ -0,0 +1,84 @@
from typing import List
from redis.asyncio import Redis
from redis.asyncio.multidb.database import AsyncDatabase
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
from redis.event import AsyncEventListenerInterface, AsyncOnCommandsFailEvent
class AsyncActiveDatabaseChanged:
"""
Event fired when an async active database has been changed.
"""
def __init__(
self,
old_database: AsyncDatabase,
new_database: AsyncDatabase,
command_executor,
**kwargs,
):
self._old_database = old_database
self._new_database = new_database
self._command_executor = command_executor
self._kwargs = kwargs
@property
def old_database(self) -> AsyncDatabase:
return self._old_database
@property
def new_database(self) -> AsyncDatabase:
return self._new_database
@property
def command_executor(self):
return self._command_executor
@property
def kwargs(self):
return self._kwargs
class ResubscribeOnActiveDatabaseChanged(AsyncEventListenerInterface):
"""
Re-subscribe the currently active pub / sub to a new active database.
"""
async def listen(self, event: AsyncActiveDatabaseChanged):
old_pubsub = event.command_executor.active_pubsub
if old_pubsub is not None:
# Re-assign old channels and patterns so they will be automatically subscribed on connection.
new_pubsub = event.new_database.client.pubsub(**event.kwargs)
new_pubsub.channels = old_pubsub.channels
new_pubsub.patterns = old_pubsub.patterns
await new_pubsub.on_connect(None)
event.command_executor.active_pubsub = new_pubsub
await old_pubsub.aclose()
class CloseConnectionOnActiveDatabaseChanged(AsyncEventListenerInterface):
"""
Close connection to the old active database.
"""
async def listen(self, event: AsyncActiveDatabaseChanged):
await event.old_database.client.aclose()
if isinstance(event.old_database.client, Redis):
await event.old_database.client.connection_pool.update_active_connections_for_reconnect()
await event.old_database.client.connection_pool.disconnect()
class RegisterCommandFailure(AsyncEventListenerInterface):
"""
Event listener that registers command failures and passing it to the failure detectors.
"""
def __init__(self, failure_detectors: List[AsyncFailureDetector]):
self._failure_detectors = failure_detectors
async def listen(self, event: AsyncOnCommandsFailEvent) -> None:
for failure_detector in self._failure_detectors:
await failure_detector.register_failure(event.exception, event.commands)

View File

@@ -0,0 +1,125 @@
import time
from abc import ABC, abstractmethod
from redis.asyncio.multidb.database import AsyncDatabase, Databases
from redis.data_structure import WeightedList
from redis.multidb.circuit import State as CBState
from redis.multidb.exception import (
NoValidDatabaseException,
TemporaryUnavailableException,
)
DEFAULT_FAILOVER_ATTEMPTS = 10
DEFAULT_FAILOVER_DELAY = 12
class AsyncFailoverStrategy(ABC):
@abstractmethod
async def database(self) -> AsyncDatabase:
"""Select the database according to the strategy."""
pass
@abstractmethod
def set_databases(self, databases: Databases) -> None:
"""Set the database strategy operates on."""
pass
class FailoverStrategyExecutor(ABC):
@property
@abstractmethod
def failover_attempts(self) -> int:
"""The number of failover attempts."""
pass
@property
@abstractmethod
def failover_delay(self) -> float:
"""The delay between failover attempts."""
pass
@property
@abstractmethod
def strategy(self) -> AsyncFailoverStrategy:
"""The strategy to execute."""
pass
@abstractmethod
async def execute(self) -> AsyncDatabase:
"""Execute the failover strategy."""
pass
class WeightBasedFailoverStrategy(AsyncFailoverStrategy):
"""
Failover strategy based on database weights.
"""
def __init__(self):
self._databases = WeightedList()
async def database(self) -> AsyncDatabase:
for database, _ in self._databases:
if database.circuit.state == CBState.CLOSED:
return database
raise NoValidDatabaseException("No valid database available for communication")
def set_databases(self, databases: Databases) -> None:
self._databases = databases
class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor):
"""
Executes given failover strategy.
"""
def __init__(
self,
strategy: AsyncFailoverStrategy,
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
failover_delay: float = DEFAULT_FAILOVER_DELAY,
):
self._strategy = strategy
self._failover_attempts = failover_attempts
self._failover_delay = failover_delay
self._next_attempt_ts: int = 0
self._failover_counter: int = 0
@property
def failover_attempts(self) -> int:
return self._failover_attempts
@property
def failover_delay(self) -> float:
return self._failover_delay
@property
def strategy(self) -> AsyncFailoverStrategy:
return self._strategy
async def execute(self) -> AsyncDatabase:
try:
database = await self._strategy.database()
self._reset()
return database
except NoValidDatabaseException as e:
if self._next_attempt_ts == 0:
self._next_attempt_ts = time.time() + self._failover_delay
self._failover_counter += 1
elif time.time() >= self._next_attempt_ts:
self._next_attempt_ts += self._failover_delay
self._failover_counter += 1
if self._failover_counter > self._failover_attempts:
self._reset()
raise e
else:
raise TemporaryUnavailableException(
"No database connections currently available. "
"This is a temporary condition - please retry the operation."
)
def _reset(self) -> None:
self._next_attempt_ts = 0
self._failover_counter = 0

View File

@@ -0,0 +1,38 @@
from abc import ABC, abstractmethod
from redis.multidb.failure_detector import FailureDetector
class AsyncFailureDetector(ABC):
@abstractmethod
async def register_failure(self, exception: Exception, cmd: tuple) -> None:
"""Register a failure that occurred during command execution."""
pass
@abstractmethod
async def register_command_execution(self, cmd: tuple) -> None:
"""Register a command execution."""
pass
@abstractmethod
def set_command_executor(self, command_executor) -> None:
"""Set the command executor for this failure."""
pass
class FailureDetectorAsyncWrapper(AsyncFailureDetector):
"""
Async wrapper for the failure detector.
"""
def __init__(self, failure_detector: FailureDetector) -> None:
self._failure_detector = failure_detector
async def register_failure(self, exception: Exception, cmd: tuple) -> None:
self._failure_detector.register_failure(exception, cmd)
async def register_command_execution(self, cmd: tuple) -> None:
self._failure_detector.register_command_execution(cmd)
def set_command_executor(self, command_executor) -> None:
self._failure_detector.set_command_executor(command_executor)

View File

@@ -0,0 +1,285 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Optional, Tuple, Union
from redis.asyncio import Redis
from redis.asyncio.http.http_client import DEFAULT_TIMEOUT, AsyncHTTPClientWrapper
from redis.backoff import NoBackoff
from redis.http.http_client import HttpClient
from redis.multidb.exception import UnhealthyDatabaseException
from redis.retry import Retry
DEFAULT_HEALTH_CHECK_PROBES = 3
DEFAULT_HEALTH_CHECK_INTERVAL = 5
DEFAULT_HEALTH_CHECK_DELAY = 0.5
DEFAULT_LAG_AWARE_TOLERANCE = 5000
logger = logging.getLogger(__name__)
class HealthCheck(ABC):
@abstractmethod
async def check_health(self, database) -> bool:
"""Function to determine the health status."""
pass
class HealthCheckPolicy(ABC):
"""
Health checks execution policy.
"""
@property
@abstractmethod
def health_check_probes(self) -> int:
"""Number of probes to execute health checks."""
pass
@property
@abstractmethod
def health_check_delay(self) -> float:
"""Delay between health check probes."""
pass
@abstractmethod
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
"""Execute health checks and return database health status."""
pass
class AbstractHealthCheckPolicy(HealthCheckPolicy):
def __init__(self, health_check_probes: int, health_check_delay: float):
if health_check_probes < 1:
raise ValueError("health_check_probes must be greater than 0")
self._health_check_probes = health_check_probes
self._health_check_delay = health_check_delay
@property
def health_check_probes(self) -> int:
return self._health_check_probes
@property
def health_check_delay(self) -> float:
return self._health_check_delay
@abstractmethod
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
pass
class HealthyAllPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if all health check probes are successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
for health_check in health_checks:
for attempt in range(self.health_check_probes):
try:
if not await health_check.check_health(database):
return False
except Exception as e:
raise UnhealthyDatabaseException("Unhealthy database", database, e)
if attempt < self.health_check_probes - 1:
await asyncio.sleep(self._health_check_delay)
return True
class HealthyMajorityPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if a majority of health check probes are successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
for health_check in health_checks:
if self.health_check_probes % 2 == 0:
allowed_unsuccessful_probes = self.health_check_probes / 2
else:
allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2
for attempt in range(self.health_check_probes):
try:
if not await health_check.check_health(database):
allowed_unsuccessful_probes -= 1
if allowed_unsuccessful_probes <= 0:
return False
except Exception as e:
allowed_unsuccessful_probes -= 1
if allowed_unsuccessful_probes <= 0:
raise UnhealthyDatabaseException(
"Unhealthy database", database, e
)
if attempt < self.health_check_probes - 1:
await asyncio.sleep(self._health_check_delay)
return True
class HealthyAnyPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if at least one health check probe is successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
is_healthy = False
for health_check in health_checks:
exception = None
for attempt in range(self.health_check_probes):
try:
if await health_check.check_health(database):
is_healthy = True
break
else:
is_healthy = False
except Exception as e:
exception = UnhealthyDatabaseException(
"Unhealthy database", database, e
)
if attempt < self.health_check_probes - 1:
await asyncio.sleep(self._health_check_delay)
if not is_healthy and not exception:
return is_healthy
elif not is_healthy and exception:
raise exception
return is_healthy
class HealthCheckPolicies(Enum):
HEALTHY_ALL = HealthyAllPolicy
HEALTHY_MAJORITY = HealthyMajorityPolicy
HEALTHY_ANY = HealthyAnyPolicy
DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL
class PingHealthCheck(HealthCheck):
"""
Health check based on PING command.
"""
async def check_health(self, database) -> bool:
if isinstance(database.client, Redis):
return await database.client.execute_command("PING")
else:
# For a cluster checks if all nodes are healthy.
all_nodes = database.client.get_nodes()
for node in all_nodes:
if not await node.redis_connection.execute_command("PING"):
return False
return True
class LagAwareHealthCheck(HealthCheck):
"""
Health check available for Redis Enterprise deployments.
Verify via REST API that the database is healthy based on different lags.
"""
def __init__(
self,
rest_api_port: int = 9443,
lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE,
timeout: float = DEFAULT_TIMEOUT,
auth_basic: Optional[Tuple[str, str]] = None,
verify_tls: bool = True,
# TLS verification (server) options
ca_file: Optional[str] = None,
ca_path: Optional[str] = None,
ca_data: Optional[Union[str, bytes]] = None,
# Mutual TLS (client cert) options
client_cert_file: Optional[str] = None,
client_key_file: Optional[str] = None,
client_key_password: Optional[str] = None,
):
"""
Initialize LagAwareHealthCheck with the specified parameters.
Args:
rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100)
timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
auth_basic: Tuple of (username, password) for basic authentication
verify_tls: Whether to verify TLS certificates (default: True)
ca_file: Path to CA certificate file for TLS verification
ca_path: Path to CA certificates directory for TLS verification
ca_data: CA certificate data as string or bytes
client_cert_file: Path to client certificate file for mutual TLS
client_key_file: Path to client private key file for mutual TLS
client_key_password: Password for encrypted client private key
"""
self._http_client = AsyncHTTPClientWrapper(
HttpClient(
timeout=timeout,
auth_basic=auth_basic,
retry=Retry(NoBackoff(), retries=0),
verify_tls=verify_tls,
ca_file=ca_file,
ca_path=ca_path,
ca_data=ca_data,
client_cert_file=client_cert_file,
client_key_file=client_key_file,
client_key_password=client_key_password,
)
)
self._rest_api_port = rest_api_port
self._lag_aware_tolerance = lag_aware_tolerance
async def check_health(self, database) -> bool:
if database.health_check_url is None:
raise ValueError(
"Database health check url is not set. Please check DatabaseConfig for the current database."
)
if isinstance(database.client, Redis):
db_host = database.client.get_connection_kwargs()["host"]
else:
db_host = database.client.startup_nodes[0].host
base_url = f"{database.health_check_url}:{self._rest_api_port}"
self._http_client.client.base_url = base_url
# Find bdb matching to the current database host
matching_bdb = None
for bdb in await self._http_client.get("/v1/bdbs"):
for endpoint in bdb["endpoints"]:
if endpoint["dns_name"] == db_host:
matching_bdb = bdb
break
# In case if the host was set as public IP
for addr in endpoint["addr"]:
if addr == db_host:
matching_bdb = bdb
break
if matching_bdb is None:
logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
raise ValueError("Could not find a matching bdb")
url = (
f"/v1/bdbs/{matching_bdb['uid']}/availability"
f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}"
)
await self._http_client.get(url, expect_json=False)
# Status checked in an http client, otherwise HttpError will be raised
return True

View File

@@ -0,0 +1,58 @@
from asyncio import sleep
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar
from redis.exceptions import ConnectionError, RedisError, TimeoutError
from redis.retry import AbstractRetry
T = TypeVar("T")
if TYPE_CHECKING:
from redis.backoff import AbstractBackoff
class Retry(AbstractRetry[RedisError]):
__hash__ = AbstractRetry.__hash__
def __init__(
self,
backoff: "AbstractBackoff",
retries: int,
supported_errors: Tuple[Type[RedisError], ...] = (
ConnectionError,
TimeoutError,
),
):
super().__init__(backoff, retries, supported_errors)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Retry):
return NotImplemented
return (
self._backoff == other._backoff
and self._retries == other._retries
and set(self._supported_errors) == set(other._supported_errors)
)
async def call_with_retry(
self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any]
) -> T:
"""
Execute an operation that might fail and returns its result, or
raise the exception that was thrown depending on the `Backoff` object.
`do`: the operation to call. Expects no argument.
`fail`: the failure handler, expects the last error that was thrown
"""
self._backoff.reset()
failures = 0
while True:
try:
return await do()
except self._supported_errors as error:
failures += 1
await fail(error)
if self._retries >= 0 and failures > self._retries:
raise error
backoff = self._backoff.compute(failures)
if backoff > 0:
await sleep(backoff)

View File

@@ -0,0 +1,404 @@
import asyncio
import random
import weakref
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type
from redis.asyncio.client import Redis
from redis.asyncio.connection import (
Connection,
ConnectionPool,
EncodableT,
SSLConnection,
)
from redis.commands import AsyncSentinelCommands
from redis.exceptions import (
ConnectionError,
ReadOnlyError,
ResponseError,
TimeoutError,
)
class MasterNotFoundError(ConnectionError):
pass
class SlaveNotFoundError(ConnectionError):
pass
class SentinelManagedConnection(Connection):
def __init__(self, **kwargs):
self.connection_pool = kwargs.pop("connection_pool")
super().__init__(**kwargs)
def __repr__(self):
s = f"<{self.__class__.__module__}.{self.__class__.__name__}"
if self.host:
host_info = f",host={self.host},port={self.port}"
s += host_info
return s + ")>"
async def connect_to(self, address):
self.host, self.port = address
await self.connect_check_health(
check_health=self.connection_pool.check_connection,
retry_socket_connect=False,
)
async def _connect_retry(self):
if self._reader:
return # already connected
if self.connection_pool.is_master:
await self.connect_to(await self.connection_pool.get_master_address())
else:
async for slave in self.connection_pool.rotate_slaves():
try:
return await self.connect_to(slave)
except ConnectionError:
continue
raise SlaveNotFoundError # Never be here
async def connect(self):
return await self.retry.call_with_retry(
self._connect_retry,
lambda error: asyncio.sleep(0),
)
async def read_response(
self,
disable_decoding: bool = False,
timeout: Optional[float] = None,
*,
disconnect_on_error: Optional[float] = True,
push_request: Optional[bool] = False,
):
try:
return await super().read_response(
disable_decoding=disable_decoding,
timeout=timeout,
disconnect_on_error=disconnect_on_error,
push_request=push_request,
)
except ReadOnlyError:
if self.connection_pool.is_master:
# When talking to a master, a ReadOnlyError when likely
# indicates that the previous master that we're still connected
# to has been demoted to a slave and there's a new master.
# calling disconnect will force the connection to re-query
# sentinel during the next connect() attempt.
await self.disconnect()
raise ConnectionError("The previous master is now a slave")
raise
class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection):
pass
class SentinelConnectionPool(ConnectionPool):
"""
Sentinel backed connection pool.
If ``check_connection`` flag is set to True, SentinelManagedConnection
sends a PING command right after establishing the connection.
"""
def __init__(self, service_name, sentinel_manager, **kwargs):
kwargs["connection_class"] = kwargs.get(
"connection_class",
(
SentinelManagedSSLConnection
if kwargs.pop("ssl", False)
else SentinelManagedConnection
),
)
self.is_master = kwargs.pop("is_master", True)
self.check_connection = kwargs.pop("check_connection", False)
super().__init__(**kwargs)
self.connection_kwargs["connection_pool"] = weakref.proxy(self)
self.service_name = service_name
self.sentinel_manager = sentinel_manager
self.master_address = None
self.slave_rr_counter = None
def __repr__(self):
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
f"(service={self.service_name}({self.is_master and 'master' or 'slave'}))>"
)
def reset(self):
super().reset()
self.master_address = None
self.slave_rr_counter = None
def owns_connection(self, connection: Connection):
check = not self.is_master or (
self.is_master and self.master_address == (connection.host, connection.port)
)
return check and super().owns_connection(connection)
async def get_master_address(self):
master_address = await self.sentinel_manager.discover_master(self.service_name)
if self.is_master:
if self.master_address != master_address:
self.master_address = master_address
# disconnect any idle connections so that they reconnect
# to the new master the next time that they are used.
await self.disconnect(inuse_connections=False)
return master_address
async def rotate_slaves(self) -> AsyncIterator:
"""Round-robin slave balancer"""
slaves = await self.sentinel_manager.discover_slaves(self.service_name)
if slaves:
if self.slave_rr_counter is None:
self.slave_rr_counter = random.randint(0, len(slaves) - 1)
for _ in range(len(slaves)):
self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves)
slave = slaves[self.slave_rr_counter]
yield slave
# Fallback to the master connection
try:
yield await self.get_master_address()
except MasterNotFoundError:
pass
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")
class Sentinel(AsyncSentinelCommands):
"""
Redis Sentinel cluster client
>>> from redis.sentinel import Sentinel
>>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1)
>>> master = sentinel.master_for('mymaster', socket_timeout=0.1)
>>> await master.set('foo', 'bar')
>>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1)
>>> await slave.get('foo')
b'bar'
``sentinels`` is a list of sentinel nodes. Each node is represented by
a pair (hostname, port).
``min_other_sentinels`` defined a minimum number of peers for a sentinel.
When querying a sentinel, if it doesn't meet this threshold, responses
from that sentinel won't be considered valid.
``sentinel_kwargs`` is a dictionary of connection arguments used when
connecting to sentinel instances. Any argument that can be passed to
a normal Redis connection can be specified here. If ``sentinel_kwargs`` is
not specified, any socket_timeout and socket_keepalive options specified
in ``connection_kwargs`` will be used.
``connection_kwargs`` are keyword arguments that will be used when
establishing a connection to a Redis server.
"""
def __init__(
self,
sentinels,
min_other_sentinels=0,
sentinel_kwargs=None,
force_master_ip=None,
**connection_kwargs,
):
# if sentinel_kwargs isn't defined, use the socket_* options from
# connection_kwargs
if sentinel_kwargs is None:
sentinel_kwargs = {
k: v for k, v in connection_kwargs.items() if k.startswith("socket_")
}
self.sentinel_kwargs = sentinel_kwargs
self.sentinels = [
Redis(host=hostname, port=port, **self.sentinel_kwargs)
for hostname, port in sentinels
]
self.min_other_sentinels = min_other_sentinels
self.connection_kwargs = connection_kwargs
self._force_master_ip = force_master_ip
async def execute_command(self, *args, **kwargs):
"""
Execute Sentinel command in sentinel nodes.
once - If set to True, then execute the resulting command on a single
node at random, rather than across the entire sentinel cluster.
"""
once = bool(kwargs.pop("once", False))
# Check if command is supposed to return the original
# responses instead of boolean value.
return_responses = bool(kwargs.pop("return_responses", False))
if once:
response = await random.choice(self.sentinels).execute_command(
*args, **kwargs
)
if return_responses:
return [response]
else:
return True if response else False
tasks = [
asyncio.Task(sentinel.execute_command(*args, **kwargs))
for sentinel in self.sentinels
]
responses = await asyncio.gather(*tasks)
if return_responses:
return responses
return all(responses)
def __repr__(self):
sentinel_addresses = []
for sentinel in self.sentinels:
sentinel_addresses.append(
f"{sentinel.connection_pool.connection_kwargs['host']}:"
f"{sentinel.connection_pool.connection_kwargs['port']}"
)
return (
f"<{self.__class__}.{self.__class__.__name__}"
f"(sentinels=[{','.join(sentinel_addresses)}])>"
)
def check_master_state(self, state: dict, service_name: str) -> bool:
if not state["is_master"] or state["is_sdown"] or state["is_odown"]:
return False
# Check if our sentinel doesn't see other nodes
if state["num-other-sentinels"] < self.min_other_sentinels:
return False
return True
async def discover_master(self, service_name: str):
"""
Asks sentinel servers for the Redis master's address corresponding
to the service labeled ``service_name``.
Returns a pair (address, port) or raises MasterNotFoundError if no
master is found.
"""
collected_errors = list()
for sentinel_no, sentinel in enumerate(self.sentinels):
try:
masters = await sentinel.sentinel_masters()
except (ConnectionError, TimeoutError) as e:
collected_errors.append(f"{sentinel} - {e!r}")
continue
state = masters.get(service_name)
if state and self.check_master_state(state, service_name):
# Put this sentinel at the top of the list
self.sentinels[0], self.sentinels[sentinel_no] = (
sentinel,
self.sentinels[0],
)
ip = (
self._force_master_ip
if self._force_master_ip is not None
else state["ip"]
)
return ip, state["port"]
error_info = ""
if len(collected_errors) > 0:
error_info = f" : {', '.join(collected_errors)}"
raise MasterNotFoundError(f"No master found for {service_name!r}{error_info}")
def filter_slaves(
self, slaves: Iterable[Mapping]
) -> Sequence[Tuple[EncodableT, EncodableT]]:
"""Remove slaves that are in an ODOWN or SDOWN state"""
slaves_alive = []
for slave in slaves:
if slave["is_odown"] or slave["is_sdown"]:
continue
slaves_alive.append((slave["ip"], slave["port"]))
return slaves_alive
async def discover_slaves(
self, service_name: str
) -> Sequence[Tuple[EncodableT, EncodableT]]:
"""Returns a list of alive slaves for service ``service_name``"""
for sentinel in self.sentinels:
try:
slaves = await sentinel.sentinel_slaves(service_name)
except (ConnectionError, ResponseError, TimeoutError):
continue
slaves = self.filter_slaves(slaves)
if slaves:
return slaves
return []
def master_for(
self,
service_name: str,
redis_class: Type[Redis] = Redis,
connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool,
**kwargs,
):
"""
Returns a redis client instance for the ``service_name`` master.
Sentinel client will detect failover and reconnect Redis clients
automatically.
A :py:class:`~redis.sentinel.SentinelConnectionPool` class is
used to retrieve the master's address before establishing a new
connection.
NOTE: If the master's address has changed, any cached connections to
the old master are closed.
By default clients will be a :py:class:`~redis.Redis` instance.
Specify a different class to the ``redis_class`` argument if you
desire something different.
The ``connection_pool_class`` specifies the connection pool to
use. The :py:class:`~redis.sentinel.SentinelConnectionPool`
will be used by default.
All other keyword arguments are merged with any connection_kwargs
passed to this class and passed to the connection pool as keyword
arguments to be used to initialize Redis connections.
"""
kwargs["is_master"] = True
connection_kwargs = dict(self.connection_kwargs)
connection_kwargs.update(kwargs)
connection_pool = connection_pool_class(service_name, self, **connection_kwargs)
# The Redis object "owns" the pool
return redis_class.from_pool(connection_pool)
def slave_for(
self,
service_name: str,
redis_class: Type[Redis] = Redis,
connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool,
**kwargs,
):
"""
Returns redis client instance for the ``service_name`` slave(s).
A SentinelConnectionPool class is used to retrieve the slave's
address before establishing a new connection.
By default clients will be a :py:class:`~redis.Redis` instance.
Specify a different class to the ``redis_class`` argument if you
desire something different.
The ``connection_pool_class`` specifies the connection pool to use.
The SentinelConnectionPool will be used by default.
All other keyword arguments are merged with any connection_kwargs
passed to this class and passed to the connection pool as keyword
arguments to be used to initialize Redis connections.
"""
kwargs["is_master"] = False
connection_kwargs = dict(self.connection_kwargs)
connection_kwargs.update(kwargs)
connection_pool = connection_pool_class(service_name, self, **connection_kwargs)
# The Redis object "owns" the pool
return redis_class.from_pool(connection_pool)

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from redis.asyncio.client import Pipeline, Redis
def from_url(url, **kwargs):
"""
Returns an active Redis client generated from the given database URL.
Will attempt to extract the database id from the path url fragment, if
none is provided.
"""
from redis.asyncio.client import Redis
return Redis.from_url(url, **kwargs)
class pipeline: # noqa: N801
def __init__(self, redis_obj: "Redis"):
self.p: "Pipeline" = redis_obj.pipeline()
async def __aenter__(self) -> "Pipeline":
return self.p
async def __aexit__(self, exc_type, exc_value, traceback):
await self.p.execute()
del self.p