Initial commit: igny8 project

This commit is contained in:
igny8
2025-11-09 10:27:02 +00:00
commit 60b8188111
27265 changed files with 4360521 additions and 0 deletions

View File

@@ -0,0 +1,88 @@
from redis import asyncio # noqa
from redis.backoff import default_backoff
from redis.client import Redis, StrictRedis
from redis.cluster import RedisCluster
from redis.connection import (
BlockingConnectionPool,
Connection,
ConnectionPool,
SSLConnection,
UnixDomainSocketConnection,
)
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
from redis.exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
BusyLoadingError,
ChildDeadlockedError,
ConnectionError,
CrossSlotTransactionError,
DataError,
InvalidPipelineStack,
InvalidResponse,
MaxConnectionsError,
OutOfMemoryError,
PubSubError,
ReadOnlyError,
RedisClusterException,
RedisError,
ResponseError,
TimeoutError,
WatchError,
)
from redis.sentinel import (
Sentinel,
SentinelConnectionPool,
SentinelManagedConnection,
SentinelManagedSSLConnection,
)
from redis.utils import from_url
def int_or_str(value):
try:
return int(value)
except ValueError:
return value
__version__ = "7.0.1"
VERSION = tuple(map(int_or_str, __version__.split(".")))
__all__ = [
"AuthenticationError",
"AuthenticationWrongNumberOfArgsError",
"BlockingConnectionPool",
"BusyLoadingError",
"ChildDeadlockedError",
"Connection",
"ConnectionError",
"ConnectionPool",
"CredentialProvider",
"CrossSlotTransactionError",
"DataError",
"from_url",
"default_backoff",
"InvalidPipelineStack",
"InvalidResponse",
"MaxConnectionsError",
"OutOfMemoryError",
"PubSubError",
"ReadOnlyError",
"Redis",
"RedisCluster",
"RedisClusterException",
"RedisError",
"ResponseError",
"Sentinel",
"SentinelConnectionPool",
"SentinelManagedConnection",
"SentinelManagedSSLConnection",
"SSLConnection",
"UsernamePasswordCredentialProvider",
"StrictRedis",
"TimeoutError",
"UnixDomainSocketConnection",
"WatchError",
]

View File

@@ -0,0 +1,27 @@
from .base import (
AsyncPushNotificationsParser,
BaseParser,
PushNotificationsParser,
_AsyncRESPBase,
)
from .commands import AsyncCommandsParser, CommandsParser
from .encoders import Encoder
from .hiredis import _AsyncHiredisParser, _HiredisParser
from .resp2 import _AsyncRESP2Parser, _RESP2Parser
from .resp3 import _AsyncRESP3Parser, _RESP3Parser
__all__ = [
"AsyncCommandsParser",
"_AsyncHiredisParser",
"_AsyncRESPBase",
"_AsyncRESP2Parser",
"_AsyncRESP3Parser",
"AsyncPushNotificationsParser",
"CommandsParser",
"Encoder",
"BaseParser",
"_HiredisParser",
"_RESP2Parser",
"_RESP3Parser",
"PushNotificationsParser",
]

View File

@@ -0,0 +1,474 @@
import logging
import sys
from abc import ABC
from asyncio import IncompleteReadError, StreamReader, TimeoutError
from typing import Awaitable, Callable, List, Optional, Protocol, Union
from redis.maint_notifications import (
MaintenanceNotification,
NodeFailedOverNotification,
NodeFailingOverNotification,
NodeMigratedNotification,
NodeMigratingNotification,
NodeMovingNotification,
)
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout
from ..exceptions import (
AskError,
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
BusyLoadingError,
ClusterCrossSlotError,
ClusterDownError,
ConnectionError,
ExecAbortError,
ExternalAuthProviderError,
MasterDownError,
ModuleError,
MovedError,
NoPermissionError,
NoScriptError,
OutOfMemoryError,
ReadOnlyError,
RedisError,
ResponseError,
TryAgainError,
)
from ..typing import EncodableT
from .encoders import Encoder
from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer
MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs."
NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible."
MODULE_EXPORTS_DATA_TYPES_ERROR = (
"Error unloading module: the module "
"exports one or more module-side data "
"types, can't unload"
)
# user send an AUTH cmd to a server without authorization configured
NO_AUTH_SET_ERROR = {
# Redis >= 6.0
"AUTH <password> called without any password "
"configured for the default user. Are you sure "
"your configuration is correct?": AuthenticationError,
# Redis < 6.0
"Client sent AUTH, but no password is set": AuthenticationError,
}
EXTERNAL_AUTH_PROVIDER_ERROR = {
"problem with LDAP service": ExternalAuthProviderError,
}
logger = logging.getLogger(__name__)
class BaseParser(ABC):
EXCEPTION_CLASSES = {
"ERR": {
"max number of clients reached": ConnectionError,
"invalid password": AuthenticationError,
# some Redis server versions report invalid command syntax
# in lowercase
"wrong number of arguments "
"for 'auth' command": AuthenticationWrongNumberOfArgsError,
# some Redis server versions report invalid command syntax
# in uppercase
"wrong number of arguments "
"for 'AUTH' command": AuthenticationWrongNumberOfArgsError,
MODULE_LOAD_ERROR: ModuleError,
MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError,
NO_SUCH_MODULE_ERROR: ModuleError,
MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError,
**NO_AUTH_SET_ERROR,
**EXTERNAL_AUTH_PROVIDER_ERROR,
},
"OOM": OutOfMemoryError,
"WRONGPASS": AuthenticationError,
"EXECABORT": ExecAbortError,
"LOADING": BusyLoadingError,
"NOSCRIPT": NoScriptError,
"READONLY": ReadOnlyError,
"NOAUTH": AuthenticationError,
"NOPERM": NoPermissionError,
"ASK": AskError,
"TRYAGAIN": TryAgainError,
"MOVED": MovedError,
"CLUSTERDOWN": ClusterDownError,
"CROSSSLOT": ClusterCrossSlotError,
"MASTERDOWN": MasterDownError,
}
@classmethod
def parse_error(cls, response):
"Parse an error response"
error_code = response.split(" ")[0]
if error_code in cls.EXCEPTION_CLASSES:
response = response[len(error_code) + 1 :]
exception_class = cls.EXCEPTION_CLASSES[error_code]
if isinstance(exception_class, dict):
exception_class = exception_class.get(response, ResponseError)
return exception_class(response)
return ResponseError(response)
def on_disconnect(self):
raise NotImplementedError()
def on_connect(self, connection):
raise NotImplementedError()
class _RESPBase(BaseParser):
"""Base class for sync-based resp parsing"""
def __init__(self, socket_read_size):
self.socket_read_size = socket_read_size
self.encoder = None
self._sock = None
self._buffer = None
def __del__(self):
try:
self.on_disconnect()
except Exception:
pass
def on_connect(self, connection):
"Called when the socket connects"
self._sock = connection._sock
self._buffer = SocketBuffer(
self._sock, self.socket_read_size, connection.socket_timeout
)
self.encoder = connection.encoder
def on_disconnect(self):
"Called when the socket disconnects"
self._sock = None
if self._buffer is not None:
self._buffer.close()
self._buffer = None
self.encoder = None
def can_read(self, timeout):
return self._buffer and self._buffer.can_read(timeout)
class AsyncBaseParser(BaseParser):
"""Base parsing class for the python-backed async parser"""
__slots__ = "_stream", "_read_size"
def __init__(self, socket_read_size: int):
self._stream: Optional[StreamReader] = None
self._read_size = socket_read_size
async def can_read_destructive(self) -> bool:
raise NotImplementedError()
async def read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None, List[EncodableT]]:
raise NotImplementedError()
class MaintenanceNotificationsParser:
"""Protocol defining maintenance push notification parsing functionality"""
@staticmethod
def parse_maintenance_start_msg(response, notification_type):
# Expected message format is: <notification_type> <seq_number> <time>
id = response[1]
ttl = response[2]
return notification_type(id, ttl)
@staticmethod
def parse_maintenance_completed_msg(response, notification_type):
# Expected message format is: <notification_type> <seq_number>
id = response[1]
return notification_type(id)
@staticmethod
def parse_moving_msg(response):
# Expected message format is: MOVING <seq_number> <time> <endpoint>
id = response[1]
ttl = response[2]
if response[3] is None:
host, port = None, None
else:
value = response[3]
if isinstance(value, bytes):
value = value.decode()
host, port = value.split(":")
port = int(port) if port is not None else None
return NodeMovingNotification(id, host, port, ttl)
_INVALIDATION_MESSAGE = "invalidate"
_MOVING_MESSAGE = "MOVING"
_MIGRATING_MESSAGE = "MIGRATING"
_MIGRATED_MESSAGE = "MIGRATED"
_FAILING_OVER_MESSAGE = "FAILING_OVER"
_FAILED_OVER_MESSAGE = "FAILED_OVER"
_MAINTENANCE_MESSAGES = (
_MIGRATING_MESSAGE,
_MIGRATED_MESSAGE,
_FAILING_OVER_MESSAGE,
_FAILED_OVER_MESSAGE,
)
MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING: dict[
str, tuple[type[MaintenanceNotification], Callable]
] = {
_MIGRATING_MESSAGE: (
NodeMigratingNotification,
MaintenanceNotificationsParser.parse_maintenance_start_msg,
),
_MIGRATED_MESSAGE: (
NodeMigratedNotification,
MaintenanceNotificationsParser.parse_maintenance_completed_msg,
),
_FAILING_OVER_MESSAGE: (
NodeFailingOverNotification,
MaintenanceNotificationsParser.parse_maintenance_start_msg,
),
_FAILED_OVER_MESSAGE: (
NodeFailedOverNotification,
MaintenanceNotificationsParser.parse_maintenance_completed_msg,
),
_MOVING_MESSAGE: (
NodeMovingNotification,
MaintenanceNotificationsParser.parse_moving_msg,
),
}
class PushNotificationsParser(Protocol):
"""Protocol defining RESP3-specific parsing functionality"""
pubsub_push_handler_func: Callable
invalidation_push_handler_func: Optional[Callable] = None
node_moving_push_handler_func: Optional[Callable] = None
maintenance_push_handler_func: Optional[Callable] = None
def handle_pubsub_push_response(self, response):
"""Handle pubsub push responses"""
raise NotImplementedError()
def handle_push_response(self, response, **kwargs):
msg_type = response[0]
if isinstance(msg_type, bytes):
msg_type = msg_type.decode()
if msg_type not in (
_INVALIDATION_MESSAGE,
*_MAINTENANCE_MESSAGES,
_MOVING_MESSAGE,
):
return self.pubsub_push_handler_func(response)
try:
if (
msg_type == _INVALIDATION_MESSAGE
and self.invalidation_push_handler_func
):
return self.invalidation_push_handler_func(response)
if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
msg_type
][1]
notification = parser_function(response)
return self.node_moving_push_handler_func(notification)
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
msg_type
][1]
notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
msg_type
][0]
notification = parser_function(response, notification_type)
if notification is not None:
return self.maintenance_push_handler_func(notification)
except Exception as e:
logger.error(
"Error handling {} message ({}): {}".format(msg_type, response, e)
)
return None
def set_pubsub_push_handler(self, pubsub_push_handler_func):
self.pubsub_push_handler_func = pubsub_push_handler_func
def set_invalidation_push_handler(self, invalidation_push_handler_func):
self.invalidation_push_handler_func = invalidation_push_handler_func
def set_node_moving_push_handler(self, node_moving_push_handler_func):
self.node_moving_push_handler_func = node_moving_push_handler_func
def set_maintenance_push_handler(self, maintenance_push_handler_func):
self.maintenance_push_handler_func = maintenance_push_handler_func
class AsyncPushNotificationsParser(Protocol):
"""Protocol defining async RESP3-specific parsing functionality"""
pubsub_push_handler_func: Callable
invalidation_push_handler_func: Optional[Callable] = None
node_moving_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
maintenance_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
async def handle_pubsub_push_response(self, response):
"""Handle pubsub push responses asynchronously"""
raise NotImplementedError()
async def handle_push_response(self, response, **kwargs):
"""Handle push responses asynchronously"""
msg_type = response[0]
if isinstance(msg_type, bytes):
msg_type = msg_type.decode()
if msg_type not in (
_INVALIDATION_MESSAGE,
*_MAINTENANCE_MESSAGES,
_MOVING_MESSAGE,
):
return await self.pubsub_push_handler_func(response)
try:
if (
msg_type == _INVALIDATION_MESSAGE
and self.invalidation_push_handler_func
):
return await self.invalidation_push_handler_func(response)
if isinstance(msg_type, bytes):
msg_type = msg_type.decode()
if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
msg_type
][1]
notification = parser_function(response)
return await self.node_moving_push_handler_func(notification)
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
msg_type
][1]
notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
msg_type
][0]
notification = parser_function(response, notification_type)
if notification is not None:
return await self.maintenance_push_handler_func(notification)
except Exception as e:
logger.error(
"Error handling {} message ({}): {}".format(msg_type, response, e)
)
return None
def set_pubsub_push_handler(self, pubsub_push_handler_func):
"""Set the pubsub push handler function"""
self.pubsub_push_handler_func = pubsub_push_handler_func
def set_invalidation_push_handler(self, invalidation_push_handler_func):
"""Set the invalidation push handler function"""
self.invalidation_push_handler_func = invalidation_push_handler_func
def set_node_moving_push_handler(self, node_moving_push_handler_func):
self.node_moving_push_handler_func = node_moving_push_handler_func
def set_maintenance_push_handler(self, maintenance_push_handler_func):
self.maintenance_push_handler_func = maintenance_push_handler_func
class _AsyncRESPBase(AsyncBaseParser):
"""Base class for async resp parsing"""
__slots__ = AsyncBaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
def __init__(self, socket_read_size: int):
super().__init__(socket_read_size)
self.encoder: Optional[Encoder] = None
self._buffer = b""
self._chunks = []
self._pos = 0
def _clear(self):
self._buffer = b""
self._chunks.clear()
def on_connect(self, connection):
"""Called when the stream connects"""
self._stream = connection._reader
if self._stream is None:
raise RedisError("Buffer is closed.")
self.encoder = connection.encoder
self._clear()
self._connected = True
def on_disconnect(self):
"""Called when the stream disconnects"""
self._connected = False
async def can_read_destructive(self) -> bool:
if not self._connected:
raise RedisError("Buffer is closed.")
if self._buffer:
return True
try:
async with async_timeout(0):
return self._stream.at_eof()
except TimeoutError:
return False
async def _read(self, length: int) -> bytes:
"""
Read `length` bytes of data. These are assumed to be followed
by a '\r\n' terminator which is subsequently discarded.
"""
want = length + 2
end = self._pos + want
if len(self._buffer) >= end:
result = self._buffer[self._pos : end - 2]
else:
tail = self._buffer[self._pos :]
try:
data = await self._stream.readexactly(want - len(tail))
except IncompleteReadError as error:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
result = (tail + data)[:-2]
self._chunks.append(data)
self._pos += want
return result
async def _readline(self) -> bytes:
"""
read an unknown number of bytes up to the next '\r\n'
line separator, which is discarded.
"""
found = self._buffer.find(b"\r\n", self._pos)
if found >= 0:
result = self._buffer[self._pos : found]
else:
tail = self._buffer[self._pos :]
data = await self._stream.readline()
if not data.endswith(b"\r\n"):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
result = (tail + data)[:-2]
self._chunks.append(data)
self._pos += len(result) + 2
return result

View File

@@ -0,0 +1,281 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from redis.exceptions import RedisError, ResponseError
from redis.utils import str_if_bytes
if TYPE_CHECKING:
from redis.asyncio.cluster import ClusterNode
class AbstractCommandsParser:
def _get_pubsub_keys(self, *args):
"""
Get the keys from pubsub command.
Although PubSub commands have predetermined key locations, they are not
supported in the 'COMMAND's output, so the key positions are hardcoded
in this method
"""
if len(args) < 2:
# The command has no keys in it
return None
args = [str_if_bytes(arg) for arg in args]
command = args[0].upper()
keys = None
if command == "PUBSUB":
# the second argument is a part of the command name, e.g.
# ['PUBSUB', 'NUMSUB', 'foo'].
pubsub_type = args[1].upper()
if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]:
keys = args[2:]
elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]:
# format example:
# SUBSCRIBE channel [channel ...]
keys = list(args[1:])
elif command in ["PUBLISH", "SPUBLISH"]:
# format example:
# PUBLISH channel message
keys = [args[1]]
return keys
def parse_subcommand(self, command, **options):
cmd_dict = {}
cmd_name = str_if_bytes(command[0])
cmd_dict["name"] = cmd_name
cmd_dict["arity"] = int(command[1])
cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]]
cmd_dict["first_key_pos"] = command[3]
cmd_dict["last_key_pos"] = command[4]
cmd_dict["step_count"] = command[5]
if len(command) > 7:
cmd_dict["tips"] = command[7]
cmd_dict["key_specifications"] = command[8]
cmd_dict["subcommands"] = command[9]
return cmd_dict
class CommandsParser(AbstractCommandsParser):
"""
Parses Redis commands to get command keys.
COMMAND output is used to determine key locations.
Commands that do not have a predefined key location are flagged with
'movablekeys', and these commands' keys are determined by the command
'COMMAND GETKEYS'.
"""
def __init__(self, redis_connection):
self.commands = {}
self.initialize(redis_connection)
def initialize(self, r):
commands = r.command()
uppercase_commands = []
for cmd in commands:
if any(x.isupper() for x in cmd):
uppercase_commands.append(cmd)
for cmd in uppercase_commands:
commands[cmd.lower()] = commands.pop(cmd)
self.commands = commands
# As soon as this PR is merged into Redis, we should reimplement
# our logic to use COMMAND INFO changes to determine the key positions
# https://github.com/redis/redis/pull/8324
def get_keys(self, redis_conn, *args):
"""
Get the keys from the passed command.
NOTE: Due to a bug in redis<7.0, this function does not work properly
for EVAL or EVALSHA when the `numkeys` arg is 0.
- issue: https://github.com/redis/redis/issues/9493
- fix: https://github.com/redis/redis/pull/9733
So, don't use this function with EVAL or EVALSHA.
"""
if len(args) < 2:
# The command has no keys in it
return None
cmd_name = args[0].lower()
if cmd_name not in self.commands:
# try to split the command name and to take only the main command,
# e.g. 'memory' for 'memory usage'
cmd_name_split = cmd_name.split()
cmd_name = cmd_name_split[0]
if cmd_name in self.commands:
# save the splitted command to args
args = cmd_name_split + list(args[1:])
else:
# We'll try to reinitialize the commands cache, if the engine
# version has changed, the commands may not be current
self.initialize(redis_conn)
if cmd_name not in self.commands:
raise RedisError(
f"{cmd_name.upper()} command doesn't exist in Redis commands"
)
command = self.commands.get(cmd_name)
if "movablekeys" in command["flags"]:
keys = self._get_moveable_keys(redis_conn, *args)
elif "pubsub" in command["flags"] or command["name"] == "pubsub":
keys = self._get_pubsub_keys(*args)
else:
if (
command["step_count"] == 0
and command["first_key_pos"] == 0
and command["last_key_pos"] == 0
):
is_subcmd = False
if "subcommands" in command:
subcmd_name = f"{cmd_name}|{args[1].lower()}"
for subcmd in command["subcommands"]:
if str_if_bytes(subcmd[0]) == subcmd_name:
command = self.parse_subcommand(subcmd)
is_subcmd = True
# The command doesn't have keys in it
if not is_subcmd:
return None
last_key_pos = command["last_key_pos"]
if last_key_pos < 0:
last_key_pos = len(args) - abs(last_key_pos)
keys_pos = list(
range(command["first_key_pos"], last_key_pos + 1, command["step_count"])
)
keys = [args[pos] for pos in keys_pos]
return keys
def _get_moveable_keys(self, redis_conn, *args):
"""
NOTE: Due to a bug in redis<7.0, this function does not work properly
for EVAL or EVALSHA when the `numkeys` arg is 0.
- issue: https://github.com/redis/redis/issues/9493
- fix: https://github.com/redis/redis/pull/9733
So, don't use this function with EVAL or EVALSHA.
"""
# The command name should be splitted into separate arguments,
# e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE']
pieces = args[0].split() + list(args[1:])
try:
keys = redis_conn.execute_command("COMMAND GETKEYS", *pieces)
except ResponseError as e:
message = e.__str__()
if (
"Invalid arguments" in message
or "The command has no key arguments" in message
):
return None
else:
raise e
return keys
class AsyncCommandsParser(AbstractCommandsParser):
"""
Parses Redis commands to get command keys.
COMMAND output is used to determine key locations.
Commands that do not have a predefined key location are flagged with 'movablekeys',
and these commands' keys are determined by the command 'COMMAND GETKEYS'.
NOTE: Due to a bug in redis<7.0, this does not work properly
for EVAL or EVALSHA when the `numkeys` arg is 0.
- issue: https://github.com/redis/redis/issues/9493
- fix: https://github.com/redis/redis/pull/9733
So, don't use this with EVAL or EVALSHA.
"""
__slots__ = ("commands", "node")
def __init__(self) -> None:
self.commands: Dict[str, Union[int, Dict[str, Any]]] = {}
async def initialize(self, node: Optional["ClusterNode"] = None) -> None:
if node:
self.node = node
commands = await self.node.execute_command("COMMAND")
self.commands = {cmd.lower(): command for cmd, command in commands.items()}
# As soon as this PR is merged into Redis, we should reimplement
# our logic to use COMMAND INFO changes to determine the key positions
# https://github.com/redis/redis/pull/8324
async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]:
"""
Get the keys from the passed command.
NOTE: Due to a bug in redis<7.0, this function does not work properly
for EVAL or EVALSHA when the `numkeys` arg is 0.
- issue: https://github.com/redis/redis/issues/9493
- fix: https://github.com/redis/redis/pull/9733
So, don't use this function with EVAL or EVALSHA.
"""
if len(args) < 2:
# The command has no keys in it
return None
cmd_name = args[0].lower()
if cmd_name not in self.commands:
# try to split the command name and to take only the main command,
# e.g. 'memory' for 'memory usage'
cmd_name_split = cmd_name.split()
cmd_name = cmd_name_split[0]
if cmd_name in self.commands:
# save the splitted command to args
args = cmd_name_split + list(args[1:])
else:
# We'll try to reinitialize the commands cache, if the engine
# version has changed, the commands may not be current
await self.initialize()
if cmd_name not in self.commands:
raise RedisError(
f"{cmd_name.upper()} command doesn't exist in Redis commands"
)
command = self.commands.get(cmd_name)
if "movablekeys" in command["flags"]:
keys = await self._get_moveable_keys(*args)
elif "pubsub" in command["flags"] or command["name"] == "pubsub":
keys = self._get_pubsub_keys(*args)
else:
if (
command["step_count"] == 0
and command["first_key_pos"] == 0
and command["last_key_pos"] == 0
):
is_subcmd = False
if "subcommands" in command:
subcmd_name = f"{cmd_name}|{args[1].lower()}"
for subcmd in command["subcommands"]:
if str_if_bytes(subcmd[0]) == subcmd_name:
command = self.parse_subcommand(subcmd)
is_subcmd = True
# The command doesn't have keys in it
if not is_subcmd:
return None
last_key_pos = command["last_key_pos"]
if last_key_pos < 0:
last_key_pos = len(args) - abs(last_key_pos)
keys_pos = list(
range(command["first_key_pos"], last_key_pos + 1, command["step_count"])
)
keys = [args[pos] for pos in keys_pos]
return keys
async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]:
try:
keys = await self.node.execute_command("COMMAND GETKEYS", *args)
except ResponseError as e:
message = e.__str__()
if (
"Invalid arguments" in message
or "The command has no key arguments" in message
):
return None
else:
raise e
return keys

View File

@@ -0,0 +1,44 @@
from ..exceptions import DataError
class Encoder:
"Encode strings to bytes-like and decode bytes-like to strings"
__slots__ = "encoding", "encoding_errors", "decode_responses"
def __init__(self, encoding, encoding_errors, decode_responses):
self.encoding = encoding
self.encoding_errors = encoding_errors
self.decode_responses = decode_responses
def encode(self, value):
"Return a bytestring or bytes-like representation of the value"
if isinstance(value, (bytes, memoryview)):
return value
elif isinstance(value, bool):
# special case bool since it is a subclass of int
raise DataError(
"Invalid input of type: 'bool'. Convert to a "
"bytes, string, int or float first."
)
elif isinstance(value, (int, float)):
value = repr(value).encode()
elif not isinstance(value, str):
# a value we don't know how to deal with. throw an error
typename = type(value).__name__
raise DataError(
f"Invalid input of type: '{typename}'. "
f"Convert to a bytes, string, int or float first."
)
if isinstance(value, str):
value = value.encode(self.encoding, self.encoding_errors)
return value
def decode(self, value, force=False):
"Return a unicode string from the bytes-like representation"
if self.decode_responses or force:
if isinstance(value, memoryview):
value = value.tobytes()
if isinstance(value, bytes):
value = value.decode(self.encoding, self.encoding_errors)
return value

View File

@@ -0,0 +1,941 @@
import datetime
from redis.utils import str_if_bytes
def timestamp_to_datetime(response):
"Converts a unix timestamp to a Python datetime object"
if not response:
return None
try:
response = int(response)
except ValueError:
return None
return datetime.datetime.fromtimestamp(response)
def parse_debug_object(response):
"Parse the results of Redis's DEBUG OBJECT command into a Python dict"
# The 'type' of the object is the first item in the response, but isn't
# prefixed with a name
response = str_if_bytes(response)
response = "type:" + response
response = dict(kv.split(":") for kv in response.split())
# parse some expected int values from the string response
# note: this cmd isn't spec'd so these may not appear in all redis versions
int_fields = ("refcount", "serializedlength", "lru", "lru_seconds_idle")
for field in int_fields:
if field in response:
response[field] = int(response[field])
return response
def parse_info(response):
"""Parse the result of Redis's INFO command into a Python dict"""
info = {}
response = str_if_bytes(response)
def get_value(value):
if "," not in value and "=" not in value:
try:
if "." in value:
return float(value)
else:
return int(value)
except ValueError:
return value
elif "=" not in value:
return [get_value(v) for v in value.split(",") if v]
else:
sub_dict = {}
for item in value.split(","):
if not item:
continue
if "=" in item:
k, v = item.rsplit("=", 1)
sub_dict[k] = get_value(v)
else:
sub_dict[item] = True
return sub_dict
for line in response.splitlines():
if line and not line.startswith("#"):
if line.find(":") != -1:
# Split, the info fields keys and values.
# Note that the value may contain ':'. but the 'host:'
# pseudo-command is the only case where the key contains ':'
key, value = line.split(":", 1)
if key == "cmdstat_host":
key, value = line.rsplit(":", 1)
if key == "module":
# Hardcode a list for key 'modules' since there could be
# multiple lines that started with 'module'
info.setdefault("modules", []).append(get_value(value))
else:
info[key] = get_value(value)
else:
# if the line isn't splittable, append it to the "__raw__" key
info.setdefault("__raw__", []).append(line)
return info
def parse_memory_stats(response, **kwargs):
"""Parse the results of MEMORY STATS"""
stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True)
for key, value in stats.items():
if key.startswith("db.") and isinstance(value, list):
stats[key] = pairs_to_dict(
value, decode_keys=True, decode_string_values=True
)
return stats
SENTINEL_STATE_TYPES = {
"can-failover-its-master": int,
"config-epoch": int,
"down-after-milliseconds": int,
"failover-timeout": int,
"info-refresh": int,
"last-hello-message": int,
"last-ok-ping-reply": int,
"last-ping-reply": int,
"last-ping-sent": int,
"master-link-down-time": int,
"master-port": int,
"num-other-sentinels": int,
"num-slaves": int,
"o-down-time": int,
"pending-commands": int,
"parallel-syncs": int,
"port": int,
"quorum": int,
"role-reported-time": int,
"s-down-time": int,
"slave-priority": int,
"slave-repl-offset": int,
"voted-leader-epoch": int,
}
def parse_sentinel_state(item):
result = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES)
flags = set(result["flags"].split(","))
for name, flag in (
("is_master", "master"),
("is_slave", "slave"),
("is_sdown", "s_down"),
("is_odown", "o_down"),
("is_sentinel", "sentinel"),
("is_disconnected", "disconnected"),
("is_master_down", "master_down"),
):
result[name] = flag in flags
return result
def parse_sentinel_master(response):
return parse_sentinel_state(map(str_if_bytes, response))
def parse_sentinel_state_resp3(response):
result = {}
for key in response:
try:
value = SENTINEL_STATE_TYPES[key](str_if_bytes(response[key]))
result[str_if_bytes(key)] = value
except Exception:
result[str_if_bytes(key)] = response[str_if_bytes(key)]
flags = set(result["flags"].split(","))
result["flags"] = flags
return result
def parse_sentinel_masters(response):
result = {}
for item in response:
state = parse_sentinel_state(map(str_if_bytes, item))
result[state["name"]] = state
return result
def parse_sentinel_masters_resp3(response):
return [parse_sentinel_state(master) for master in response]
def parse_sentinel_slaves_and_sentinels(response):
return [parse_sentinel_state(map(str_if_bytes, item)) for item in response]
def parse_sentinel_slaves_and_sentinels_resp3(response):
return [parse_sentinel_state_resp3(item) for item in response]
def parse_sentinel_get_master(response):
return response and (response[0], int(response[1])) or None
def pairs_to_dict(response, decode_keys=False, decode_string_values=False):
"""Create a dict given a list of key/value pairs"""
if response is None:
return {}
if decode_keys or decode_string_values:
# the iter form is faster, but I don't know how to make that work
# with a str_if_bytes() map
keys = response[::2]
if decode_keys:
keys = map(str_if_bytes, keys)
values = response[1::2]
if decode_string_values:
values = map(str_if_bytes, values)
return dict(zip(keys, values))
else:
it = iter(response)
return dict(zip(it, it))
def pairs_to_dict_typed(response, type_info):
it = iter(response)
result = {}
for key, value in zip(it, it):
if key in type_info:
try:
value = type_info[key](value)
except Exception:
# if for some reason the value can't be coerced, just use
# the string value
pass
result[key] = value
return result
def zset_score_pairs(response, **options):
"""
If ``withscores`` is specified in the options, return the response as
a list of (value, score) pairs
"""
if not response or not options.get("withscores"):
return response
score_cast_func = options.get("score_cast_func", float)
it = iter(response)
return list(zip(it, map(score_cast_func, it)))
def zset_score_for_rank(response, **options):
"""
If ``withscores`` is specified in the options, return the response as
a [value, score] pair
"""
if not response or not options.get("withscore"):
return response
score_cast_func = options.get("score_cast_func", float)
return [response[0], score_cast_func(response[1])]
def zset_score_pairs_resp3(response, **options):
"""
If ``withscores`` is specified in the options, return the response as
a list of [value, score] pairs
"""
if not response or not options.get("withscores"):
return response
score_cast_func = options.get("score_cast_func", float)
return [[name, score_cast_func(val)] for name, val in response]
def zset_score_for_rank_resp3(response, **options):
"""
If ``withscores`` is specified in the options, return the response as
a [value, score] pair
"""
if not response or not options.get("withscore"):
return response
score_cast_func = options.get("score_cast_func", float)
return [response[0], score_cast_func(response[1])]
def sort_return_tuples(response, **options):
"""
If ``groups`` is specified, return the response as a list of
n-element tuples with n being the value found in options['groups']
"""
if not response or not options.get("groups"):
return response
n = options["groups"]
return list(zip(*[response[i::n] for i in range(n)]))
def parse_stream_list(response):
if response is None:
return None
data = []
for r in response:
if r is not None:
data.append((r[0], pairs_to_dict(r[1])))
else:
data.append((None, None))
return data
def pairs_to_dict_with_str_keys(response):
return pairs_to_dict(response, decode_keys=True)
def parse_list_of_dicts(response):
return list(map(pairs_to_dict_with_str_keys, response))
def parse_xclaim(response, **options):
if options.get("parse_justid", False):
return response
return parse_stream_list(response)
def parse_xautoclaim(response, **options):
if options.get("parse_justid", False):
return response[1]
response[1] = parse_stream_list(response[1])
return response
def parse_xinfo_stream(response, **options):
if isinstance(response, list):
data = pairs_to_dict(response, decode_keys=True)
else:
data = {str_if_bytes(k): v for k, v in response.items()}
if not options.get("full", False):
first = data.get("first-entry")
if first is not None and first[0] is not None:
data["first-entry"] = (first[0], pairs_to_dict(first[1]))
last = data["last-entry"]
if last is not None and last[0] is not None:
data["last-entry"] = (last[0], pairs_to_dict(last[1]))
else:
data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]}
if len(data["groups"]) > 0 and isinstance(data["groups"][0], list):
data["groups"] = [
pairs_to_dict(group, decode_keys=True) for group in data["groups"]
]
for g in data["groups"]:
if g["consumers"] and g["consumers"][0] is not None:
g["consumers"] = [
pairs_to_dict(c, decode_keys=True) for c in g["consumers"]
]
else:
data["groups"] = [
{str_if_bytes(k): v for k, v in group.items()}
for group in data["groups"]
]
return data
def parse_xread(response):
if response is None:
return []
return [[r[0], parse_stream_list(r[1])] for r in response]
def parse_xread_resp3(response):
if response is None:
return {}
return {key: [parse_stream_list(value)] for key, value in response.items()}
def parse_xpending(response, **options):
if options.get("parse_detail", False):
return parse_xpending_range(response)
consumers = [{"name": n, "pending": int(p)} for n, p in response[3] or []]
return {
"pending": response[0],
"min": response[1],
"max": response[2],
"consumers": consumers,
}
def parse_xpending_range(response):
k = ("message_id", "consumer", "time_since_delivered", "times_delivered")
return [dict(zip(k, r)) for r in response]
def float_or_none(response):
if response is None:
return None
return float(response)
def bool_ok(response, **options):
return str_if_bytes(response) == "OK"
def parse_zadd(response, **options):
if response is None:
return None
if options.get("as_score"):
return float(response)
return int(response)
def parse_client_list(response, **options):
clients = []
for c in str_if_bytes(response).splitlines():
client_dict = {}
tokens = c.split(" ")
last_key = None
for token in tokens:
if "=" in token:
# Values might contain '='
key, value = token.split("=", 1)
client_dict[key] = value
last_key = key
else:
# Values may include spaces. For instance, when running Redis via a Unix socket — such as
# "/tmp/redis sock/redis.sock" — the addr or laddr field will include a space.
client_dict[last_key] += " " + token
if client_dict:
clients.append(client_dict)
return clients
def parse_config_get(response, **options):
response = [str_if_bytes(i) if i is not None else None for i in response]
return response and pairs_to_dict(response) or {}
def parse_scan(response, **options):
cursor, r = response
return int(cursor), r
def parse_hscan(response, **options):
cursor, r = response
no_values = options.get("no_values", False)
if no_values:
payload = r or []
else:
payload = r and pairs_to_dict(r) or {}
return int(cursor), payload
def parse_zscan(response, **options):
score_cast_func = options.get("score_cast_func", float)
cursor, r = response
it = iter(r)
return int(cursor), list(zip(it, map(score_cast_func, it)))
def parse_zmscore(response, **options):
# zmscore: list of scores (double precision floating point number) or nil
return [float(score) if score is not None else None for score in response]
def parse_slowlog_get(response, **options):
space = " " if options.get("decode_responses", False) else b" "
def parse_item(item):
result = {"id": item[0], "start_time": int(item[1]), "duration": int(item[2])}
# Redis Enterprise injects another entry at index [3], which has
# the complexity info (i.e. the value N in case the command has
# an O(N) complexity) instead of the command.
if isinstance(item[3], list):
result["command"] = space.join(item[3])
# These fields are optional, depends on environment.
if len(item) >= 6:
result["client_address"] = item[4]
result["client_name"] = item[5]
else:
result["complexity"] = item[3]
result["command"] = space.join(item[4])
# These fields are optional, depends on environment.
if len(item) >= 7:
result["client_address"] = item[5]
result["client_name"] = item[6]
return result
return [parse_item(item) for item in response]
def parse_stralgo(response, **options):
"""
Parse the response from `STRALGO` command.
Without modifiers the returned value is string.
When LEN is given the command returns the length of the result
(i.e integer).
When IDX is given the command returns a dictionary with the LCS
length and all the ranges in both the strings, start and end
offset for each string, where there are matches.
When WITHMATCHLEN is given, each array representing a match will
also have the length of the match at the beginning of the array.
"""
if options.get("len", False):
return int(response)
if options.get("idx", False):
if options.get("withmatchlen", False):
matches = [
[(int(match[-1]))] + list(map(tuple, match[:-1]))
for match in response[1]
]
else:
matches = [list(map(tuple, match)) for match in response[1]]
return {
str_if_bytes(response[0]): matches,
str_if_bytes(response[2]): int(response[3]),
}
return str_if_bytes(response)
def parse_cluster_info(response, **options):
response = str_if_bytes(response)
return dict(line.split(":") for line in response.splitlines() if line)
def _parse_node_line(line):
line_items = line.split(" ")
node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8]
ip = addr.split("@")[0]
hostname = addr.split("@")[1].split(",")[1] if "@" in addr and "," in addr else ""
node_dict = {
"node_id": node_id,
"hostname": hostname,
"flags": flags,
"master_id": master_id,
"last_ping_sent": ping,
"last_pong_rcvd": pong,
"epoch": epoch,
"slots": [],
"migrations": [],
"connected": True if connected == "connected" else False,
}
if len(line_items) >= 9:
slots, migrations = _parse_slots(line_items[8:])
node_dict["slots"], node_dict["migrations"] = slots, migrations
return ip, node_dict
def _parse_slots(slot_ranges):
slots, migrations = [], []
for s_range in slot_ranges:
if "->-" in s_range:
slot_id, dst_node_id = s_range[1:-1].split("->-", 1)
migrations.append(
{"slot": slot_id, "node_id": dst_node_id, "state": "migrating"}
)
elif "-<-" in s_range:
slot_id, src_node_id = s_range[1:-1].split("-<-", 1)
migrations.append(
{"slot": slot_id, "node_id": src_node_id, "state": "importing"}
)
else:
s_range = [sl for sl in s_range.split("-")]
slots.append(s_range)
return slots, migrations
def parse_cluster_nodes(response, **options):
"""
@see: https://redis.io/commands/cluster-nodes # string / bytes
@see: https://redis.io/commands/cluster-replicas # list of string / bytes
"""
if isinstance(response, (str, bytes)):
response = response.splitlines()
return dict(_parse_node_line(str_if_bytes(node)) for node in response)
def parse_geosearch_generic(response, **options):
"""
Parse the response of 'GEOSEARCH', GEORADIUS' and 'GEORADIUSBYMEMBER'
commands according to 'withdist', 'withhash' and 'withcoord' labels.
"""
try:
if options["store"] or options["store_dist"]:
# `store` and `store_dist` cant be combined
# with other command arguments.
# relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER'
return response
except KeyError: # it means the command was sent via execute_command
return response
if not isinstance(response, list):
response_list = [response]
else:
response_list = response
if not options["withdist"] and not options["withcoord"] and not options["withhash"]:
# just a bunch of places
return response_list
cast = {
"withdist": float,
"withcoord": lambda ll: (float(ll[0]), float(ll[1])),
"withhash": int,
}
# zip all output results with each casting function to get
# the properly native Python value.
f = [lambda x: x]
f += [cast[o] for o in ["withdist", "withhash", "withcoord"] if options[o]]
return [list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list]
def parse_command(response, **options):
commands = {}
for command in response:
cmd_dict = {}
cmd_name = str_if_bytes(command[0])
cmd_dict["name"] = cmd_name
cmd_dict["arity"] = int(command[1])
cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]]
cmd_dict["first_key_pos"] = command[3]
cmd_dict["last_key_pos"] = command[4]
cmd_dict["step_count"] = command[5]
if len(command) > 7:
cmd_dict["tips"] = command[7]
cmd_dict["key_specifications"] = command[8]
cmd_dict["subcommands"] = command[9]
commands[cmd_name] = cmd_dict
return commands
def parse_command_resp3(response, **options):
commands = {}
for command in response:
cmd_dict = {}
cmd_name = str_if_bytes(command[0])
cmd_dict["name"] = cmd_name
cmd_dict["arity"] = command[1]
cmd_dict["flags"] = {str_if_bytes(flag) for flag in command[2]}
cmd_dict["first_key_pos"] = command[3]
cmd_dict["last_key_pos"] = command[4]
cmd_dict["step_count"] = command[5]
cmd_dict["acl_categories"] = command[6]
if len(command) > 7:
cmd_dict["tips"] = command[7]
cmd_dict["key_specifications"] = command[8]
cmd_dict["subcommands"] = command[9]
commands[cmd_name] = cmd_dict
return commands
def parse_pubsub_numsub(response, **options):
return list(zip(response[0::2], response[1::2]))
def parse_client_kill(response, **options):
if isinstance(response, int):
return response
return str_if_bytes(response) == "OK"
def parse_acl_getuser(response, **options):
if response is None:
return None
if isinstance(response, list):
data = pairs_to_dict(response, decode_keys=True)
else:
data = {str_if_bytes(key): value for key, value in response.items()}
# convert everything but user-defined data in 'keys' to native strings
data["flags"] = list(map(str_if_bytes, data["flags"]))
data["passwords"] = list(map(str_if_bytes, data["passwords"]))
data["commands"] = str_if_bytes(data["commands"])
if isinstance(data["keys"], str) or isinstance(data["keys"], bytes):
data["keys"] = list(str_if_bytes(data["keys"]).split(" "))
if data["keys"] == [""]:
data["keys"] = []
if "channels" in data:
if isinstance(data["channels"], str) or isinstance(data["channels"], bytes):
data["channels"] = list(str_if_bytes(data["channels"]).split(" "))
if data["channels"] == [""]:
data["channels"] = []
if "selectors" in data:
if data["selectors"] != [] and isinstance(data["selectors"][0], list):
data["selectors"] = [
list(map(str_if_bytes, selector)) for selector in data["selectors"]
]
elif data["selectors"] != []:
data["selectors"] = [
{str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()}
for selector in data["selectors"]
]
# split 'commands' into separate 'categories' and 'commands' lists
commands, categories = [], []
for command in data["commands"].split(" "):
categories.append(command) if "@" in command else commands.append(command)
data["commands"] = commands
data["categories"] = categories
data["enabled"] = "on" in data["flags"]
return data
def parse_acl_log(response, **options):
if response is None:
return None
if isinstance(response, list):
data = []
for log in response:
log_data = pairs_to_dict(log, True, True)
client_info = log_data.get("client-info", "")
log_data["client-info"] = parse_client_info(client_info)
# float() is lossy comparing to the "double" in C
log_data["age-seconds"] = float(log_data["age-seconds"])
data.append(log_data)
else:
data = bool_ok(response)
return data
def parse_client_info(value):
"""
Parsing client-info in ACL Log in following format.
"key1=value1 key2=value2 key3=value3"
"""
client_info = {}
for info in str_if_bytes(value).strip().split():
key, value = info.split("=")
client_info[key] = value
# Those fields are defined as int in networking.c
for int_key in {
"id",
"age",
"idle",
"db",
"sub",
"psub",
"multi",
"qbuf",
"qbuf-free",
"obl",
"argv-mem",
"oll",
"omem",
"tot-mem",
}:
if int_key in client_info:
client_info[int_key] = int(client_info[int_key])
return client_info
def parse_set_result(response, **options):
"""
Handle SET result since GET argument is available since Redis 6.2.
Parsing SET result into:
- BOOL
- String when GET argument is used
"""
if options.get("get"):
# Redis will return a getCommand result.
# See `setGenericCommand` in t_string.c
return response
return response and str_if_bytes(response) == "OK"
def string_keys_to_dict(key_string, callback):
return dict.fromkeys(key_string.split(), callback)
_RedisCallbacks = {
**string_keys_to_dict(
"AUTH COPY EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST PSETEX "
"PEXPIRE PEXPIREAT RENAMENX SETEX SETNX SMOVE",
bool,
),
**string_keys_to_dict("HINCRBYFLOAT INCRBYFLOAT", float),
**string_keys_to_dict(
"ASKING FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE READONLY READWRITE "
"RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH",
bool_ok,
),
**string_keys_to_dict("XREAD XREADGROUP", parse_xread),
**string_keys_to_dict(
"GEORADIUS GEORADIUSBYMEMBER GEOSEARCH",
parse_geosearch_generic,
),
**string_keys_to_dict("XRANGE XREVRANGE", parse_stream_list),
"ACL GETUSER": parse_acl_getuser,
"ACL LOAD": bool_ok,
"ACL LOG": parse_acl_log,
"ACL SETUSER": bool_ok,
"ACL SAVE": bool_ok,
"CLIENT INFO": parse_client_info,
"CLIENT KILL": parse_client_kill,
"CLIENT LIST": parse_client_list,
"CLIENT PAUSE": bool_ok,
"CLIENT SETINFO": bool_ok,
"CLIENT SETNAME": bool_ok,
"CLIENT UNBLOCK": bool,
"CLUSTER ADDSLOTS": bool_ok,
"CLUSTER ADDSLOTSRANGE": bool_ok,
"CLUSTER DELSLOTS": bool_ok,
"CLUSTER DELSLOTSRANGE": bool_ok,
"CLUSTER FAILOVER": bool_ok,
"CLUSTER FORGET": bool_ok,
"CLUSTER INFO": parse_cluster_info,
"CLUSTER MEET": bool_ok,
"CLUSTER NODES": parse_cluster_nodes,
"CLUSTER REPLICAS": parse_cluster_nodes,
"CLUSTER REPLICATE": bool_ok,
"CLUSTER RESET": bool_ok,
"CLUSTER SAVECONFIG": bool_ok,
"CLUSTER SET-CONFIG-EPOCH": bool_ok,
"CLUSTER SETSLOT": bool_ok,
"CLUSTER SLAVES": parse_cluster_nodes,
"COMMAND": parse_command,
"CONFIG RESETSTAT": bool_ok,
"CONFIG SET": bool_ok,
"FUNCTION DELETE": bool_ok,
"FUNCTION FLUSH": bool_ok,
"FUNCTION RESTORE": bool_ok,
"GEODIST": float_or_none,
"HSCAN": parse_hscan,
"INFO": parse_info,
"LASTSAVE": timestamp_to_datetime,
"MEMORY PURGE": bool_ok,
"MODULE LOAD": bool,
"MODULE UNLOAD": bool,
"PING": lambda r: str_if_bytes(r) == "PONG",
"PUBSUB NUMSUB": parse_pubsub_numsub,
"PUBSUB SHARDNUMSUB": parse_pubsub_numsub,
"QUIT": bool_ok,
"SET": parse_set_result,
"SCAN": parse_scan,
"SCRIPT EXISTS": lambda r: list(map(bool, r)),
"SCRIPT FLUSH": bool_ok,
"SCRIPT KILL": bool_ok,
"SCRIPT LOAD": str_if_bytes,
"SENTINEL CKQUORUM": bool_ok,
"SENTINEL FAILOVER": bool_ok,
"SENTINEL FLUSHCONFIG": bool_ok,
"SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master,
"SENTINEL MONITOR": bool_ok,
"SENTINEL RESET": bool_ok,
"SENTINEL REMOVE": bool_ok,
"SENTINEL SET": bool_ok,
"SLOWLOG GET": parse_slowlog_get,
"SLOWLOG RESET": bool_ok,
"SORT": sort_return_tuples,
"SSCAN": parse_scan,
"TIME": lambda x: (int(x[0]), int(x[1])),
"XAUTOCLAIM": parse_xautoclaim,
"XCLAIM": parse_xclaim,
"XGROUP CREATE": bool_ok,
"XGROUP DESTROY": bool,
"XGROUP SETID": bool_ok,
"XINFO STREAM": parse_xinfo_stream,
"XPENDING": parse_xpending,
"ZSCAN": parse_zscan,
}
_RedisCallbacksRESP2 = {
**string_keys_to_dict(
"SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set()
),
**string_keys_to_dict(
"ZDIFF ZINTER ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZREVRANGE "
"ZREVRANGEBYSCORE ZUNION",
zset_score_pairs,
),
**string_keys_to_dict(
"ZREVRANK ZRANK",
zset_score_for_rank,
),
**string_keys_to_dict("ZINCRBY ZSCORE", float_or_none),
**string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True),
**string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None),
**string_keys_to_dict(
"BZPOPMAX BZPOPMIN", lambda r: r and (r[0], r[1], float(r[2])) or None
),
"ACL CAT": lambda r: list(map(str_if_bytes, r)),
"ACL GENPASS": str_if_bytes,
"ACL HELP": lambda r: list(map(str_if_bytes, r)),
"ACL LIST": lambda r: list(map(str_if_bytes, r)),
"ACL USERS": lambda r: list(map(str_if_bytes, r)),
"ACL WHOAMI": str_if_bytes,
"CLIENT GETNAME": str_if_bytes,
"CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)),
"CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)),
"COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)),
"CONFIG GET": parse_config_get,
"DEBUG OBJECT": parse_debug_object,
"GEOHASH": lambda r: list(map(str_if_bytes, r)),
"GEOPOS": lambda r: list(
map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r)
),
"HGETALL": lambda r: r and pairs_to_dict(r) or {},
"MEMORY STATS": parse_memory_stats,
"MODULE LIST": lambda r: [pairs_to_dict(m) for m in r],
"RESET": str_if_bytes,
"SENTINEL MASTER": parse_sentinel_master,
"SENTINEL MASTERS": parse_sentinel_masters,
"SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels,
"SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels,
"STRALGO": parse_stralgo,
"XINFO CONSUMERS": parse_list_of_dicts,
"XINFO GROUPS": parse_list_of_dicts,
"ZADD": parse_zadd,
"ZMSCORE": parse_zmscore,
}
_RedisCallbacksRESP3 = {
**string_keys_to_dict(
"SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set()
),
**string_keys_to_dict(
"ZRANGE ZINTER ZPOPMAX ZPOPMIN HGETALL XREADGROUP",
lambda r, **kwargs: r,
),
**string_keys_to_dict(
"ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE ZUNION",
zset_score_pairs_resp3,
),
**string_keys_to_dict(
"ZREVRANK ZRANK",
zset_score_for_rank_resp3,
),
**string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3),
"ACL LOG": lambda r: (
[
{str_if_bytes(key): str_if_bytes(value) for key, value in x.items()}
for x in r
]
if isinstance(r, list)
else bool_ok(r)
),
"COMMAND": parse_command_resp3,
"CONFIG GET": lambda r: {
str_if_bytes(key) if key is not None else None: (
str_if_bytes(value) if value is not None else None
)
for key, value in r.items()
},
"MEMORY STATS": lambda r: {str_if_bytes(key): value for key, value in r.items()},
"SENTINEL MASTER": parse_sentinel_state_resp3,
"SENTINEL MASTERS": parse_sentinel_masters_resp3,
"SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels_resp3,
"SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels_resp3,
"STRALGO": lambda r, **options: (
{str_if_bytes(key): str_if_bytes(value) for key, value in r.items()}
if isinstance(r, dict)
else str_if_bytes(r)
),
"XINFO CONSUMERS": lambda r: [
{str_if_bytes(key): value for key, value in x.items()} for x in r
],
"XINFO GROUPS": lambda r: [
{str_if_bytes(key): value for key, value in d.items()} for d in r
],
}

View File

@@ -0,0 +1,301 @@
import asyncio
import socket
import sys
from logging import getLogger
from typing import Callable, List, Optional, TypedDict, Union
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout
from ..exceptions import ConnectionError, InvalidResponse, RedisError
from ..typing import EncodableT
from ..utils import HIREDIS_AVAILABLE
from .base import (
AsyncBaseParser,
AsyncPushNotificationsParser,
BaseParser,
PushNotificationsParser,
)
from .socket import (
NONBLOCKING_EXCEPTION_ERROR_NUMBERS,
NONBLOCKING_EXCEPTIONS,
SENTINEL,
SERVER_CLOSED_CONNECTION_ERROR,
)
# Used to signal that hiredis-py does not have enough data to parse.
# Using `False` or `None` is not reliable, given that the parser can
# return `False` or `None` for legitimate reasons from RESP payloads.
NOT_ENOUGH_DATA = object()
class _HiredisReaderArgs(TypedDict, total=False):
protocolError: Callable[[str], Exception]
replyError: Callable[[str], Exception]
encoding: Optional[str]
errors: Optional[str]
class _HiredisParser(BaseParser, PushNotificationsParser):
"Parser class for connections using Hiredis"
def __init__(self, socket_read_size):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not installed")
self.socket_read_size = socket_read_size
self._buffer = bytearray(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.node_moving_push_handler_func = None
self.maintenance_push_handler_func = None
self.invalidation_push_handler_func = None
self._hiredis_PushNotificationType = None
def __del__(self):
try:
self.on_disconnect()
except Exception:
pass
def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
return response
def on_connect(self, connection, **kwargs):
import hiredis
self._sock = connection._sock
self._socket_timeout = connection.socket_timeout
kwargs = {
"protocolError": InvalidResponse,
"replyError": self.parse_error,
"errors": connection.encoder.encoding_errors,
"notEnoughData": NOT_ENOUGH_DATA,
}
if connection.encoder.decode_responses:
kwargs["encoding"] = connection.encoder.encoding
self._reader = hiredis.Reader(**kwargs)
self._next_response = NOT_ENOUGH_DATA
try:
self._hiredis_PushNotificationType = hiredis.PushNotification
except AttributeError:
# hiredis < 3.2
self._hiredis_PushNotificationType = None
def on_disconnect(self):
self._sock = None
self._reader = None
self._next_response = NOT_ENOUGH_DATA
def can_read(self, timeout):
if not self._reader:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._next_response is NOT_ENOUGH_DATA:
self._next_response = self._reader.gets()
if self._next_response is NOT_ENOUGH_DATA:
return self.read_from_socket(timeout=timeout, raise_on_timeout=False)
return True
def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True):
sock = self._sock
custom_timeout = timeout is not SENTINEL
try:
if custom_timeout:
sock.settimeout(timeout)
bufflen = self._sock.recv_into(self._buffer)
if bufflen == 0:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
self._reader.feed(self._buffer, 0, bufflen)
# data was read from the socket and added to the buffer.
# return True to indicate that data was read.
return True
except socket.timeout:
if raise_on_timeout:
raise TimeoutError("Timeout reading from socket")
return False
except NONBLOCKING_EXCEPTIONS as ex:
# if we're in nonblocking mode and the recv raises a
# blocking error, simply return False indicating that
# there's no data to be read. otherwise raise the
# original exception.
allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
if not raise_on_timeout and ex.errno == allowed:
return False
raise ConnectionError(f"Error while reading from socket: {ex.args}")
finally:
if custom_timeout:
sock.settimeout(self._socket_timeout)
def read_response(self, disable_decoding=False, push_request=False):
if not self._reader:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
# _next_response might be cached from a can_read() call
if self._next_response is not NOT_ENOUGH_DATA:
response = self._next_response
self._next_response = NOT_ENOUGH_DATA
if self._hiredis_PushNotificationType is not None and isinstance(
response, self._hiredis_PushNotificationType
):
response = self.handle_push_response(response)
# if this is a push request return the push response
if push_request:
return response
return self.read_response(
disable_decoding=disable_decoding,
push_request=push_request,
)
return response
if disable_decoding:
response = self._reader.gets(False)
else:
response = self._reader.gets()
while response is NOT_ENOUGH_DATA:
self.read_from_socket()
if disable_decoding:
response = self._reader.gets(False)
else:
response = self._reader.gets()
# if the response is a ConnectionError or the response is a list and
# the first item is a ConnectionError, raise it as something bad
# happened
if isinstance(response, ConnectionError):
raise response
elif self._hiredis_PushNotificationType is not None and isinstance(
response, self._hiredis_PushNotificationType
):
response = self.handle_push_response(response)
if push_request:
return response
return self.read_response(
disable_decoding=disable_decoding,
push_request=push_request,
)
elif (
isinstance(response, list)
and response
and isinstance(response[0], ConnectionError)
):
raise response[0]
return response
class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
"""Async implementation of parser class for connections using Hiredis"""
__slots__ = ("_reader",)
def __init__(self, socket_read_size: int):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not available.")
super().__init__(socket_read_size=socket_read_size)
self._reader = None
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidation_push_handler_func = None
self._hiredis_PushNotificationType = None
async def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
return response
def on_connect(self, connection):
import hiredis
self._stream = connection._reader
kwargs: _HiredisReaderArgs = {
"protocolError": InvalidResponse,
"replyError": self.parse_error,
"notEnoughData": NOT_ENOUGH_DATA,
}
if connection.encoder.decode_responses:
kwargs["encoding"] = connection.encoder.encoding
kwargs["errors"] = connection.encoder.encoding_errors
self._reader = hiredis.Reader(**kwargs)
self._connected = True
try:
self._hiredis_PushNotificationType = getattr(
hiredis, "PushNotification", None
)
except AttributeError:
# hiredis < 3.2
self._hiredis_PushNotificationType = None
def on_disconnect(self):
self._connected = False
async def can_read_destructive(self):
if not self._connected:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._reader.gets() is not NOT_ENOUGH_DATA:
return True
try:
async with async_timeout(0):
return await self.read_from_socket()
except asyncio.TimeoutError:
return False
async def read_from_socket(self):
buffer = await self._stream.read(self._read_size)
if not buffer or not isinstance(buffer, bytes):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
self._reader.feed(buffer)
# data was read from the socket and added to the buffer.
# return True to indicate that data was read.
return True
async def read_response(
self, disable_decoding: bool = False, push_request: bool = False
) -> Union[EncodableT, List[EncodableT]]:
# If `on_disconnect()` has been called, prohibit any more reads
# even if they could happen because data might be present.
# We still allow reads in progress to finish
if not self._connected:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
if disable_decoding:
response = self._reader.gets(False)
else:
response = self._reader.gets()
while response is NOT_ENOUGH_DATA:
await self.read_from_socket()
if disable_decoding:
response = self._reader.gets(False)
else:
response = self._reader.gets()
# if the response is a ConnectionError or the response is a list and
# the first item is a ConnectionError, raise it as something bad
# happened
if isinstance(response, ConnectionError):
raise response
elif self._hiredis_PushNotificationType is not None and isinstance(
response, self._hiredis_PushNotificationType
):
response = await self.handle_push_response(response)
if not push_request:
return await self.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return response
elif (
isinstance(response, list)
and response
and isinstance(response[0], ConnectionError)
):
raise response[0]
return response

View File

@@ -0,0 +1,132 @@
from typing import Any, Union
from ..exceptions import ConnectionError, InvalidResponse, ResponseError
from ..typing import EncodableT
from .base import _AsyncRESPBase, _RESPBase
from .socket import SERVER_CLOSED_CONNECTION_ERROR
class _RESP2Parser(_RESPBase):
"""RESP2 protocol implementation"""
def read_response(self, disable_decoding=False):
pos = self._buffer.get_pos() if self._buffer else None
try:
result = self._read_response(disable_decoding=disable_decoding)
except BaseException:
if self._buffer:
self._buffer.rewind(pos)
raise
else:
self._buffer.purge()
return result
def _read_response(self, disable_decoding=False):
raw = self._buffer.readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
byte, response = raw[:1], raw[1:]
# server returned an error
if byte == b"-":
response = response.decode("utf-8", errors="replace")
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == b"+":
pass
# int value
elif byte == b":":
return int(response)
# bulk response
elif byte == b"$" and response == b"-1":
return None
elif byte == b"$":
response = self._buffer.read(int(response))
# multi-bulk response
elif byte == b"*" and response == b"-1":
return None
elif byte == b"*":
response = [
self._read_response(disable_decoding=disable_decoding)
for i in range(int(response))
]
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
if disable_decoding is False:
response = self.encoder.decode(response)
return response
class _AsyncRESP2Parser(_AsyncRESPBase):
"""Async class for the RESP2 protocol"""
async def read_response(self, disable_decoding: bool = False):
if not self._connected:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._chunks:
# augment parsing buffer with previously read data
self._buffer += b"".join(self._chunks)
self._chunks.clear()
self._pos = 0
response = await self._read_response(disable_decoding=disable_decoding)
# Successfully parsing a response allows us to clear our parsing buffer
self._clear()
return response
async def _read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None]:
raw = await self._readline()
response: Any
byte, response = raw[:1], raw[1:]
# server returned an error
if byte == b"-":
response = response.decode("utf-8", errors="replace")
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
self._clear() # Successful parse
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == b"+":
pass
# int value
elif byte == b":":
return int(response)
# bulk response
elif byte == b"$" and response == b"-1":
return None
elif byte == b"$":
response = await self._read(int(response))
# multi-bulk response
elif byte == b"*" and response == b"-1":
return None
elif byte == b"*":
response = [
(await self._read_response(disable_decoding))
for _ in range(int(response)) # noqa
]
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
if disable_decoding is False:
response = self.encoder.decode(response)
return response

View File

@@ -0,0 +1,263 @@
from logging import getLogger
from typing import Any, Union
from ..exceptions import ConnectionError, InvalidResponse, ResponseError
from ..typing import EncodableT
from .base import (
AsyncPushNotificationsParser,
PushNotificationsParser,
_AsyncRESPBase,
_RESPBase,
)
from .socket import SERVER_CLOSED_CONNECTION_ERROR
class _RESP3Parser(_RESPBase, PushNotificationsParser):
"""RESP3 protocol implementation"""
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.node_moving_push_handler_func = None
self.maintenance_push_handler_func = None
self.invalidation_push_handler_func = None
def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
return response
def read_response(self, disable_decoding=False, push_request=False):
pos = self._buffer.get_pos() if self._buffer else None
try:
result = self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
except BaseException:
if self._buffer:
self._buffer.rewind(pos)
raise
else:
self._buffer.purge()
return result
def _read_response(self, disable_decoding=False, push_request=False):
raw = self._buffer.readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
byte, response = raw[:1], raw[1:]
# server returned an error
if byte in (b"-", b"!"):
if byte == b"!":
response = self._buffer.read(int(response))
response = response.decode("utf-8", errors="replace")
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == b"+":
pass
# null value
elif byte == b"_":
return None
# int and big int values
elif byte in (b":", b"("):
return int(response)
# double value
elif byte == b",":
return float(response)
# bool value
elif byte == b"#":
return response == b"t"
# bulk response
elif byte == b"$":
response = self._buffer.read(int(response))
# verbatim string response
elif byte == b"=":
response = self._buffer.read(int(response))[4:]
# array response
elif byte == b"*":
response = [
self._read_response(disable_decoding=disable_decoding)
for _ in range(int(response))
]
# set response
elif byte == b"~":
# redis can return unhashable types (like dict) in a set,
# so we return sets as list, all the time, for predictability
response = [
self._read_response(disable_decoding=disable_decoding)
for _ in range(int(response))
]
# map response
elif byte == b"%":
# We cannot use a dict-comprehension to parse stream.
# Evaluation order of key:val expression in dict comprehension only
# became defined to be left-right in version 3.8
resp_dict = {}
for _ in range(int(response)):
key = self._read_response(disable_decoding=disable_decoding)
resp_dict[key] = self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
response = resp_dict
# push response
elif byte == b">":
response = [
self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
for _ in range(int(response))
]
response = self.handle_push_response(response)
# if this is a push request return the push response
if push_request:
return response
return self._read_response(
disable_decoding=disable_decoding,
push_request=push_request,
)
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response
class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidation_push_handler_func = None
async def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
return response
async def read_response(
self, disable_decoding: bool = False, push_request: bool = False
):
if self._chunks:
# augment parsing buffer with previously read data
self._buffer += b"".join(self._chunks)
self._chunks.clear()
self._pos = 0
response = await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
# Successfully parsing a response allows us to clear our parsing buffer
self._clear()
return response
async def _read_response(
self, disable_decoding: bool = False, push_request: bool = False
) -> Union[EncodableT, ResponseError, None]:
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
raw = await self._readline()
response: Any
byte, response = raw[:1], raw[1:]
# if byte not in (b"-", b"+", b":", b"$", b"*"):
# raise InvalidResponse(f"Protocol Error: {raw!r}")
# server returned an error
if byte in (b"-", b"!"):
if byte == b"!":
response = await self._read(int(response))
response = response.decode("utf-8", errors="replace")
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
self._clear() # Successful parse
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == b"+":
pass
# null value
elif byte == b"_":
return None
# int and big int values
elif byte in (b":", b"("):
return int(response)
# double value
elif byte == b",":
return float(response)
# bool value
elif byte == b"#":
return response == b"t"
# bulk response
elif byte == b"$":
response = await self._read(int(response))
# verbatim string response
elif byte == b"=":
response = (await self._read(int(response)))[4:]
# array response
elif byte == b"*":
response = [
(await self._read_response(disable_decoding=disable_decoding))
for _ in range(int(response))
]
# set response
elif byte == b"~":
# redis can return unhashable types (like dict) in a set,
# so we always convert to a list, to have predictable return types
response = [
(await self._read_response(disable_decoding=disable_decoding))
for _ in range(int(response))
]
# map response
elif byte == b"%":
# We cannot use a dict-comprehension to parse stream.
# Evaluation order of key:val expression in dict comprehension only
# became defined to be left-right in version 3.8
resp_dict = {}
for _ in range(int(response)):
key = await self._read_response(disable_decoding=disable_decoding)
resp_dict[key] = await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
response = resp_dict
# push response
elif byte == b">":
response = [
(
await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
)
for _ in range(int(response))
]
response = await self.handle_push_response(response)
if not push_request:
return await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return response
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response

View File

@@ -0,0 +1,162 @@
import errno
import io
import socket
from io import SEEK_END
from typing import Optional, Union
from ..exceptions import ConnectionError, TimeoutError
from ..utils import SSL_AVAILABLE
NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK}
if SSL_AVAILABLE:
import ssl
if hasattr(ssl, "SSLWantReadError"):
NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2
NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2
else:
NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2
NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys())
SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
SENTINEL = object()
SYM_CRLF = b"\r\n"
class SocketBuffer:
def __init__(
self, socket: socket.socket, socket_read_size: int, socket_timeout: float
):
self._sock = socket
self.socket_read_size = socket_read_size
self.socket_timeout = socket_timeout
self._buffer = io.BytesIO()
def unread_bytes(self) -> int:
"""
Remaining unread length of buffer
"""
pos = self._buffer.tell()
end = self._buffer.seek(0, SEEK_END)
self._buffer.seek(pos)
return end - pos
def _read_from_socket(
self,
length: Optional[int] = None,
timeout: Union[float, object] = SENTINEL,
raise_on_timeout: Optional[bool] = True,
) -> bool:
sock = self._sock
socket_read_size = self.socket_read_size
marker = 0
custom_timeout = timeout is not SENTINEL
buf = self._buffer
current_pos = buf.tell()
buf.seek(0, SEEK_END)
if custom_timeout:
sock.settimeout(timeout)
try:
while True:
data = self._sock.recv(socket_read_size)
# an empty string indicates the server shutdown the socket
if isinstance(data, bytes) and len(data) == 0:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
buf.write(data)
data_length = len(data)
marker += data_length
if length is not None and length > marker:
continue
return True
except socket.timeout:
if raise_on_timeout:
raise TimeoutError("Timeout reading from socket")
return False
except NONBLOCKING_EXCEPTIONS as ex:
# if we're in nonblocking mode and the recv raises a
# blocking error, simply return False indicating that
# there's no data to be read. otherwise raise the
# original exception.
allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
if not raise_on_timeout and ex.errno == allowed:
return False
raise ConnectionError(f"Error while reading from socket: {ex.args}")
finally:
buf.seek(current_pos)
if custom_timeout:
sock.settimeout(self.socket_timeout)
def can_read(self, timeout: float) -> bool:
return bool(self.unread_bytes()) or self._read_from_socket(
timeout=timeout, raise_on_timeout=False
)
def read(self, length: int) -> bytes:
length = length + 2 # make sure to read the \r\n terminator
# BufferIO will return less than requested if buffer is short
data = self._buffer.read(length)
missing = length - len(data)
if missing:
# fill up the buffer and read the remainder
self._read_from_socket(missing)
data += self._buffer.read(missing)
return data[:-2]
def readline(self) -> bytes:
buf = self._buffer
data = buf.readline()
while not data.endswith(SYM_CRLF):
# there's more data in the socket that we need
self._read_from_socket()
data += buf.readline()
return data[:-2]
def get_pos(self) -> int:
"""
Get current read position
"""
return self._buffer.tell()
def rewind(self, pos: int) -> None:
"""
Rewind the buffer to a specific position, to re-start reading
"""
self._buffer.seek(pos)
def purge(self) -> None:
"""
After a successful read, purge the read part of buffer
"""
unread = self.unread_bytes()
# Only if we have read all of the buffer do we truncate, to
# reduce the amount of memory thrashing. This heuristic
# can be changed or removed later.
if unread > 0:
return
if unread > 0:
# move unread data to the front
view = self._buffer.getbuffer()
view[:unread] = view[-unread:]
self._buffer.truncate(unread)
self._buffer.seek(0)
def close(self) -> None:
try:
self._buffer.close()
except Exception:
# issue #633 suggests the purge/close somehow raised a
# BadFileDescriptor error. Perhaps the client ran out of
# memory or something else? It's probably OK to ignore
# any error being raised from purge/close since we're
# removing the reference to the instance below.
pass
self._buffer = None
self._sock = None

View File

@@ -0,0 +1,64 @@
from redis.asyncio.client import Redis, StrictRedis
from redis.asyncio.cluster import RedisCluster
from redis.asyncio.connection import (
BlockingConnectionPool,
Connection,
ConnectionPool,
SSLConnection,
UnixDomainSocketConnection,
)
from redis.asyncio.sentinel import (
Sentinel,
SentinelConnectionPool,
SentinelManagedConnection,
SentinelManagedSSLConnection,
)
from redis.asyncio.utils import from_url
from redis.backoff import default_backoff
from redis.exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
BusyLoadingError,
ChildDeadlockedError,
ConnectionError,
DataError,
InvalidResponse,
OutOfMemoryError,
PubSubError,
ReadOnlyError,
RedisError,
ResponseError,
TimeoutError,
WatchError,
)
__all__ = [
"AuthenticationError",
"AuthenticationWrongNumberOfArgsError",
"BlockingConnectionPool",
"BusyLoadingError",
"ChildDeadlockedError",
"Connection",
"ConnectionError",
"ConnectionPool",
"DataError",
"from_url",
"default_backoff",
"InvalidResponse",
"PubSubError",
"OutOfMemoryError",
"ReadOnlyError",
"Redis",
"RedisCluster",
"RedisError",
"ResponseError",
"Sentinel",
"SentinelConnectionPool",
"SentinelManagedConnection",
"SentinelManagedSSLConnection",
"SSLConnection",
"StrictRedis",
"TimeoutError",
"UnixDomainSocketConnection",
"WatchError",
]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,265 @@
import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Mapping, Optional, Union
from redis.http.http_client import HttpClient, HttpResponse
DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)"
DEFAULT_TIMEOUT = 30.0
RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
class AsyncHTTPClient(ABC):
@abstractmethod
async def get(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP GET request."""
pass
@abstractmethod
async def delete(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP DELETE request."""
pass
@abstractmethod
async def post(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP POST request."""
pass
@abstractmethod
async def put(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP PUT request."""
pass
@abstractmethod
async def patch(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP PATCH request."""
pass
@abstractmethod
async def request(
self,
method: str,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
body: Optional[Union[bytes, str]] = None,
timeout: Optional[float] = None,
) -> HttpResponse:
"""
Invoke HTTP request with given method."""
pass
class AsyncHTTPClientWrapper(AsyncHTTPClient):
"""
An async wrapper around sync HTTP client with thread pool execution.
"""
def __init__(self, client: HttpClient, max_workers: int = 10) -> None:
"""
Initialize a new HTTP client instance.
Args:
client: Sync HTTP client instance.
max_workers: Maximum number of concurrent requests.
The client supports both regular HTTPS with server verification and mutual TLS
authentication. For server verification, provide CA certificate information via
ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client
certificate and key via client_cert_file and client_key_file.
"""
self.client = client
self._executor = ThreadPoolExecutor(max_workers=max_workers)
async def get(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor, self.client.get, path, params, headers, timeout, expect_json
)
async def delete(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.delete,
path,
params,
headers,
timeout,
expect_json,
)
async def post(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.post,
path,
json_body,
data,
params,
headers,
timeout,
expect_json,
)
async def put(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.put,
path,
json_body,
data,
params,
headers,
timeout,
expect_json,
)
async def patch(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.patch,
path,
json_body,
data,
params,
headers,
timeout,
expect_json,
)
async def request(
self,
method: str,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
body: Optional[Union[bytes, str]] = None,
timeout: Optional[float] = None,
) -> HttpResponse:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.request,
method,
path,
params,
headers,
body,
timeout,
)

View File

@@ -0,0 +1,334 @@
import asyncio
import logging
import threading
import uuid
from types import SimpleNamespace
from typing import TYPE_CHECKING, Awaitable, Optional, Union
from redis.exceptions import LockError, LockNotOwnedError
from redis.typing import Number
if TYPE_CHECKING:
from redis.asyncio import Redis, RedisCluster
logger = logging.getLogger(__name__)
class Lock:
"""
A shared, distributed Lock. Using Redis for locking allows the Lock
to be shared across processes and/or machines.
It's left to the user to resolve deadlock issues and make sure
multiple clients play nicely together.
"""
lua_release = None
lua_extend = None
lua_reacquire = None
# KEYS[1] - lock name
# ARGV[1] - token
# return 1 if the lock was released, otherwise 0
LUA_RELEASE_SCRIPT = """
local token = redis.call('get', KEYS[1])
if not token or token ~= ARGV[1] then
return 0
end
redis.call('del', KEYS[1])
return 1
"""
# KEYS[1] - lock name
# ARGV[1] - token
# ARGV[2] - additional milliseconds
# ARGV[3] - "0" if the additional time should be added to the lock's
# existing ttl or "1" if the existing ttl should be replaced
# return 1 if the locks time was extended, otherwise 0
LUA_EXTEND_SCRIPT = """
local token = redis.call('get', KEYS[1])
if not token or token ~= ARGV[1] then
return 0
end
local expiration = redis.call('pttl', KEYS[1])
if not expiration then
expiration = 0
end
if expiration < 0 then
return 0
end
local newttl = ARGV[2]
if ARGV[3] == "0" then
newttl = ARGV[2] + expiration
end
redis.call('pexpire', KEYS[1], newttl)
return 1
"""
# KEYS[1] - lock name
# ARGV[1] - token
# ARGV[2] - milliseconds
# return 1 if the locks time was reacquired, otherwise 0
LUA_REACQUIRE_SCRIPT = """
local token = redis.call('get', KEYS[1])
if not token or token ~= ARGV[1] then
return 0
end
redis.call('pexpire', KEYS[1], ARGV[2])
return 1
"""
def __init__(
self,
redis: Union["Redis", "RedisCluster"],
name: Union[str, bytes, memoryview],
timeout: Optional[float] = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: Optional[Number] = None,
thread_local: bool = True,
raise_on_release_error: bool = True,
):
"""
Create a new Lock instance named ``name`` using the Redis client
supplied by ``redis``.
``timeout`` indicates a maximum life for the lock in seconds.
By default, it will remain locked until release() is called.
``timeout`` can be specified as a float or integer, both representing
the number of seconds to wait.
``sleep`` indicates the amount of time to sleep in seconds per loop
iteration when the lock is in blocking mode and another client is
currently holding the lock.
``blocking`` indicates whether calling ``acquire`` should block until
the lock has been acquired or to fail immediately, causing ``acquire``
to return False and the lock not being acquired. Defaults to True.
Note this value can be overridden by passing a ``blocking``
argument to ``acquire``.
``blocking_timeout`` indicates the maximum amount of time in seconds to
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
float or integer, both representing the number of seconds to wait.
``thread_local`` indicates whether the lock token is placed in
thread-local storage. By default, the token is placed in thread local
storage so that a thread only sees its token, not a token set by
another thread. Consider the following timeline:
time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
thread-1 sets the token to "abc"
time: 1, thread-2 blocks trying to acquire `my-lock` using the
Lock instance.
time: 5, thread-1 has not yet completed. redis expires the lock
key.
time: 5, thread-2 acquired `my-lock` now that it's available.
thread-2 sets the token to "xyz"
time: 6, thread-1 finishes its work and calls release(). if the
token is *not* stored in thread local storage, then
thread-1 would see the token value as "xyz" and would be
able to successfully release the thread-2's lock.
``raise_on_release_error`` indicates whether to raise an exception when
the lock is no longer owned when exiting the context manager. By default,
this is True, meaning an exception will be raised. If False, the warning
will be logged and the exception will be suppressed.
In some use cases it's necessary to disable thread local storage. For
example, if you have code where one thread acquires a lock and passes
that lock instance to a worker thread to release later. If thread
local storage isn't disabled in this case, the worker thread won't see
the token set by the thread that acquired the lock. Our assumption
is that these cases aren't common and as such default to using
thread local storage.
"""
self.redis = redis
self.name = name
self.timeout = timeout
self.sleep = sleep
self.blocking = blocking
self.blocking_timeout = blocking_timeout
self.thread_local = bool(thread_local)
self.local = threading.local() if self.thread_local else SimpleNamespace()
self.raise_on_release_error = raise_on_release_error
self.local.token = None
self.register_scripts()
def register_scripts(self):
cls = self.__class__
client = self.redis
if cls.lua_release is None:
cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT)
if cls.lua_extend is None:
cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT)
if cls.lua_reacquire is None:
cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)
async def __aenter__(self):
if await self.acquire():
return self
raise LockError("Unable to acquire lock within the time specified")
async def __aexit__(self, exc_type, exc_value, traceback):
try:
await self.release()
except LockError:
if self.raise_on_release_error:
raise
logger.warning(
"Lock was unlocked or no longer owned when exiting context manager."
)
async def acquire(
self,
blocking: Optional[bool] = None,
blocking_timeout: Optional[Number] = None,
token: Optional[Union[str, bytes]] = None,
):
"""
Use Redis to hold a shared, distributed lock named ``name``.
Returns True once the lock is acquired.
If ``blocking`` is False, always return immediately. If the lock
was acquired, return True, otherwise return False.
``blocking_timeout`` specifies the maximum number of seconds to
wait trying to acquire the lock.
``token`` specifies the token value to be used. If provided, token
must be a bytes object or a string that can be encoded to a bytes
object with the default encoding. If a token isn't specified, a UUID
will be generated.
"""
sleep = self.sleep
if token is None:
token = uuid.uuid1().hex.encode()
else:
try:
encoder = self.redis.connection_pool.get_encoder()
except AttributeError:
# Cluster
encoder = self.redis.get_encoder()
token = encoder.encode(token)
if blocking is None:
blocking = self.blocking
if blocking_timeout is None:
blocking_timeout = self.blocking_timeout
stop_trying_at = None
if blocking_timeout is not None:
stop_trying_at = asyncio.get_running_loop().time() + blocking_timeout
while True:
if await self.do_acquire(token):
self.local.token = token
return True
if not blocking:
return False
next_try_at = asyncio.get_running_loop().time() + sleep
if stop_trying_at is not None and next_try_at > stop_trying_at:
return False
await asyncio.sleep(sleep)
async def do_acquire(self, token: Union[str, bytes]) -> bool:
if self.timeout:
# convert to milliseconds
timeout = int(self.timeout * 1000)
else:
timeout = None
if await self.redis.set(self.name, token, nx=True, px=timeout):
return True
return False
async def locked(self) -> bool:
"""
Returns True if this key is locked by any process, otherwise False.
"""
return await self.redis.get(self.name) is not None
async def owned(self) -> bool:
"""
Returns True if this key is locked by this lock, otherwise False.
"""
stored_token = await self.redis.get(self.name)
# need to always compare bytes to bytes
# TODO: this can be simplified when the context manager is finished
if stored_token and not isinstance(stored_token, bytes):
try:
encoder = self.redis.connection_pool.get_encoder()
except AttributeError:
# Cluster
encoder = self.redis.get_encoder()
stored_token = encoder.encode(stored_token)
return self.local.token is not None and stored_token == self.local.token
def release(self) -> Awaitable[None]:
"""Releases the already acquired lock"""
expected_token = self.local.token
if expected_token is None:
raise LockError(
"Cannot release a lock that's not owned or is already unlocked.",
lock_name=self.name,
)
self.local.token = None
return self.do_release(expected_token)
async def do_release(self, expected_token: bytes) -> None:
if not bool(
await self.lua_release(
keys=[self.name], args=[expected_token], client=self.redis
)
):
raise LockNotOwnedError("Cannot release a lock that's no longer owned")
def extend(
self, additional_time: Number, replace_ttl: bool = False
) -> Awaitable[bool]:
"""
Adds more time to an already acquired lock.
``additional_time`` can be specified as an integer or a float, both
representing the number of seconds to add.
``replace_ttl`` if False (the default), add `additional_time` to
the lock's existing ttl. If True, replace the lock's ttl with
`additional_time`.
"""
if self.local.token is None:
raise LockError("Cannot extend an unlocked lock")
if self.timeout is None:
raise LockError("Cannot extend a lock with no timeout")
return self.do_extend(additional_time, replace_ttl)
async def do_extend(self, additional_time, replace_ttl) -> bool:
additional_time = int(additional_time * 1000)
if not bool(
await self.lua_extend(
keys=[self.name],
args=[self.local.token, additional_time, replace_ttl and "1" or "0"],
client=self.redis,
)
):
raise LockNotOwnedError("Cannot extend a lock that's no longer owned")
return True
def reacquire(self) -> Awaitable[bool]:
"""
Resets a TTL of an already acquired lock back to a timeout value.
"""
if self.local.token is None:
raise LockError("Cannot reacquire an unlocked lock")
if self.timeout is None:
raise LockError("Cannot reacquire a lock with no timeout")
return self.do_reacquire()
async def do_reacquire(self) -> bool:
timeout = int(self.timeout * 1000)
if not bool(
await self.lua_reacquire(
keys=[self.name], args=[self.local.token, timeout], client=self.redis
)
):
raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned")
return True

View File

@@ -0,0 +1,530 @@
import asyncio
import logging
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Union
from redis.asyncio.client import PubSubHandler
from redis.asyncio.multidb.command_executor import DefaultCommandExecutor
from redis.asyncio.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig
from redis.asyncio.multidb.database import AsyncDatabase, Databases
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
from redis.asyncio.multidb.healthcheck import HealthCheck, HealthCheckPolicy
from redis.background import BackgroundScheduler
from redis.commands import AsyncCoreCommands, AsyncRedisModuleCommands
from redis.multidb.circuit import CircuitBreaker
from redis.multidb.circuit import State as CBState
from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException
from redis.typing import ChannelT, EncodableT, KeyT
from redis.utils import experimental
logger = logging.getLogger(__name__)
@experimental
class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands):
"""
Client that operates on multiple logical Redis databases.
Should be used in Active-Active database setups.
"""
def __init__(self, config: MultiDbConfig):
self._databases = config.databases()
self._health_checks = config.default_health_checks()
if config.health_checks is not None:
self._health_checks.extend(config.health_checks)
self._health_check_interval = config.health_check_interval
self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value(
config.health_check_probes, config.health_check_delay
)
self._failure_detectors = config.default_failure_detectors()
if config.failure_detectors is not None:
self._failure_detectors.extend(config.failure_detectors)
self._failover_strategy = (
config.default_failover_strategy()
if config.failover_strategy is None
else config.failover_strategy
)
self._failover_strategy.set_databases(self._databases)
self._auto_fallback_interval = config.auto_fallback_interval
self._event_dispatcher = config.event_dispatcher
self._command_retry = config.command_retry
self._command_retry.update_supported_errors([ConnectionRefusedError])
self.command_executor = DefaultCommandExecutor(
failure_detectors=self._failure_detectors,
databases=self._databases,
command_retry=self._command_retry,
failover_strategy=self._failover_strategy,
failover_attempts=config.failover_attempts,
failover_delay=config.failover_delay,
event_dispatcher=self._event_dispatcher,
auto_fallback_interval=self._auto_fallback_interval,
)
self.initialized = False
self._hc_lock = asyncio.Lock()
self._bg_scheduler = BackgroundScheduler()
self._config = config
self._recurring_hc_task = None
self._hc_tasks = []
self._half_open_state_task = None
async def __aenter__(self: "MultiDBClient") -> "MultiDBClient":
if not self.initialized:
await self.initialize()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
if self._recurring_hc_task:
self._recurring_hc_task.cancel()
if self._half_open_state_task:
self._half_open_state_task.cancel()
for hc_task in self._hc_tasks:
hc_task.cancel()
async def initialize(self):
"""
Perform initialization of databases to define their initial state.
"""
async def raise_exception_on_failed_hc(error):
raise error
# Initial databases check to define initial state
await self._check_databases_health(on_error=raise_exception_on_failed_hc)
# Starts recurring health checks on the background.
self._recurring_hc_task = asyncio.create_task(
self._bg_scheduler.run_recurring_async(
self._health_check_interval,
self._check_databases_health,
)
)
is_active_db_found = False
for database, weight in self._databases:
# Set on state changed callback for each circuit.
database.circuit.on_state_changed(self._on_circuit_state_change_callback)
# Set states according to a weights and circuit state
if database.circuit.state == CBState.CLOSED and not is_active_db_found:
await self.command_executor.set_active_database(database)
is_active_db_found = True
if not is_active_db_found:
raise NoValidDatabaseException(
"Initial connection failed - no active database found"
)
self.initialized = True
def get_databases(self) -> Databases:
"""
Returns a sorted (by weight) list of all databases.
"""
return self._databases
async def set_active_database(self, database: AsyncDatabase) -> None:
"""
Promote one of the existing databases to become an active.
"""
exists = None
for existing_db, _ in self._databases:
if existing_db == database:
exists = True
break
if not exists:
raise ValueError("Given database is not a member of database list")
await self._check_db_health(database)
if database.circuit.state == CBState.CLOSED:
highest_weighted_db, _ = self._databases.get_top_n(1)[0]
await self.command_executor.set_active_database(database)
return
raise NoValidDatabaseException(
"Cannot set active database, database is unhealthy"
)
async def add_database(self, database: AsyncDatabase):
"""
Adds a new database to the database list.
"""
for existing_db, _ in self._databases:
if existing_db == database:
raise ValueError("Given database already exists")
await self._check_db_health(database)
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
self._databases.add(database, database.weight)
await self._change_active_database(database, highest_weighted_db)
async def _change_active_database(
self, new_database: AsyncDatabase, highest_weight_database: AsyncDatabase
):
if (
new_database.weight > highest_weight_database.weight
and new_database.circuit.state == CBState.CLOSED
):
await self.command_executor.set_active_database(new_database)
async def remove_database(self, database: AsyncDatabase):
"""
Removes a database from the database list.
"""
weight = self._databases.remove(database)
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
if (
highest_weight <= weight
and highest_weighted_db.circuit.state == CBState.CLOSED
):
await self.command_executor.set_active_database(highest_weighted_db)
async def update_database_weight(self, database: AsyncDatabase, weight: float):
"""
Updates a database from the database list.
"""
exists = None
for existing_db, _ in self._databases:
if existing_db == database:
exists = True
break
if not exists:
raise ValueError("Given database is not a member of database list")
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
self._databases.update_weight(database, weight)
database.weight = weight
await self._change_active_database(database, highest_weighted_db)
def add_failure_detector(self, failure_detector: AsyncFailureDetector):
"""
Adds a new failure detector to the database.
"""
self._failure_detectors.append(failure_detector)
async def add_health_check(self, healthcheck: HealthCheck):
"""
Adds a new health check to the database.
"""
async with self._hc_lock:
self._health_checks.append(healthcheck)
async def execute_command(self, *args, **options):
"""
Executes a single command and return its result.
"""
if not self.initialized:
await self.initialize()
return await self.command_executor.execute_command(*args, **options)
def pipeline(self):
"""
Enters into pipeline mode of the client.
"""
return Pipeline(self)
async def transaction(
self,
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
*watches: KeyT,
shard_hint: Optional[str] = None,
value_from_callable: bool = False,
watch_delay: Optional[float] = None,
):
"""
Executes callable as transaction.
"""
if not self.initialized:
await self.initialize()
return await self.command_executor.execute_transaction(
func,
*watches,
shard_hint=shard_hint,
value_from_callable=value_from_callable,
watch_delay=watch_delay,
)
async def pubsub(self, **kwargs):
"""
Return a Publish/Subscribe object. With this object, you can
subscribe to channels and listen for messages that get published to
them.
"""
if not self.initialized:
await self.initialize()
return PubSub(self, **kwargs)
async def _check_databases_health(
self,
on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None,
):
"""
Runs health checks as a recurring task.
Runs health checks against all databases.
"""
try:
self._hc_tasks = [
asyncio.create_task(self._check_db_health(database))
for database, _ in self._databases
]
results = await asyncio.wait_for(
asyncio.gather(
*self._hc_tasks,
return_exceptions=True,
),
timeout=self._health_check_interval,
)
except asyncio.TimeoutError:
raise asyncio.TimeoutError(
"Health check execution exceeds health_check_interval"
)
for result in results:
if isinstance(result, UnhealthyDatabaseException):
unhealthy_db = result.database
unhealthy_db.circuit.state = CBState.OPEN
logger.exception(
"Health check failed, due to exception",
exc_info=result.original_exception,
)
if on_error:
on_error(result.original_exception)
async def _check_db_health(self, database: AsyncDatabase) -> bool:
"""
Runs health checks on the given database until first failure.
"""
# Health check will setup circuit state
is_healthy = await self._health_check_policy.execute(
self._health_checks, database
)
if not is_healthy:
if database.circuit.state != CBState.OPEN:
database.circuit.state = CBState.OPEN
return is_healthy
elif is_healthy and database.circuit.state != CBState.CLOSED:
database.circuit.state = CBState.CLOSED
return is_healthy
def _on_circuit_state_change_callback(
self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState
):
loop = asyncio.get_running_loop()
if new_state == CBState.HALF_OPEN:
self._half_open_state_task = asyncio.create_task(
self._check_db_health(circuit.database)
)
return
if old_state == CBState.CLOSED and new_state == CBState.OPEN:
loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit)
async def aclose(self):
if self.command_executor.active_database:
await self.command_executor.active_database.client.aclose()
def _half_open_circuit(circuit: CircuitBreaker):
circuit.state = CBState.HALF_OPEN
class Pipeline(AsyncRedisModuleCommands, AsyncCoreCommands):
"""
Pipeline implementation for multiple logical Redis databases.
"""
def __init__(self, client: MultiDBClient):
self._command_stack = []
self._client = client
async def __aenter__(self: "Pipeline") -> "Pipeline":
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.reset()
await self._client.__aexit__(exc_type, exc_value, traceback)
def __await__(self):
return self._async_self().__await__()
async def _async_self(self):
return self
def __len__(self) -> int:
return len(self._command_stack)
def __bool__(self) -> bool:
"""Pipeline instances should always evaluate to True"""
return True
async def reset(self) -> None:
self._command_stack = []
async def aclose(self) -> None:
"""Close the pipeline"""
await self.reset()
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
"""
Stage a command to be executed when execute() is next called
Returns the current Pipeline object back so commands can be
chained together, such as:
pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
At some other point, you can then run: pipe.execute(),
which will execute all commands queued in the pipe.
"""
self._command_stack.append((args, options))
return self
def execute_command(self, *args, **kwargs):
"""Adds a command to the stack"""
return self.pipeline_execute_command(*args, **kwargs)
async def execute(self) -> List[Any]:
"""Execute all the commands in the current pipeline"""
if not self._client.initialized:
await self._client.initialize()
try:
return await self._client.command_executor.execute_pipeline(
tuple(self._command_stack)
)
finally:
await self.reset()
class PubSub:
"""
PubSub object for multi database client.
"""
def __init__(self, client: MultiDBClient, **kwargs):
"""Initialize the PubSub object for a multi-database client.
Args:
client: MultiDBClient instance to use for pub/sub operations
**kwargs: Additional keyword arguments to pass to the underlying pubsub implementation
"""
self._client = client
self._client.command_executor.pubsub(**kwargs)
async def __aenter__(self) -> "PubSub":
return self
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
await self.aclose()
async def aclose(self):
return await self._client.command_executor.execute_pubsub_method("aclose")
@property
def subscribed(self) -> bool:
return self._client.command_executor.active_pubsub.subscribed
async def execute_command(self, *args: EncodableT):
return await self._client.command_executor.execute_pubsub_method(
"execute_command", *args
)
async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
"""
Subscribe to channel patterns. Patterns supplied as keyword arguments
expect a pattern name as the key and a callable as the value. A
pattern's callable will be invoked automatically when a message is
received on that pattern rather than producing a message via
``listen()``.
"""
return await self._client.command_executor.execute_pubsub_method(
"psubscribe", *args, **kwargs
)
async def punsubscribe(self, *args: ChannelT):
"""
Unsubscribe from the supplied patterns. If empty, unsubscribe from
all patterns.
"""
return await self._client.command_executor.execute_pubsub_method(
"punsubscribe", *args
)
async def subscribe(self, *args: ChannelT, **kwargs: Callable):
"""
Subscribe to channels. Channels supplied as keyword arguments expect
a channel name as the key and a callable as the value. A channel's
callable will be invoked automatically when a message is received on
that channel rather than producing a message via ``listen()`` or
``get_message()``.
"""
return await self._client.command_executor.execute_pubsub_method(
"subscribe", *args, **kwargs
)
async def unsubscribe(self, *args):
"""
Unsubscribe from the supplied channels. If empty, unsubscribe from
all channels
"""
return await self._client.command_executor.execute_pubsub_method(
"unsubscribe", *args
)
async def get_message(
self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
):
"""
Get the next message if one is available, otherwise None.
If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number or None to wait indefinitely.
"""
return await self._client.command_executor.execute_pubsub_method(
"get_message",
ignore_subscribe_messages=ignore_subscribe_messages,
timeout=timeout,
)
async def run(
self,
*,
exception_handler=None,
poll_timeout: float = 1.0,
) -> None:
"""Process pub/sub messages using registered callbacks.
This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
redis-py, but it is a coroutine. To launch it as a separate task, use
``asyncio.create_task``:
>>> task = asyncio.create_task(pubsub.run())
To shut it down, use asyncio cancellation:
>>> task.cancel()
>>> await task
"""
return await self._client.command_executor.execute_pubsub_run(
sleep_time=poll_timeout, exception_handler=exception_handler, pubsub=self
)

View File

@@ -0,0 +1,339 @@
from abc import abstractmethod
from asyncio import iscoroutinefunction
from datetime import datetime
from typing import Any, Awaitable, Callable, List, Optional, Union
from redis.asyncio import RedisCluster
from redis.asyncio.client import Pipeline, PubSub
from redis.asyncio.multidb.database import AsyncDatabase, Database, Databases
from redis.asyncio.multidb.event import (
AsyncActiveDatabaseChanged,
CloseConnectionOnActiveDatabaseChanged,
RegisterCommandFailure,
ResubscribeOnActiveDatabaseChanged,
)
from redis.asyncio.multidb.failover import (
DEFAULT_FAILOVER_ATTEMPTS,
DEFAULT_FAILOVER_DELAY,
AsyncFailoverStrategy,
DefaultFailoverStrategyExecutor,
FailoverStrategyExecutor,
)
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
from redis.asyncio.retry import Retry
from redis.event import AsyncOnCommandsFailEvent, EventDispatcherInterface
from redis.multidb.circuit import State as CBState
from redis.multidb.command_executor import BaseCommandExecutor, CommandExecutor
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
from redis.typing import KeyT
class AsyncCommandExecutor(CommandExecutor):
@property
@abstractmethod
def databases(self) -> Databases:
"""Returns a list of databases."""
pass
@property
@abstractmethod
def failure_detectors(self) -> List[AsyncFailureDetector]:
"""Returns a list of failure detectors."""
pass
@abstractmethod
def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
"""Adds a new failure detector to the list of failure detectors."""
pass
@property
@abstractmethod
def active_database(self) -> Optional[AsyncDatabase]:
"""Returns currently active database."""
pass
@abstractmethod
async def set_active_database(self, database: AsyncDatabase) -> None:
"""Sets the currently active database."""
pass
@property
@abstractmethod
def active_pubsub(self) -> Optional[PubSub]:
"""Returns currently active pubsub."""
pass
@active_pubsub.setter
@abstractmethod
def active_pubsub(self, pubsub: PubSub) -> None:
"""Sets currently active pubsub."""
pass
@property
@abstractmethod
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
"""Returns failover strategy executor."""
pass
@property
@abstractmethod
def command_retry(self) -> Retry:
"""Returns command retry object."""
pass
@abstractmethod
async def pubsub(self, **kwargs):
"""Initializes a PubSub object on a currently active database"""
pass
@abstractmethod
async def execute_command(self, *args, **options):
"""Executes a command and returns the result."""
pass
@abstractmethod
async def execute_pipeline(self, command_stack: tuple):
"""Executes a stack of commands in pipeline."""
pass
@abstractmethod
async def execute_transaction(
self, transaction: Callable[[Pipeline], None], *watches, **options
):
"""Executes a transaction block wrapped in callback."""
pass
@abstractmethod
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
"""Executes a given method on active pub/sub."""
pass
@abstractmethod
async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any:
"""Executes pub/sub run in a thread."""
pass
class DefaultCommandExecutor(BaseCommandExecutor, AsyncCommandExecutor):
def __init__(
self,
failure_detectors: List[AsyncFailureDetector],
databases: Databases,
command_retry: Retry,
failover_strategy: AsyncFailoverStrategy,
event_dispatcher: EventDispatcherInterface,
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
failover_delay: float = DEFAULT_FAILOVER_DELAY,
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
):
"""
Initialize the DefaultCommandExecutor instance.
Args:
failure_detectors: List of failure detector instances to monitor database health
databases: Collection of available databases to execute commands on
command_retry: Retry policy for failed command execution
failover_strategy: Strategy for handling database failover
event_dispatcher: Interface for dispatching events
failover_attempts: Number of failover attempts
failover_delay: Delay between failover attempts
auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database
"""
super().__init__(auto_fallback_interval)
for fd in failure_detectors:
fd.set_command_executor(command_executor=self)
self._databases = databases
self._failure_detectors = failure_detectors
self._command_retry = command_retry
self._failover_strategy_executor = DefaultFailoverStrategyExecutor(
failover_strategy, failover_attempts, failover_delay
)
self._event_dispatcher = event_dispatcher
self._active_database: Optional[Database] = None
self._active_pubsub: Optional[PubSub] = None
self._active_pubsub_kwargs = {}
self._setup_event_dispatcher()
self._schedule_next_fallback()
@property
def databases(self) -> Databases:
return self._databases
@property
def failure_detectors(self) -> List[AsyncFailureDetector]:
return self._failure_detectors
def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
self._failure_detectors.append(failure_detector)
@property
def active_database(self) -> Optional[AsyncDatabase]:
return self._active_database
async def set_active_database(self, database: AsyncDatabase) -> None:
old_active = self._active_database
self._active_database = database
if old_active is not None and old_active is not database:
await self._event_dispatcher.dispatch_async(
AsyncActiveDatabaseChanged(
old_active,
self._active_database,
self,
**self._active_pubsub_kwargs,
)
)
@property
def active_pubsub(self) -> Optional[PubSub]:
return self._active_pubsub
@active_pubsub.setter
def active_pubsub(self, pubsub: PubSub) -> None:
self._active_pubsub = pubsub
@property
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
return self._failover_strategy_executor
@property
def command_retry(self) -> Retry:
return self._command_retry
def pubsub(self, **kwargs):
if self._active_pubsub is None:
if isinstance(self._active_database.client, RedisCluster):
raise ValueError("PubSub is not supported for RedisCluster")
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
self._active_pubsub_kwargs = kwargs
async def execute_command(self, *args, **options):
async def callback():
response = await self._active_database.client.execute_command(
*args, **options
)
await self._register_command_execution(args)
return response
return await self._execute_with_failure_detection(callback, args)
async def execute_pipeline(self, command_stack: tuple):
async def callback():
async with self._active_database.client.pipeline() as pipe:
for command, options in command_stack:
pipe.execute_command(*command, **options)
response = await pipe.execute()
await self._register_command_execution(command_stack)
return response
return await self._execute_with_failure_detection(callback, command_stack)
async def execute_transaction(
self,
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
*watches: KeyT,
shard_hint: Optional[str] = None,
value_from_callable: bool = False,
watch_delay: Optional[float] = None,
):
async def callback():
response = await self._active_database.client.transaction(
func,
*watches,
shard_hint=shard_hint,
value_from_callable=value_from_callable,
watch_delay=watch_delay,
)
await self._register_command_execution(())
return response
return await self._execute_with_failure_detection(callback)
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
async def callback():
method = getattr(self.active_pubsub, method_name)
if iscoroutinefunction(method):
response = await method(*args, **kwargs)
else:
response = method(*args, **kwargs)
await self._register_command_execution(args)
return response
return await self._execute_with_failure_detection(callback, *args)
async def execute_pubsub_run(
self, sleep_time: float, exception_handler=None, pubsub=None
) -> Any:
async def callback():
return await self._active_pubsub.run(
poll_timeout=sleep_time,
exception_handler=exception_handler,
pubsub=pubsub,
)
return await self._execute_with_failure_detection(callback)
async def _execute_with_failure_detection(
self, callback: Callable, cmds: tuple = ()
):
"""
Execute a commands execution callback with failure detection.
"""
async def wrapper():
# On each retry we need to check active database as it might change.
await self._check_active_database()
return await callback()
return await self._command_retry.call_with_retry(
lambda: wrapper(),
lambda error: self._on_command_fail(error, *cmds),
)
async def _check_active_database(self):
"""
Checks if active a database needs to be updated.
"""
if (
self._active_database is None
or self._active_database.circuit.state != CBState.CLOSED
or (
self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL
and self._next_fallback_attempt <= datetime.now()
)
):
await self.set_active_database(
await self._failover_strategy_executor.execute()
)
self._schedule_next_fallback()
async def _on_command_fail(self, error, *args):
await self._event_dispatcher.dispatch_async(
AsyncOnCommandsFailEvent(args, error)
)
async def _register_command_execution(self, cmd: tuple):
for detector in self._failure_detectors:
await detector.register_command_execution(cmd)
def _setup_event_dispatcher(self):
"""
Registers necessary listeners.
"""
failure_listener = RegisterCommandFailure(self._failure_detectors)
resubscribe_listener = ResubscribeOnActiveDatabaseChanged()
close_connection_listener = CloseConnectionOnActiveDatabaseChanged()
self._event_dispatcher.register_listeners(
{
AsyncOnCommandsFailEvent: [failure_listener],
AsyncActiveDatabaseChanged: [
close_connection_listener,
resubscribe_listener,
],
}
)

View File

@@ -0,0 +1,210 @@
from dataclasses import dataclass, field
from typing import List, Optional, Type, Union
import pybreaker
from redis.asyncio import ConnectionPool, Redis, RedisCluster
from redis.asyncio.multidb.database import Database, Databases
from redis.asyncio.multidb.failover import (
DEFAULT_FAILOVER_ATTEMPTS,
DEFAULT_FAILOVER_DELAY,
AsyncFailoverStrategy,
WeightBasedFailoverStrategy,
)
from redis.asyncio.multidb.failure_detector import (
AsyncFailureDetector,
FailureDetectorAsyncWrapper,
)
from redis.asyncio.multidb.healthcheck import (
DEFAULT_HEALTH_CHECK_DELAY,
DEFAULT_HEALTH_CHECK_INTERVAL,
DEFAULT_HEALTH_CHECK_POLICY,
DEFAULT_HEALTH_CHECK_PROBES,
HealthCheck,
HealthCheckPolicies,
PingHealthCheck,
)
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
from redis.data_structure import WeightedList
from redis.event import EventDispatcher, EventDispatcherInterface
from redis.multidb.circuit import (
DEFAULT_GRACE_PERIOD,
CircuitBreaker,
PBCircuitBreakerAdapter,
)
from redis.multidb.failure_detector import (
DEFAULT_FAILURE_RATE_THRESHOLD,
DEFAULT_FAILURES_DETECTION_WINDOW,
DEFAULT_MIN_NUM_FAILURES,
CommandFailureDetector,
)
DEFAULT_AUTO_FALLBACK_INTERVAL = 120
def default_event_dispatcher() -> EventDispatcherInterface:
return EventDispatcher()
@dataclass
class DatabaseConfig:
"""
Dataclass representing the configuration for a database connection.
This class is used to store configuration settings for a database connection,
including client options, connection sourcing details, circuit breaker settings,
and cluster-specific properties. It provides a structure for defining these
attributes and allows for the creation of customized configurations for various
database setups.
Attributes:
weight (float): Weight of the database to define the active one.
client_kwargs (dict): Additional parameters for the database client connection.
from_url (Optional[str]): Redis URL way of connecting to the database.
from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use.
circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation.
grace_period (float): Grace period after which we need to check if the circuit could be closed again.
health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used
on public Redis Enterprise endpoints.
Methods:
default_circuit_breaker:
Generates and returns a default CircuitBreaker instance adapted for use.
"""
weight: float = 1.0
client_kwargs: dict = field(default_factory=dict)
from_url: Optional[str] = None
from_pool: Optional[ConnectionPool] = None
circuit: Optional[CircuitBreaker] = None
grace_period: float = DEFAULT_GRACE_PERIOD
health_check_url: Optional[str] = None
def default_circuit_breaker(self) -> CircuitBreaker:
circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period)
return PBCircuitBreakerAdapter(circuit_breaker)
@dataclass
class MultiDbConfig:
"""
Configuration class for managing multiple database connections in a resilient and fail-safe manner.
Attributes:
databases_config: A list of database configurations.
client_class: The client class used to manage database connections.
command_retry: Retry strategy for executing database commands.
failure_detectors: Optional list of additional failure detectors for monitoring database failures.
min_num_failures: Minimal count of failures required for failover
failure_rate_threshold: Percentage of failures required for failover
failures_detection_window: Time interval for tracking database failures.
health_checks: Optional list of additional health checks performed on databases.
health_check_interval: Time interval for executing health checks.
health_check_probes: Number of attempts to evaluate the health of a database.
health_check_delay: Delay between health check attempts.
failover_strategy: Optional strategy for handling database failover scenarios.
failover_attempts: Number of retries allowed for failover operations.
failover_delay: Delay between failover attempts.
auto_fallback_interval: Time interval to trigger automatic fallback.
event_dispatcher: Interface for dispatching events related to database operations.
Methods:
databases:
Retrieves a collection of database clients managed by weighted configurations.
Initializes database clients based on the provided configuration and removes
redundant retry objects for lower-level clients to rely on global retry logic.
default_failure_detectors:
Returns the default list of failure detectors used to monitor database failures.
default_health_checks:
Returns the default list of health checks used to monitor database health
with specific retry and backoff strategies.
default_failover_strategy:
Provides the default failover strategy used for handling failover scenarios
with defined retry and backoff configurations.
"""
databases_config: List[DatabaseConfig]
client_class: Type[Union[Redis, RedisCluster]] = Redis
command_retry: Retry = Retry(
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
)
failure_detectors: Optional[List[AsyncFailureDetector]] = None
min_num_failures: int = DEFAULT_MIN_NUM_FAILURES
failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD
failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW
health_checks: Optional[List[HealthCheck]] = None
health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL
health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES
health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY
health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY
failover_strategy: Optional[AsyncFailoverStrategy] = None
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS
failover_delay: float = DEFAULT_FAILOVER_DELAY
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL
event_dispatcher: EventDispatcherInterface = field(
default_factory=default_event_dispatcher
)
def databases(self) -> Databases:
databases = WeightedList()
for database_config in self.databases_config:
# The retry object is not used in the lower level clients, so we can safely remove it.
# We rely on command_retry in terms of global retries.
database_config.client_kwargs.update(
{"retry": Retry(retries=0, backoff=NoBackoff())}
)
if database_config.from_url:
client = self.client_class.from_url(
database_config.from_url, **database_config.client_kwargs
)
elif database_config.from_pool:
database_config.from_pool.set_retry(
Retry(retries=0, backoff=NoBackoff())
)
client = self.client_class.from_pool(
connection_pool=database_config.from_pool
)
else:
client = self.client_class(**database_config.client_kwargs)
circuit = (
database_config.default_circuit_breaker()
if database_config.circuit is None
else database_config.circuit
)
databases.add(
Database(
client=client,
circuit=circuit,
weight=database_config.weight,
health_check_url=database_config.health_check_url,
),
database_config.weight,
)
return databases
def default_failure_detectors(self) -> List[AsyncFailureDetector]:
return [
FailureDetectorAsyncWrapper(
CommandFailureDetector(
min_num_failures=self.min_num_failures,
failure_rate_threshold=self.failure_rate_threshold,
failure_detection_window=self.failures_detection_window,
)
),
]
def default_health_checks(self) -> List[HealthCheck]:
return [
PingHealthCheck(),
]
def default_failover_strategy(self) -> AsyncFailoverStrategy:
return WeightBasedFailoverStrategy()

View File

@@ -0,0 +1,69 @@
from abc import abstractmethod
from typing import Optional, Union
from redis.asyncio import Redis, RedisCluster
from redis.data_structure import WeightedList
from redis.multidb.circuit import CircuitBreaker
from redis.multidb.database import AbstractDatabase, BaseDatabase
from redis.typing import Number
class AsyncDatabase(AbstractDatabase):
"""Database with an underlying asynchronous redis client."""
@property
@abstractmethod
def client(self) -> Union[Redis, RedisCluster]:
"""The underlying redis client."""
pass
@client.setter
@abstractmethod
def client(self, client: Union[Redis, RedisCluster]):
"""Set the underlying redis client."""
pass
@property
@abstractmethod
def circuit(self) -> CircuitBreaker:
"""Circuit breaker for the current database."""
pass
@circuit.setter
@abstractmethod
def circuit(self, circuit: CircuitBreaker):
"""Set the circuit breaker for the current database."""
pass
Databases = WeightedList[tuple[AsyncDatabase, Number]]
class Database(BaseDatabase, AsyncDatabase):
def __init__(
self,
client: Union[Redis, RedisCluster],
circuit: CircuitBreaker,
weight: float,
health_check_url: Optional[str] = None,
):
self._client = client
self._cb = circuit
self._cb.database = self
super().__init__(weight, health_check_url)
@property
def client(self) -> Union[Redis, RedisCluster]:
return self._client
@client.setter
def client(self, client: Union[Redis, RedisCluster]):
self._client = client
@property
def circuit(self) -> CircuitBreaker:
return self._cb
@circuit.setter
def circuit(self, circuit: CircuitBreaker):
self._cb = circuit

View File

@@ -0,0 +1,84 @@
from typing import List
from redis.asyncio import Redis
from redis.asyncio.multidb.database import AsyncDatabase
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
from redis.event import AsyncEventListenerInterface, AsyncOnCommandsFailEvent
class AsyncActiveDatabaseChanged:
"""
Event fired when an async active database has been changed.
"""
def __init__(
self,
old_database: AsyncDatabase,
new_database: AsyncDatabase,
command_executor,
**kwargs,
):
self._old_database = old_database
self._new_database = new_database
self._command_executor = command_executor
self._kwargs = kwargs
@property
def old_database(self) -> AsyncDatabase:
return self._old_database
@property
def new_database(self) -> AsyncDatabase:
return self._new_database
@property
def command_executor(self):
return self._command_executor
@property
def kwargs(self):
return self._kwargs
class ResubscribeOnActiveDatabaseChanged(AsyncEventListenerInterface):
"""
Re-subscribe the currently active pub / sub to a new active database.
"""
async def listen(self, event: AsyncActiveDatabaseChanged):
old_pubsub = event.command_executor.active_pubsub
if old_pubsub is not None:
# Re-assign old channels and patterns so they will be automatically subscribed on connection.
new_pubsub = event.new_database.client.pubsub(**event.kwargs)
new_pubsub.channels = old_pubsub.channels
new_pubsub.patterns = old_pubsub.patterns
await new_pubsub.on_connect(None)
event.command_executor.active_pubsub = new_pubsub
await old_pubsub.aclose()
class CloseConnectionOnActiveDatabaseChanged(AsyncEventListenerInterface):
"""
Close connection to the old active database.
"""
async def listen(self, event: AsyncActiveDatabaseChanged):
await event.old_database.client.aclose()
if isinstance(event.old_database.client, Redis):
await event.old_database.client.connection_pool.update_active_connections_for_reconnect()
await event.old_database.client.connection_pool.disconnect()
class RegisterCommandFailure(AsyncEventListenerInterface):
"""
Event listener that registers command failures and passing it to the failure detectors.
"""
def __init__(self, failure_detectors: List[AsyncFailureDetector]):
self._failure_detectors = failure_detectors
async def listen(self, event: AsyncOnCommandsFailEvent) -> None:
for failure_detector in self._failure_detectors:
await failure_detector.register_failure(event.exception, event.commands)

View File

@@ -0,0 +1,125 @@
import time
from abc import ABC, abstractmethod
from redis.asyncio.multidb.database import AsyncDatabase, Databases
from redis.data_structure import WeightedList
from redis.multidb.circuit import State as CBState
from redis.multidb.exception import (
NoValidDatabaseException,
TemporaryUnavailableException,
)
DEFAULT_FAILOVER_ATTEMPTS = 10
DEFAULT_FAILOVER_DELAY = 12
class AsyncFailoverStrategy(ABC):
@abstractmethod
async def database(self) -> AsyncDatabase:
"""Select the database according to the strategy."""
pass
@abstractmethod
def set_databases(self, databases: Databases) -> None:
"""Set the database strategy operates on."""
pass
class FailoverStrategyExecutor(ABC):
@property
@abstractmethod
def failover_attempts(self) -> int:
"""The number of failover attempts."""
pass
@property
@abstractmethod
def failover_delay(self) -> float:
"""The delay between failover attempts."""
pass
@property
@abstractmethod
def strategy(self) -> AsyncFailoverStrategy:
"""The strategy to execute."""
pass
@abstractmethod
async def execute(self) -> AsyncDatabase:
"""Execute the failover strategy."""
pass
class WeightBasedFailoverStrategy(AsyncFailoverStrategy):
"""
Failover strategy based on database weights.
"""
def __init__(self):
self._databases = WeightedList()
async def database(self) -> AsyncDatabase:
for database, _ in self._databases:
if database.circuit.state == CBState.CLOSED:
return database
raise NoValidDatabaseException("No valid database available for communication")
def set_databases(self, databases: Databases) -> None:
self._databases = databases
class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor):
"""
Executes given failover strategy.
"""
def __init__(
self,
strategy: AsyncFailoverStrategy,
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
failover_delay: float = DEFAULT_FAILOVER_DELAY,
):
self._strategy = strategy
self._failover_attempts = failover_attempts
self._failover_delay = failover_delay
self._next_attempt_ts: int = 0
self._failover_counter: int = 0
@property
def failover_attempts(self) -> int:
return self._failover_attempts
@property
def failover_delay(self) -> float:
return self._failover_delay
@property
def strategy(self) -> AsyncFailoverStrategy:
return self._strategy
async def execute(self) -> AsyncDatabase:
try:
database = await self._strategy.database()
self._reset()
return database
except NoValidDatabaseException as e:
if self._next_attempt_ts == 0:
self._next_attempt_ts = time.time() + self._failover_delay
self._failover_counter += 1
elif time.time() >= self._next_attempt_ts:
self._next_attempt_ts += self._failover_delay
self._failover_counter += 1
if self._failover_counter > self._failover_attempts:
self._reset()
raise e
else:
raise TemporaryUnavailableException(
"No database connections currently available. "
"This is a temporary condition - please retry the operation."
)
def _reset(self) -> None:
self._next_attempt_ts = 0
self._failover_counter = 0

View File

@@ -0,0 +1,38 @@
from abc import ABC, abstractmethod
from redis.multidb.failure_detector import FailureDetector
class AsyncFailureDetector(ABC):
@abstractmethod
async def register_failure(self, exception: Exception, cmd: tuple) -> None:
"""Register a failure that occurred during command execution."""
pass
@abstractmethod
async def register_command_execution(self, cmd: tuple) -> None:
"""Register a command execution."""
pass
@abstractmethod
def set_command_executor(self, command_executor) -> None:
"""Set the command executor for this failure."""
pass
class FailureDetectorAsyncWrapper(AsyncFailureDetector):
"""
Async wrapper for the failure detector.
"""
def __init__(self, failure_detector: FailureDetector) -> None:
self._failure_detector = failure_detector
async def register_failure(self, exception: Exception, cmd: tuple) -> None:
self._failure_detector.register_failure(exception, cmd)
async def register_command_execution(self, cmd: tuple) -> None:
self._failure_detector.register_command_execution(cmd)
def set_command_executor(self, command_executor) -> None:
self._failure_detector.set_command_executor(command_executor)

View File

@@ -0,0 +1,285 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Optional, Tuple, Union
from redis.asyncio import Redis
from redis.asyncio.http.http_client import DEFAULT_TIMEOUT, AsyncHTTPClientWrapper
from redis.backoff import NoBackoff
from redis.http.http_client import HttpClient
from redis.multidb.exception import UnhealthyDatabaseException
from redis.retry import Retry
DEFAULT_HEALTH_CHECK_PROBES = 3
DEFAULT_HEALTH_CHECK_INTERVAL = 5
DEFAULT_HEALTH_CHECK_DELAY = 0.5
DEFAULT_LAG_AWARE_TOLERANCE = 5000
logger = logging.getLogger(__name__)
class HealthCheck(ABC):
@abstractmethod
async def check_health(self, database) -> bool:
"""Function to determine the health status."""
pass
class HealthCheckPolicy(ABC):
"""
Health checks execution policy.
"""
@property
@abstractmethod
def health_check_probes(self) -> int:
"""Number of probes to execute health checks."""
pass
@property
@abstractmethod
def health_check_delay(self) -> float:
"""Delay between health check probes."""
pass
@abstractmethod
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
"""Execute health checks and return database health status."""
pass
class AbstractHealthCheckPolicy(HealthCheckPolicy):
def __init__(self, health_check_probes: int, health_check_delay: float):
if health_check_probes < 1:
raise ValueError("health_check_probes must be greater than 0")
self._health_check_probes = health_check_probes
self._health_check_delay = health_check_delay
@property
def health_check_probes(self) -> int:
return self._health_check_probes
@property
def health_check_delay(self) -> float:
return self._health_check_delay
@abstractmethod
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
pass
class HealthyAllPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if all health check probes are successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
for health_check in health_checks:
for attempt in range(self.health_check_probes):
try:
if not await health_check.check_health(database):
return False
except Exception as e:
raise UnhealthyDatabaseException("Unhealthy database", database, e)
if attempt < self.health_check_probes - 1:
await asyncio.sleep(self._health_check_delay)
return True
class HealthyMajorityPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if a majority of health check probes are successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
for health_check in health_checks:
if self.health_check_probes % 2 == 0:
allowed_unsuccessful_probes = self.health_check_probes / 2
else:
allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2
for attempt in range(self.health_check_probes):
try:
if not await health_check.check_health(database):
allowed_unsuccessful_probes -= 1
if allowed_unsuccessful_probes <= 0:
return False
except Exception as e:
allowed_unsuccessful_probes -= 1
if allowed_unsuccessful_probes <= 0:
raise UnhealthyDatabaseException(
"Unhealthy database", database, e
)
if attempt < self.health_check_probes - 1:
await asyncio.sleep(self._health_check_delay)
return True
class HealthyAnyPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if at least one health check probe is successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
is_healthy = False
for health_check in health_checks:
exception = None
for attempt in range(self.health_check_probes):
try:
if await health_check.check_health(database):
is_healthy = True
break
else:
is_healthy = False
except Exception as e:
exception = UnhealthyDatabaseException(
"Unhealthy database", database, e
)
if attempt < self.health_check_probes - 1:
await asyncio.sleep(self._health_check_delay)
if not is_healthy and not exception:
return is_healthy
elif not is_healthy and exception:
raise exception
return is_healthy
class HealthCheckPolicies(Enum):
HEALTHY_ALL = HealthyAllPolicy
HEALTHY_MAJORITY = HealthyMajorityPolicy
HEALTHY_ANY = HealthyAnyPolicy
DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL
class PingHealthCheck(HealthCheck):
"""
Health check based on PING command.
"""
async def check_health(self, database) -> bool:
if isinstance(database.client, Redis):
return await database.client.execute_command("PING")
else:
# For a cluster checks if all nodes are healthy.
all_nodes = database.client.get_nodes()
for node in all_nodes:
if not await node.redis_connection.execute_command("PING"):
return False
return True
class LagAwareHealthCheck(HealthCheck):
"""
Health check available for Redis Enterprise deployments.
Verify via REST API that the database is healthy based on different lags.
"""
def __init__(
self,
rest_api_port: int = 9443,
lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE,
timeout: float = DEFAULT_TIMEOUT,
auth_basic: Optional[Tuple[str, str]] = None,
verify_tls: bool = True,
# TLS verification (server) options
ca_file: Optional[str] = None,
ca_path: Optional[str] = None,
ca_data: Optional[Union[str, bytes]] = None,
# Mutual TLS (client cert) options
client_cert_file: Optional[str] = None,
client_key_file: Optional[str] = None,
client_key_password: Optional[str] = None,
):
"""
Initialize LagAwareHealthCheck with the specified parameters.
Args:
rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100)
timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
auth_basic: Tuple of (username, password) for basic authentication
verify_tls: Whether to verify TLS certificates (default: True)
ca_file: Path to CA certificate file for TLS verification
ca_path: Path to CA certificates directory for TLS verification
ca_data: CA certificate data as string or bytes
client_cert_file: Path to client certificate file for mutual TLS
client_key_file: Path to client private key file for mutual TLS
client_key_password: Password for encrypted client private key
"""
self._http_client = AsyncHTTPClientWrapper(
HttpClient(
timeout=timeout,
auth_basic=auth_basic,
retry=Retry(NoBackoff(), retries=0),
verify_tls=verify_tls,
ca_file=ca_file,
ca_path=ca_path,
ca_data=ca_data,
client_cert_file=client_cert_file,
client_key_file=client_key_file,
client_key_password=client_key_password,
)
)
self._rest_api_port = rest_api_port
self._lag_aware_tolerance = lag_aware_tolerance
async def check_health(self, database) -> bool:
if database.health_check_url is None:
raise ValueError(
"Database health check url is not set. Please check DatabaseConfig for the current database."
)
if isinstance(database.client, Redis):
db_host = database.client.get_connection_kwargs()["host"]
else:
db_host = database.client.startup_nodes[0].host
base_url = f"{database.health_check_url}:{self._rest_api_port}"
self._http_client.client.base_url = base_url
# Find bdb matching to the current database host
matching_bdb = None
for bdb in await self._http_client.get("/v1/bdbs"):
for endpoint in bdb["endpoints"]:
if endpoint["dns_name"] == db_host:
matching_bdb = bdb
break
# In case if the host was set as public IP
for addr in endpoint["addr"]:
if addr == db_host:
matching_bdb = bdb
break
if matching_bdb is None:
logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
raise ValueError("Could not find a matching bdb")
url = (
f"/v1/bdbs/{matching_bdb['uid']}/availability"
f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}"
)
await self._http_client.get(url, expect_json=False)
# Status checked in an http client, otherwise HttpError will be raised
return True

View File

@@ -0,0 +1,58 @@
from asyncio import sleep
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar
from redis.exceptions import ConnectionError, RedisError, TimeoutError
from redis.retry import AbstractRetry
T = TypeVar("T")
if TYPE_CHECKING:
from redis.backoff import AbstractBackoff
class Retry(AbstractRetry[RedisError]):
__hash__ = AbstractRetry.__hash__
def __init__(
self,
backoff: "AbstractBackoff",
retries: int,
supported_errors: Tuple[Type[RedisError], ...] = (
ConnectionError,
TimeoutError,
),
):
super().__init__(backoff, retries, supported_errors)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Retry):
return NotImplemented
return (
self._backoff == other._backoff
and self._retries == other._retries
and set(self._supported_errors) == set(other._supported_errors)
)
async def call_with_retry(
self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any]
) -> T:
"""
Execute an operation that might fail and returns its result, or
raise the exception that was thrown depending on the `Backoff` object.
`do`: the operation to call. Expects no argument.
`fail`: the failure handler, expects the last error that was thrown
"""
self._backoff.reset()
failures = 0
while True:
try:
return await do()
except self._supported_errors as error:
failures += 1
await fail(error)
if self._retries >= 0 and failures > self._retries:
raise error
backoff = self._backoff.compute(failures)
if backoff > 0:
await sleep(backoff)

View File

@@ -0,0 +1,404 @@
import asyncio
import random
import weakref
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type
from redis.asyncio.client import Redis
from redis.asyncio.connection import (
Connection,
ConnectionPool,
EncodableT,
SSLConnection,
)
from redis.commands import AsyncSentinelCommands
from redis.exceptions import (
ConnectionError,
ReadOnlyError,
ResponseError,
TimeoutError,
)
class MasterNotFoundError(ConnectionError):
pass
class SlaveNotFoundError(ConnectionError):
pass
class SentinelManagedConnection(Connection):
def __init__(self, **kwargs):
self.connection_pool = kwargs.pop("connection_pool")
super().__init__(**kwargs)
def __repr__(self):
s = f"<{self.__class__.__module__}.{self.__class__.__name__}"
if self.host:
host_info = f",host={self.host},port={self.port}"
s += host_info
return s + ")>"
async def connect_to(self, address):
self.host, self.port = address
await self.connect_check_health(
check_health=self.connection_pool.check_connection,
retry_socket_connect=False,
)
async def _connect_retry(self):
if self._reader:
return # already connected
if self.connection_pool.is_master:
await self.connect_to(await self.connection_pool.get_master_address())
else:
async for slave in self.connection_pool.rotate_slaves():
try:
return await self.connect_to(slave)
except ConnectionError:
continue
raise SlaveNotFoundError # Never be here
async def connect(self):
return await self.retry.call_with_retry(
self._connect_retry,
lambda error: asyncio.sleep(0),
)
async def read_response(
self,
disable_decoding: bool = False,
timeout: Optional[float] = None,
*,
disconnect_on_error: Optional[float] = True,
push_request: Optional[bool] = False,
):
try:
return await super().read_response(
disable_decoding=disable_decoding,
timeout=timeout,
disconnect_on_error=disconnect_on_error,
push_request=push_request,
)
except ReadOnlyError:
if self.connection_pool.is_master:
# When talking to a master, a ReadOnlyError when likely
# indicates that the previous master that we're still connected
# to has been demoted to a slave and there's a new master.
# calling disconnect will force the connection to re-query
# sentinel during the next connect() attempt.
await self.disconnect()
raise ConnectionError("The previous master is now a slave")
raise
class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection):
pass
class SentinelConnectionPool(ConnectionPool):
"""
Sentinel backed connection pool.
If ``check_connection`` flag is set to True, SentinelManagedConnection
sends a PING command right after establishing the connection.
"""
def __init__(self, service_name, sentinel_manager, **kwargs):
kwargs["connection_class"] = kwargs.get(
"connection_class",
(
SentinelManagedSSLConnection
if kwargs.pop("ssl", False)
else SentinelManagedConnection
),
)
self.is_master = kwargs.pop("is_master", True)
self.check_connection = kwargs.pop("check_connection", False)
super().__init__(**kwargs)
self.connection_kwargs["connection_pool"] = weakref.proxy(self)
self.service_name = service_name
self.sentinel_manager = sentinel_manager
self.master_address = None
self.slave_rr_counter = None
def __repr__(self):
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
f"(service={self.service_name}({self.is_master and 'master' or 'slave'}))>"
)
def reset(self):
super().reset()
self.master_address = None
self.slave_rr_counter = None
def owns_connection(self, connection: Connection):
check = not self.is_master or (
self.is_master and self.master_address == (connection.host, connection.port)
)
return check and super().owns_connection(connection)
async def get_master_address(self):
master_address = await self.sentinel_manager.discover_master(self.service_name)
if self.is_master:
if self.master_address != master_address:
self.master_address = master_address
# disconnect any idle connections so that they reconnect
# to the new master the next time that they are used.
await self.disconnect(inuse_connections=False)
return master_address
async def rotate_slaves(self) -> AsyncIterator:
"""Round-robin slave balancer"""
slaves = await self.sentinel_manager.discover_slaves(self.service_name)
if slaves:
if self.slave_rr_counter is None:
self.slave_rr_counter = random.randint(0, len(slaves) - 1)
for _ in range(len(slaves)):
self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves)
slave = slaves[self.slave_rr_counter]
yield slave
# Fallback to the master connection
try:
yield await self.get_master_address()
except MasterNotFoundError:
pass
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")
class Sentinel(AsyncSentinelCommands):
"""
Redis Sentinel cluster client
>>> from redis.sentinel import Sentinel
>>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1)
>>> master = sentinel.master_for('mymaster', socket_timeout=0.1)
>>> await master.set('foo', 'bar')
>>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1)
>>> await slave.get('foo')
b'bar'
``sentinels`` is a list of sentinel nodes. Each node is represented by
a pair (hostname, port).
``min_other_sentinels`` defined a minimum number of peers for a sentinel.
When querying a sentinel, if it doesn't meet this threshold, responses
from that sentinel won't be considered valid.
``sentinel_kwargs`` is a dictionary of connection arguments used when
connecting to sentinel instances. Any argument that can be passed to
a normal Redis connection can be specified here. If ``sentinel_kwargs`` is
not specified, any socket_timeout and socket_keepalive options specified
in ``connection_kwargs`` will be used.
``connection_kwargs`` are keyword arguments that will be used when
establishing a connection to a Redis server.
"""
def __init__(
self,
sentinels,
min_other_sentinels=0,
sentinel_kwargs=None,
force_master_ip=None,
**connection_kwargs,
):
# if sentinel_kwargs isn't defined, use the socket_* options from
# connection_kwargs
if sentinel_kwargs is None:
sentinel_kwargs = {
k: v for k, v in connection_kwargs.items() if k.startswith("socket_")
}
self.sentinel_kwargs = sentinel_kwargs
self.sentinels = [
Redis(host=hostname, port=port, **self.sentinel_kwargs)
for hostname, port in sentinels
]
self.min_other_sentinels = min_other_sentinels
self.connection_kwargs = connection_kwargs
self._force_master_ip = force_master_ip
async def execute_command(self, *args, **kwargs):
"""
Execute Sentinel command in sentinel nodes.
once - If set to True, then execute the resulting command on a single
node at random, rather than across the entire sentinel cluster.
"""
once = bool(kwargs.pop("once", False))
# Check if command is supposed to return the original
# responses instead of boolean value.
return_responses = bool(kwargs.pop("return_responses", False))
if once:
response = await random.choice(self.sentinels).execute_command(
*args, **kwargs
)
if return_responses:
return [response]
else:
return True if response else False
tasks = [
asyncio.Task(sentinel.execute_command(*args, **kwargs))
for sentinel in self.sentinels
]
responses = await asyncio.gather(*tasks)
if return_responses:
return responses
return all(responses)
def __repr__(self):
sentinel_addresses = []
for sentinel in self.sentinels:
sentinel_addresses.append(
f"{sentinel.connection_pool.connection_kwargs['host']}:"
f"{sentinel.connection_pool.connection_kwargs['port']}"
)
return (
f"<{self.__class__}.{self.__class__.__name__}"
f"(sentinels=[{','.join(sentinel_addresses)}])>"
)
def check_master_state(self, state: dict, service_name: str) -> bool:
if not state["is_master"] or state["is_sdown"] or state["is_odown"]:
return False
# Check if our sentinel doesn't see other nodes
if state["num-other-sentinels"] < self.min_other_sentinels:
return False
return True
async def discover_master(self, service_name: str):
"""
Asks sentinel servers for the Redis master's address corresponding
to the service labeled ``service_name``.
Returns a pair (address, port) or raises MasterNotFoundError if no
master is found.
"""
collected_errors = list()
for sentinel_no, sentinel in enumerate(self.sentinels):
try:
masters = await sentinel.sentinel_masters()
except (ConnectionError, TimeoutError) as e:
collected_errors.append(f"{sentinel} - {e!r}")
continue
state = masters.get(service_name)
if state and self.check_master_state(state, service_name):
# Put this sentinel at the top of the list
self.sentinels[0], self.sentinels[sentinel_no] = (
sentinel,
self.sentinels[0],
)
ip = (
self._force_master_ip
if self._force_master_ip is not None
else state["ip"]
)
return ip, state["port"]
error_info = ""
if len(collected_errors) > 0:
error_info = f" : {', '.join(collected_errors)}"
raise MasterNotFoundError(f"No master found for {service_name!r}{error_info}")
def filter_slaves(
self, slaves: Iterable[Mapping]
) -> Sequence[Tuple[EncodableT, EncodableT]]:
"""Remove slaves that are in an ODOWN or SDOWN state"""
slaves_alive = []
for slave in slaves:
if slave["is_odown"] or slave["is_sdown"]:
continue
slaves_alive.append((slave["ip"], slave["port"]))
return slaves_alive
async def discover_slaves(
self, service_name: str
) -> Sequence[Tuple[EncodableT, EncodableT]]:
"""Returns a list of alive slaves for service ``service_name``"""
for sentinel in self.sentinels:
try:
slaves = await sentinel.sentinel_slaves(service_name)
except (ConnectionError, ResponseError, TimeoutError):
continue
slaves = self.filter_slaves(slaves)
if slaves:
return slaves
return []
def master_for(
self,
service_name: str,
redis_class: Type[Redis] = Redis,
connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool,
**kwargs,
):
"""
Returns a redis client instance for the ``service_name`` master.
Sentinel client will detect failover and reconnect Redis clients
automatically.
A :py:class:`~redis.sentinel.SentinelConnectionPool` class is
used to retrieve the master's address before establishing a new
connection.
NOTE: If the master's address has changed, any cached connections to
the old master are closed.
By default clients will be a :py:class:`~redis.Redis` instance.
Specify a different class to the ``redis_class`` argument if you
desire something different.
The ``connection_pool_class`` specifies the connection pool to
use. The :py:class:`~redis.sentinel.SentinelConnectionPool`
will be used by default.
All other keyword arguments are merged with any connection_kwargs
passed to this class and passed to the connection pool as keyword
arguments to be used to initialize Redis connections.
"""
kwargs["is_master"] = True
connection_kwargs = dict(self.connection_kwargs)
connection_kwargs.update(kwargs)
connection_pool = connection_pool_class(service_name, self, **connection_kwargs)
# The Redis object "owns" the pool
return redis_class.from_pool(connection_pool)
def slave_for(
self,
service_name: str,
redis_class: Type[Redis] = Redis,
connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool,
**kwargs,
):
"""
Returns redis client instance for the ``service_name`` slave(s).
A SentinelConnectionPool class is used to retrieve the slave's
address before establishing a new connection.
By default clients will be a :py:class:`~redis.Redis` instance.
Specify a different class to the ``redis_class`` argument if you
desire something different.
The ``connection_pool_class`` specifies the connection pool to use.
The SentinelConnectionPool will be used by default.
All other keyword arguments are merged with any connection_kwargs
passed to this class and passed to the connection pool as keyword
arguments to be used to initialize Redis connections.
"""
kwargs["is_master"] = False
connection_kwargs = dict(self.connection_kwargs)
connection_kwargs.update(kwargs)
connection_pool = connection_pool_class(service_name, self, **connection_kwargs)
# The Redis object "owns" the pool
return redis_class.from_pool(connection_pool)

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from redis.asyncio.client import Pipeline, Redis
def from_url(url, **kwargs):
"""
Returns an active Redis client generated from the given database URL.
Will attempt to extract the database id from the path url fragment, if
none is provided.
"""
from redis.asyncio.client import Redis
return Redis.from_url(url, **kwargs)
class pipeline: # noqa: N801
def __init__(self, redis_obj: "Redis"):
self.p: "Pipeline" = redis_obj.pipeline()
async def __aenter__(self) -> "Pipeline":
return self.p
async def __aexit__(self, exc_type, exc_value, traceback):
await self.p.execute()
del self.p

View File

@@ -0,0 +1,31 @@
from typing import Iterable
class RequestTokenErr(Exception):
"""
Represents an exception during token request.
"""
def __init__(self, *args):
super().__init__(*args)
class InvalidTokenSchemaErr(Exception):
"""
Represents an exception related to invalid token schema.
"""
def __init__(self, missing_fields: Iterable[str] = []):
super().__init__(
"Unexpected token schema. Following fields are missing: "
+ ", ".join(missing_fields)
)
class TokenRenewalErr(Exception):
"""
Represents an exception during token renewal process.
"""
def __init__(self, *args):
super().__init__(*args)

View File

@@ -0,0 +1,28 @@
from abc import ABC, abstractmethod
from redis.auth.token import TokenInterface
"""
This interface is the facade of an identity provider
"""
class IdentityProviderInterface(ABC):
"""
Receive a token from the identity provider.
Receiving a token only works when being authenticated.
"""
@abstractmethod
def request_token(self, force_refresh=False) -> TokenInterface:
pass
class IdentityProviderConfigInterface(ABC):
"""
Configuration class that provides a configured identity provider.
"""
@abstractmethod
def get_provider(self) -> IdentityProviderInterface:
pass

View File

@@ -0,0 +1,130 @@
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from redis.auth.err import InvalidTokenSchemaErr
class TokenInterface(ABC):
@abstractmethod
def is_expired(self) -> bool:
pass
@abstractmethod
def ttl(self) -> float:
pass
@abstractmethod
def try_get(self, key: str) -> str:
pass
@abstractmethod
def get_value(self) -> str:
pass
@abstractmethod
def get_expires_at_ms(self) -> float:
pass
@abstractmethod
def get_received_at_ms(self) -> float:
pass
class TokenResponse:
def __init__(self, token: TokenInterface):
self._token = token
def get_token(self) -> TokenInterface:
return self._token
def get_ttl_ms(self) -> float:
return self._token.get_expires_at_ms() - self._token.get_received_at_ms()
class SimpleToken(TokenInterface):
def __init__(
self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict
) -> None:
self.value = value
self.expires_at = expires_at_ms
self.received_at = received_at_ms
self.claims = claims
def ttl(self) -> float:
if self.expires_at == -1:
return -1
return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)
def is_expired(self) -> bool:
if self.expires_at == -1:
return False
return self.ttl() <= 0
def try_get(self, key: str) -> str:
return self.claims.get(key)
def get_value(self) -> str:
return self.value
def get_expires_at_ms(self) -> float:
return self.expires_at
def get_received_at_ms(self) -> float:
return self.received_at
class JWToken(TokenInterface):
REQUIRED_FIELDS = {"exp"}
def __init__(self, token: str):
try:
import jwt
except ImportError as ie:
raise ImportError(
f"The PyJWT library is required for {self.__class__.__name__}.",
) from ie
self._value = token
self._decoded = jwt.decode(
self._value,
options={"verify_signature": False},
algorithms=[jwt.get_unverified_header(self._value).get("alg")],
)
self._validate_token()
def is_expired(self) -> bool:
exp = self._decoded["exp"]
if exp == -1:
return False
return (
self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000
)
def ttl(self) -> float:
exp = self._decoded["exp"]
if exp == -1:
return -1
return (
self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000
)
def try_get(self, key: str) -> str:
return self._decoded.get(key)
def get_value(self) -> str:
return self._value
def get_expires_at_ms(self) -> float:
return float(self._decoded["exp"] * 1000)
def get_received_at_ms(self) -> float:
return datetime.now(timezone.utc).timestamp() * 1000
def _validate_token(self):
actual_fields = {x for x in self._decoded.keys()}
if len(self.REQUIRED_FIELDS - actual_fields) != 0:
raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)

View File

@@ -0,0 +1,370 @@
import asyncio
import logging
import threading
from datetime import datetime, timezone
from time import sleep
from typing import Any, Awaitable, Callable, Union
from redis.auth.err import RequestTokenErr, TokenRenewalErr
from redis.auth.idp import IdentityProviderInterface
from redis.auth.token import TokenResponse
logger = logging.getLogger(__name__)
class CredentialsListener:
"""
Listeners that will be notified on events related to credentials.
Accepts callbacks and awaitable callbacks.
"""
def __init__(self):
self._on_next = None
self._on_error = None
@property
def on_next(self) -> Union[Callable[[Any], None], Awaitable]:
return self._on_next
@on_next.setter
def on_next(self, callback: Union[Callable[[Any], None], Awaitable]) -> None:
self._on_next = callback
@property
def on_error(self) -> Union[Callable[[Exception], None], Awaitable]:
return self._on_error
@on_error.setter
def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]) -> None:
self._on_error = callback
class RetryPolicy:
def __init__(self, max_attempts: int, delay_in_ms: float):
self.max_attempts = max_attempts
self.delay_in_ms = delay_in_ms
def get_max_attempts(self) -> int:
"""
Retry attempts before exception will be thrown.
:return: int
"""
return self.max_attempts
def get_delay_in_ms(self) -> float:
"""
Delay between retries in seconds.
:return: int
"""
return self.delay_in_ms
class TokenManagerConfig:
def __init__(
self,
expiration_refresh_ratio: float,
lower_refresh_bound_millis: int,
token_request_execution_timeout_in_ms: int,
retry_policy: RetryPolicy,
):
self._expiration_refresh_ratio = expiration_refresh_ratio
self._lower_refresh_bound_millis = lower_refresh_bound_millis
self._token_request_execution_timeout_in_ms = (
token_request_execution_timeout_in_ms
)
self._retry_policy = retry_policy
def get_expiration_refresh_ratio(self) -> float:
"""
Represents the ratio of a token's lifetime at which a refresh should be triggered. # noqa: E501
For example, a value of 0.75 means the token should be refreshed
when 75% of its lifetime has elapsed (or when 25% of its lifetime remains).
:return: float
"""
return self._expiration_refresh_ratio
def get_lower_refresh_bound_millis(self) -> int:
"""
Represents the minimum time in milliseconds before token expiration
to trigger a refresh, in milliseconds.
This value sets a fixed lower bound for when a token refresh should occur,
regardless of the token's total lifetime.
If set to 0 there will be no lower bound and the refresh will be triggered
based on the expirationRefreshRatio only.
:return: int
"""
return self._lower_refresh_bound_millis
def get_token_request_execution_timeout_in_ms(self) -> int:
"""
Represents the maximum time in milliseconds to wait
for a token request to complete.
:return: int
"""
return self._token_request_execution_timeout_in_ms
def get_retry_policy(self) -> RetryPolicy:
"""
Represents the retry policy for token requests.
:return: RetryPolicy
"""
return self._retry_policy
class TokenManager:
def __init__(
self, identity_provider: IdentityProviderInterface, config: TokenManagerConfig
):
self._idp = identity_provider
self._config = config
self._next_timer = None
self._listener = None
self._init_timer = None
self._retries = 0
def __del__(self):
logger.info("Token manager are disposed")
self.stop()
def start(
self,
listener: CredentialsListener,
skip_initial: bool = False,
) -> Callable[[], None]:
self._listener = listener
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# Run loop in a separate thread to unblock main thread.
loop = asyncio.new_event_loop()
thread = threading.Thread(
target=_start_event_loop_in_thread, args=(loop,), daemon=True
)
thread.start()
# Event to block for initial execution.
init_event = asyncio.Event()
self._init_timer = loop.call_later(
0, self._renew_token, skip_initial, init_event
)
logger.info("Token manager started")
# Blocks in thread-safe manner.
asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result()
return self.stop
async def start_async(
self,
listener: CredentialsListener,
block_for_initial: bool = False,
initial_delay_in_ms: float = 0,
skip_initial: bool = False,
) -> Callable[[], None]:
self._listener = listener
loop = asyncio.get_running_loop()
init_event = asyncio.Event()
# Wraps the async callback with async wrapper to schedule with loop.call_later()
wrapped = _async_to_sync_wrapper(
loop, self._renew_token_async, skip_initial, init_event
)
self._init_timer = loop.call_later(initial_delay_in_ms / 1000, wrapped)
logger.info("Token manager started")
if block_for_initial:
await init_event.wait()
return self.stop
def stop(self):
if self._init_timer is not None:
self._init_timer.cancel()
if self._next_timer is not None:
self._next_timer.cancel()
def acquire_token(self, force_refresh=False) -> TokenResponse:
try:
token = self._idp.request_token(force_refresh)
except RequestTokenErr as e:
if self._retries < self._config.get_retry_policy().get_max_attempts():
self._retries += 1
sleep(self._config.get_retry_policy().get_delay_in_ms() / 1000)
return self.acquire_token(force_refresh)
else:
raise e
self._retries = 0
return TokenResponse(token)
async def acquire_token_async(self, force_refresh=False) -> TokenResponse:
try:
token = self._idp.request_token(force_refresh)
except RequestTokenErr as e:
if self._retries < self._config.get_retry_policy().get_max_attempts():
self._retries += 1
await asyncio.sleep(
self._config.get_retry_policy().get_delay_in_ms() / 1000
)
return await self.acquire_token_async(force_refresh)
else:
raise e
self._retries = 0
return TokenResponse(token)
def _calculate_renewal_delay(self, expire_date: float, issue_date: float) -> float:
delay_for_lower_refresh = self._delay_for_lower_refresh(expire_date)
delay_for_ratio_refresh = self._delay_for_ratio_refresh(expire_date, issue_date)
delay = min(delay_for_ratio_refresh, delay_for_lower_refresh)
return 0 if delay < 0 else delay / 1000
def _delay_for_lower_refresh(self, expire_date: float):
return (
expire_date
- self._config.get_lower_refresh_bound_millis()
- (datetime.now(timezone.utc).timestamp() * 1000)
)
def _delay_for_ratio_refresh(self, expire_date: float, issue_date: float):
token_ttl = expire_date - issue_date
refresh_before = token_ttl - (
token_ttl * self._config.get_expiration_refresh_ratio()
)
return (
expire_date
- refresh_before
- (datetime.now(timezone.utc).timestamp() * 1000)
)
def _renew_token(
self, skip_initial: bool = False, init_event: asyncio.Event = None
):
"""
Task to renew token from identity provider.
Schedules renewal tasks based on token TTL.
"""
try:
token_res = self.acquire_token(force_refresh=True)
delay = self._calculate_renewal_delay(
token_res.get_token().get_expires_at_ms(),
token_res.get_token().get_received_at_ms(),
)
if token_res.get_token().is_expired():
raise TokenRenewalErr("Requested token is expired")
if self._listener.on_next is None:
logger.warning(
"No registered callback for token renewal task. Renewal cancelled"
)
return
if not skip_initial:
try:
self._listener.on_next(token_res.get_token())
except Exception as e:
raise TokenRenewalErr(e)
if delay <= 0:
return
loop = asyncio.get_running_loop()
self._next_timer = loop.call_later(delay, self._renew_token)
logger.info(f"Next token renewal scheduled in {delay} seconds")
return token_res
except Exception as e:
if self._listener.on_error is None:
raise e
self._listener.on_error(e)
finally:
if init_event:
init_event.set()
async def _renew_token_async(
self, skip_initial: bool = False, init_event: asyncio.Event = None
):
"""
Async task to renew tokens from identity provider.
Schedules renewal tasks based on token TTL.
"""
try:
token_res = await self.acquire_token_async(force_refresh=True)
delay = self._calculate_renewal_delay(
token_res.get_token().get_expires_at_ms(),
token_res.get_token().get_received_at_ms(),
)
if token_res.get_token().is_expired():
raise TokenRenewalErr("Requested token is expired")
if self._listener.on_next is None:
logger.warning(
"No registered callback for token renewal task. Renewal cancelled"
)
return
if not skip_initial:
try:
await self._listener.on_next(token_res.get_token())
except Exception as e:
raise TokenRenewalErr(e)
if delay <= 0:
return
loop = asyncio.get_running_loop()
wrapped = _async_to_sync_wrapper(loop, self._renew_token_async)
logger.info(f"Next token renewal scheduled in {delay} seconds")
loop.call_later(delay, wrapped)
except Exception as e:
if self._listener.on_error is None:
raise e
await self._listener.on_error(e)
finally:
if init_event:
init_event.set()
def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs):
"""
Wraps an asynchronous function so it can be used with loop.call_later.
:param loop: The event loop in which the coroutine will be executed.
:param coro_func: The coroutine function to wrap.
:param args: Positional arguments to pass to the coroutine function.
:param kwargs: Keyword arguments to pass to the coroutine function.
:return: A regular function suitable for loop.call_later.
"""
def wrapped():
# Schedule the coroutine in the event loop
asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop)
return wrapped
def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop):
"""
Starts event loop in a thread.
Used to be able to schedule tasks using loop.call_later.
:param event_loop:
:return:
"""
asyncio.set_event_loop(event_loop)
event_loop.run_forever()

View File

@@ -0,0 +1,204 @@
import asyncio
import threading
from typing import Any, Callable, Coroutine
class BackgroundScheduler:
"""
Schedules background tasks execution either in separate thread or in the running event loop.
"""
def __init__(self):
self._next_timer = None
self._event_loops = []
self._lock = threading.Lock()
self._stopped = False
def __del__(self):
self.stop()
def stop(self):
"""
Stop all scheduled tasks and clean up resources.
"""
with self._lock:
if self._stopped:
return
self._stopped = True
if self._next_timer:
self._next_timer.cancel()
self._next_timer = None
# Stop all event loops
for loop in self._event_loops:
if loop.is_running():
loop.call_soon_threadsafe(loop.stop)
self._event_loops.clear()
def run_once(self, delay: float, callback: Callable, *args):
"""
Runs callable task once after certain delay in seconds.
"""
with self._lock:
if self._stopped:
return
# Run loop in a separate thread to unblock main thread.
loop = asyncio.new_event_loop()
with self._lock:
self._event_loops.append(loop)
thread = threading.Thread(
target=_start_event_loop_in_thread,
args=(loop, self._call_later, delay, callback, *args),
daemon=True,
)
thread.start()
def run_recurring(self, interval: float, callback: Callable, *args):
"""
Runs recurring callable task with given interval in seconds.
"""
with self._lock:
if self._stopped:
return
# Run loop in a separate thread to unblock main thread.
loop = asyncio.new_event_loop()
with self._lock:
self._event_loops.append(loop)
thread = threading.Thread(
target=_start_event_loop_in_thread,
args=(loop, self._call_later_recurring, interval, callback, *args),
daemon=True,
)
thread.start()
async def run_recurring_async(
self, interval: float, coro: Callable[..., Coroutine[Any, Any, Any]], *args
):
"""
Runs recurring coroutine with given interval in seconds in the current event loop.
To be used only from an async context. No additional threads are created.
"""
with self._lock:
if self._stopped:
return
loop = asyncio.get_running_loop()
wrapped = _async_to_sync_wrapper(loop, coro, *args)
def tick():
with self._lock:
if self._stopped:
return
# Schedule the coroutine
wrapped()
# Schedule next tick
self._next_timer = loop.call_later(interval, tick)
# Schedule first tick
self._next_timer = loop.call_later(interval, tick)
def _call_later(
self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args
):
with self._lock:
if self._stopped:
return
self._next_timer = loop.call_later(delay, callback, *args)
def _call_later_recurring(
self,
loop: asyncio.AbstractEventLoop,
interval: float,
callback: Callable,
*args,
):
with self._lock:
if self._stopped:
return
self._call_later(
loop, interval, self._execute_recurring, loop, interval, callback, *args
)
def _execute_recurring(
self,
loop: asyncio.AbstractEventLoop,
interval: float,
callback: Callable,
*args,
):
"""
Executes recurring callable task with given interval in seconds.
"""
with self._lock:
if self._stopped:
return
try:
callback(*args)
except Exception:
# Silently ignore exceptions during shutdown
pass
with self._lock:
if self._stopped:
return
self._call_later(
loop, interval, self._execute_recurring, loop, interval, callback, *args
)
def _start_event_loop_in_thread(
event_loop: asyncio.AbstractEventLoop, call_soon_cb: Callable, *args
):
"""
Starts event loop in a thread and schedule callback as soon as event loop is ready.
Used to be able to schedule tasks using loop.call_later.
:param event_loop:
:return:
"""
asyncio.set_event_loop(event_loop)
event_loop.call_soon(call_soon_cb, event_loop, *args)
try:
event_loop.run_forever()
finally:
try:
# Clean up pending tasks
pending = asyncio.all_tasks(event_loop)
for task in pending:
task.cancel()
# Run loop once more to process cancellations
event_loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
except Exception:
pass
finally:
event_loop.close()
def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs):
"""
Wraps an asynchronous function so it can be used with loop.call_later.
:param loop: The event loop in which the coroutine will be executed.
:param coro_func: The coroutine function to wrap.
:param args: Positional arguments to pass to the coroutine function.
:param kwargs: Keyword arguments to pass to the coroutine function.
:return: A regular function suitable for loop.call_later.
"""
def wrapped():
# Schedule the coroutine in the event loop
asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop)
return wrapped

View File

@@ -0,0 +1,183 @@
import random
from abc import ABC, abstractmethod
# Maximum backoff between each retry in seconds
DEFAULT_CAP = 0.512
# Minimum backoff between each retry in seconds
DEFAULT_BASE = 0.008
class AbstractBackoff(ABC):
"""Backoff interface"""
def reset(self):
"""
Reset internal state before an operation.
`reset` is called once at the beginning of
every call to `Retry.call_with_retry`
"""
pass
@abstractmethod
def compute(self, failures: int) -> float:
"""Compute backoff in seconds upon failure"""
pass
class ConstantBackoff(AbstractBackoff):
"""Constant backoff upon failure"""
def __init__(self, backoff: float) -> None:
"""`backoff`: backoff time in seconds"""
self._backoff = backoff
def __hash__(self) -> int:
return hash((self._backoff,))
def __eq__(self, other) -> bool:
if not isinstance(other, ConstantBackoff):
return NotImplemented
return self._backoff == other._backoff
def compute(self, failures: int) -> float:
return self._backoff
class NoBackoff(ConstantBackoff):
"""No backoff upon failure"""
def __init__(self) -> None:
super().__init__(0)
class ExponentialBackoff(AbstractBackoff):
"""Exponential backoff upon failure"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE):
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
"""
self._cap = cap
self._base = base
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, ExponentialBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def compute(self, failures: int) -> float:
return min(self._cap, self._base * 2**failures)
class FullJitterBackoff(AbstractBackoff):
"""Full jitter backoff upon failure"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
"""
self._cap = cap
self._base = base
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, FullJitterBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def compute(self, failures: int) -> float:
return random.uniform(0, min(self._cap, self._base * 2**failures))
class EqualJitterBackoff(AbstractBackoff):
"""Equal jitter backoff upon failure"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
"""
self._cap = cap
self._base = base
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, EqualJitterBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def compute(self, failures: int) -> float:
temp = min(self._cap, self._base * 2**failures) / 2
return temp + random.uniform(0, temp)
class DecorrelatedJitterBackoff(AbstractBackoff):
"""Decorrelated jitter backoff upon failure"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
"""
self._cap = cap
self._base = base
self._previous_backoff = 0
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, DecorrelatedJitterBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def reset(self) -> None:
self._previous_backoff = 0
def compute(self, failures: int) -> float:
max_backoff = max(self._base, self._previous_backoff * 3)
temp = random.uniform(self._base, max_backoff)
self._previous_backoff = min(self._cap, temp)
return self._previous_backoff
class ExponentialWithJitterBackoff(AbstractBackoff):
"""Exponential backoff upon failure, with jitter"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
"""
self._cap = cap
self._base = base
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, ExponentialWithJitterBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def compute(self, failures: int) -> float:
return min(self._cap, random.random() * self._base * 2**failures)
def default_backoff():
return EqualJitterBackoff()

View File

@@ -0,0 +1,402 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Optional, Union
class CacheEntryStatus(Enum):
VALID = "VALID"
IN_PROGRESS = "IN_PROGRESS"
class EvictionPolicyType(Enum):
time_based = "time_based"
frequency_based = "frequency_based"
@dataclass(frozen=True)
class CacheKey:
command: str
redis_keys: tuple
class CacheEntry:
def __init__(
self,
cache_key: CacheKey,
cache_value: bytes,
status: CacheEntryStatus,
connection_ref,
):
self.cache_key = cache_key
self.cache_value = cache_value
self.status = status
self.connection_ref = connection_ref
def __hash__(self):
return hash(
(self.cache_key, self.cache_value, self.status, self.connection_ref)
)
def __eq__(self, other):
return hash(self) == hash(other)
class EvictionPolicyInterface(ABC):
@property
@abstractmethod
def cache(self):
pass
@cache.setter
@abstractmethod
def cache(self, value):
pass
@property
@abstractmethod
def type(self) -> EvictionPolicyType:
pass
@abstractmethod
def evict_next(self) -> CacheKey:
pass
@abstractmethod
def evict_many(self, count: int) -> List[CacheKey]:
pass
@abstractmethod
def touch(self, cache_key: CacheKey) -> None:
pass
class CacheConfigurationInterface(ABC):
@abstractmethod
def get_cache_class(self):
pass
@abstractmethod
def get_max_size(self) -> int:
pass
@abstractmethod
def get_eviction_policy(self):
pass
@abstractmethod
def is_exceeds_max_size(self, count: int) -> bool:
pass
@abstractmethod
def is_allowed_to_cache(self, command: str) -> bool:
pass
class CacheInterface(ABC):
@property
@abstractmethod
def collection(self) -> OrderedDict:
pass
@property
@abstractmethod
def config(self) -> CacheConfigurationInterface:
pass
@property
@abstractmethod
def eviction_policy(self) -> EvictionPolicyInterface:
pass
@property
@abstractmethod
def size(self) -> int:
pass
@abstractmethod
def get(self, key: CacheKey) -> Union[CacheEntry, None]:
pass
@abstractmethod
def set(self, entry: CacheEntry) -> bool:
pass
@abstractmethod
def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]:
pass
@abstractmethod
def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
pass
@abstractmethod
def flush(self) -> int:
pass
@abstractmethod
def is_cachable(self, key: CacheKey) -> bool:
pass
class DefaultCache(CacheInterface):
def __init__(
self,
cache_config: CacheConfigurationInterface,
) -> None:
self._cache = OrderedDict()
self._cache_config = cache_config
self._eviction_policy = self._cache_config.get_eviction_policy().value()
self._eviction_policy.cache = self
@property
def collection(self) -> OrderedDict:
return self._cache
@property
def config(self) -> CacheConfigurationInterface:
return self._cache_config
@property
def eviction_policy(self) -> EvictionPolicyInterface:
return self._eviction_policy
@property
def size(self) -> int:
return len(self._cache)
def set(self, entry: CacheEntry) -> bool:
if not self.is_cachable(entry.cache_key):
return False
self._cache[entry.cache_key] = entry
self._eviction_policy.touch(entry.cache_key)
if self._cache_config.is_exceeds_max_size(len(self._cache)):
self._eviction_policy.evict_next()
return True
def get(self, key: CacheKey) -> Union[CacheEntry, None]:
entry = self._cache.get(key, None)
if entry is None:
return None
self._eviction_policy.touch(key)
return entry
def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]:
response = []
for key in cache_keys:
if self.get(key) is not None:
self._cache.pop(key)
response.append(True)
else:
response.append(False)
return response
def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
response = []
keys_to_delete = []
for redis_key in redis_keys:
if isinstance(redis_key, bytes):
redis_key = redis_key.decode()
for cache_key in self._cache:
if redis_key in cache_key.redis_keys:
keys_to_delete.append(cache_key)
response.append(True)
for key in keys_to_delete:
self._cache.pop(key)
return response
def flush(self) -> int:
elem_count = len(self._cache)
self._cache.clear()
return elem_count
def is_cachable(self, key: CacheKey) -> bool:
return self._cache_config.is_allowed_to_cache(key.command)
class LRUPolicy(EvictionPolicyInterface):
def __init__(self):
self.cache = None
@property
def cache(self):
return self._cache
@cache.setter
def cache(self, cache: CacheInterface):
self._cache = cache
@property
def type(self) -> EvictionPolicyType:
return EvictionPolicyType.time_based
def evict_next(self) -> CacheKey:
self._assert_cache()
popped_entry = self._cache.collection.popitem(last=False)
return popped_entry[0]
def evict_many(self, count: int) -> List[CacheKey]:
self._assert_cache()
if count > len(self._cache.collection):
raise ValueError("Evictions count is above cache size")
popped_keys = []
for _ in range(count):
popped_entry = self._cache.collection.popitem(last=False)
popped_keys.append(popped_entry[0])
return popped_keys
def touch(self, cache_key: CacheKey) -> None:
self._assert_cache()
if self._cache.collection.get(cache_key) is None:
raise ValueError("Given entry does not belong to the cache")
self._cache.collection.move_to_end(cache_key)
def _assert_cache(self):
if self.cache is None or not isinstance(self.cache, CacheInterface):
raise ValueError("Eviction policy should be associated with valid cache.")
class EvictionPolicy(Enum):
LRU = LRUPolicy
class CacheConfig(CacheConfigurationInterface):
DEFAULT_CACHE_CLASS = DefaultCache
DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU
DEFAULT_MAX_SIZE = 10000
DEFAULT_ALLOW_LIST = [
"BITCOUNT",
"BITFIELD_RO",
"BITPOS",
"EXISTS",
"GEODIST",
"GEOHASH",
"GEOPOS",
"GEORADIUSBYMEMBER_RO",
"GEORADIUS_RO",
"GEOSEARCH",
"GET",
"GETBIT",
"GETRANGE",
"HEXISTS",
"HGET",
"HGETALL",
"HKEYS",
"HLEN",
"HMGET",
"HSTRLEN",
"HVALS",
"JSON.ARRINDEX",
"JSON.ARRLEN",
"JSON.GET",
"JSON.MGET",
"JSON.OBJKEYS",
"JSON.OBJLEN",
"JSON.RESP",
"JSON.STRLEN",
"JSON.TYPE",
"LCS",
"LINDEX",
"LLEN",
"LPOS",
"LRANGE",
"MGET",
"SCARD",
"SDIFF",
"SINTER",
"SINTERCARD",
"SISMEMBER",
"SMEMBERS",
"SMISMEMBER",
"SORT_RO",
"STRLEN",
"SUBSTR",
"SUNION",
"TS.GET",
"TS.INFO",
"TS.RANGE",
"TS.REVRANGE",
"TYPE",
"XLEN",
"XPENDING",
"XRANGE",
"XREAD",
"XREVRANGE",
"ZCARD",
"ZCOUNT",
"ZDIFF",
"ZINTER",
"ZINTERCARD",
"ZLEXCOUNT",
"ZMSCORE",
"ZRANGE",
"ZRANGEBYLEX",
"ZRANGEBYSCORE",
"ZRANK",
"ZREVRANGE",
"ZREVRANGEBYLEX",
"ZREVRANGEBYSCORE",
"ZREVRANK",
"ZSCORE",
"ZUNION",
]
def __init__(
self,
max_size: int = DEFAULT_MAX_SIZE,
cache_class: Any = DEFAULT_CACHE_CLASS,
eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
):
self._cache_class = cache_class
self._max_size = max_size
self._eviction_policy = eviction_policy
def get_cache_class(self):
return self._cache_class
def get_max_size(self) -> int:
return self._max_size
def get_eviction_policy(self) -> EvictionPolicy:
return self._eviction_policy
def is_exceeds_max_size(self, count: int) -> bool:
return count > self._max_size
def is_allowed_to_cache(self, command: str) -> bool:
return command in self.DEFAULT_ALLOW_LIST
class CacheFactoryInterface(ABC):
@abstractmethod
def get_cache(self) -> CacheInterface:
pass
class CacheFactory(CacheFactoryInterface):
def __init__(self, cache_config: Optional[CacheConfig] = None):
self._config = cache_config
if self._config is None:
self._config = CacheConfig()
def get_cache(self) -> CacheInterface:
cache_class = self._config.get_cache_class()
return cache_class(cache_config=self._config)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,18 @@
from .cluster import READ_COMMANDS, AsyncRedisClusterCommands, RedisClusterCommands
from .core import AsyncCoreCommands, CoreCommands
from .helpers import list_or_args
from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands
from .sentinel import AsyncSentinelCommands, SentinelCommands
__all__ = [
"AsyncCoreCommands",
"AsyncRedisClusterCommands",
"AsyncRedisModuleCommands",
"AsyncSentinelCommands",
"CoreCommands",
"READ_COMMANDS",
"RedisClusterCommands",
"RedisModuleCommands",
"SentinelCommands",
"list_or_args",
]

View File

@@ -0,0 +1,253 @@
from redis._parsers.helpers import bool_ok
from ..helpers import get_protocol_version, parse_to_list
from .commands import * # noqa
from .info import BFInfo, CFInfo, CMSInfo, TDigestInfo, TopKInfo
class AbstractBloom:
"""
The client allows to interact with RedisBloom and use all of
it's functionality.
- BF for Bloom Filter
- CF for Cuckoo Filter
- CMS for Count-Min Sketch
- TOPK for TopK Data Structure
- TDIGEST for estimate rank statistics
"""
@staticmethod
def append_items(params, items):
"""Append ITEMS to params."""
params.extend(["ITEMS"])
params += items
@staticmethod
def append_error(params, error):
"""Append ERROR to params."""
if error is not None:
params.extend(["ERROR", error])
@staticmethod
def append_capacity(params, capacity):
"""Append CAPACITY to params."""
if capacity is not None:
params.extend(["CAPACITY", capacity])
@staticmethod
def append_expansion(params, expansion):
"""Append EXPANSION to params."""
if expansion is not None:
params.extend(["EXPANSION", expansion])
@staticmethod
def append_no_scale(params, noScale):
"""Append NONSCALING tag to params."""
if noScale is not None:
params.extend(["NONSCALING"])
@staticmethod
def append_weights(params, weights):
"""Append WEIGHTS to params."""
if len(weights) > 0:
params.append("WEIGHTS")
params += weights
@staticmethod
def append_no_create(params, noCreate):
"""Append NOCREATE tag to params."""
if noCreate is not None:
params.extend(["NOCREATE"])
@staticmethod
def append_items_and_increments(params, items, increments):
"""Append pairs of items and increments to params."""
for i in range(len(items)):
params.append(items[i])
params.append(increments[i])
@staticmethod
def append_values_and_weights(params, items, weights):
"""Append pairs of items and weights to params."""
for i in range(len(items)):
params.append(items[i])
params.append(weights[i])
@staticmethod
def append_max_iterations(params, max_iterations):
"""Append MAXITERATIONS to params."""
if max_iterations is not None:
params.extend(["MAXITERATIONS", max_iterations])
@staticmethod
def append_bucket_size(params, bucket_size):
"""Append BUCKETSIZE to params."""
if bucket_size is not None:
params.extend(["BUCKETSIZE", bucket_size])
class CMSBloom(CMSCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
CMS_INITBYDIM: bool_ok,
CMS_INITBYPROB: bool_ok,
# CMS_INCRBY: spaceHolder,
# CMS_QUERY: spaceHolder,
CMS_MERGE: bool_ok,
}
_RESP2_MODULE_CALLBACKS = {
CMS_INFO: CMSInfo,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = CMSCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)
class TOPKBloom(TOPKCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
TOPK_RESERVE: bool_ok,
# TOPK_QUERY: spaceHolder,
# TOPK_COUNT: spaceHolder,
}
_RESP2_MODULE_CALLBACKS = {
TOPK_ADD: parse_to_list,
TOPK_INCRBY: parse_to_list,
TOPK_INFO: TopKInfo,
TOPK_LIST: parse_to_list,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = TOPKCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)
class CFBloom(CFCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
CF_RESERVE: bool_ok,
# CF_ADD: spaceHolder,
# CF_ADDNX: spaceHolder,
# CF_INSERT: spaceHolder,
# CF_INSERTNX: spaceHolder,
# CF_EXISTS: spaceHolder,
# CF_DEL: spaceHolder,
# CF_COUNT: spaceHolder,
# CF_SCANDUMP: spaceHolder,
# CF_LOADCHUNK: spaceHolder,
}
_RESP2_MODULE_CALLBACKS = {
CF_INFO: CFInfo,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = CFCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)
class TDigestBloom(TDigestCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
TDIGEST_CREATE: bool_ok,
# TDIGEST_RESET: bool_ok,
# TDIGEST_ADD: spaceHolder,
# TDIGEST_MERGE: spaceHolder,
}
_RESP2_MODULE_CALLBACKS = {
TDIGEST_BYRANK: parse_to_list,
TDIGEST_BYREVRANK: parse_to_list,
TDIGEST_CDF: parse_to_list,
TDIGEST_INFO: TDigestInfo,
TDIGEST_MIN: float,
TDIGEST_MAX: float,
TDIGEST_TRIMMED_MEAN: float,
TDIGEST_QUANTILE: parse_to_list,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = TDigestCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)
class BFBloom(BFCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
BF_RESERVE: bool_ok,
# BF_ADD: spaceHolder,
# BF_MADD: spaceHolder,
# BF_INSERT: spaceHolder,
# BF_EXISTS: spaceHolder,
# BF_MEXISTS: spaceHolder,
# BF_SCANDUMP: spaceHolder,
# BF_LOADCHUNK: spaceHolder,
# BF_CARD: spaceHolder,
}
_RESP2_MODULE_CALLBACKS = {
BF_INFO: BFInfo,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = BFCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)

Some files were not shown because too many files have changed in this diff Show More