diff --git a/README.rst b/README.rst index 2fa5626..ed19c7f 100644 --- a/README.rst +++ b/README.rst @@ -171,6 +171,7 @@ to 10, and all ``websocket.send!`` channels to 20: If you want to enforce a matching order, use an ``OrderedDict`` as the argument; channels will then be matched in the order the dict provides them. +.. _encryption ``symmetric_encryption_keys`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -237,6 +238,44 @@ And then in your channels consumer, you can implement the handler: async def redis_disconnect(self, *args): # Handle disconnect + + +``serializer_format`` +~~~~~~~~~~~~~~~~~~~~~~ +By default every message which reach redis is encoded using `msgpack `_. +It is also possible to switch to `JSON `_: + +.. code-block:: python + + CHANNEL_LAYERS = { + "default": { + "BACKEND": "channels_redis.core.RedisChannelLayer", + "CONFIG": { + "hosts": ["redis://:password@127.0.0.1:6379/0"], + "serializer_format": "json", + }, + }, + } + +A new serializer may be registered (or can be overriden) by using ``channels_redis.serializers.registry``, +providing a class which extends ``channels_redis.serializers.BaseMessageSerializer``, implementing ``dumps`` +and ``loads`` methods, or which provides ``serialize``/``deserialize`` methods and calling the registration method on registry: + +.. code-block:: python + + from channels_redis.serializers import registry + + class MyFormatSerializer: + def serialize(self, message): + ... + def deserialize(self, message): + ... + + registry.register_serializer('myformat', MyFormatSerializer) + +**NOTE**: Serializers also perform the encryption job see *symmetric_encryption_keys*. + + Dependencies ------------ diff --git a/channels_redis/core.py b/channels_redis/core.py index a164059..a70a600 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -5,16 +5,15 @@ import hashlib import itertools import logging -import random import time import uuid -import msgpack from redis import asyncio as aioredis from channels.exceptions import ChannelFull from channels.layers import BaseChannelLayer +from .serializers import registry from .utils import ( _close_redis, _consistent_hash, @@ -115,6 +114,7 @@ def __init__( capacity=100, channel_capacity=None, symmetric_encryption_keys=None, + serializer_format="msgpack", ): # Store basic information self.expiry = expiry @@ -126,6 +126,16 @@ def __init__( # Configure the host objects self.hosts = decode_hosts(hosts) self.ring_size = len(self.hosts) + # serialization + self._serializer = registry.get_serializer( + serializer_format, + # As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes. + random_prefix_length=12, + expiry=self.expiry, + symmetric_encryption_keys=symmetric_encryption_keys, + ) + self.serialize = self._serializer.serialize + self.deserialize = self._serializer.deserialize # Cached redis connection pools and the event loop they are from self._layers = {} # Normal channels choose a host index by cycling through the available hosts @@ -133,8 +143,6 @@ def __init__( self._send_index_generator = itertools.cycle(range(len(self.hosts))) # Decide on a unique client prefix to use in ! sections self.client_prefix = uuid.uuid4().hex - # Set up any encryption objects - self._setup_encryption(symmetric_encryption_keys) # Number of coroutines trying to receive right now self.receive_count = 0 # The receive lock @@ -154,24 +162,6 @@ def __init__( def create_pool(self, index): return create_pool(self.hosts[index]) - def _setup_encryption(self, symmetric_encryption_keys): - # See if we can do encryption if they asked - if symmetric_encryption_keys: - if isinstance(symmetric_encryption_keys, (str, bytes)): - raise ValueError( - "symmetric_encryption_keys must be a list of possible keys" - ) - try: - from cryptography.fernet import MultiFernet - except ImportError: - raise ValueError( - "Cannot run with encryption without 'cryptography' installed." - ) - sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys] - self.crypter = MultiFernet(sub_fernets) - else: - self.crypter = None - ### Channel layer API ### extensions = ["groups", "flush"] @@ -650,31 +640,6 @@ def _group_key(self, group): """ return f"{self.prefix}:group:{group}".encode("utf8") - ### Serialization ### - - def serialize(self, message): - """ - Serializes message to a byte string. - """ - value = msgpack.packb(message, use_bin_type=True) - if self.crypter: - value = self.crypter.encrypt(value) - - # As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes. - random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big") - return random_prefix + value - - def deserialize(self, message): - """ - Deserializes from a byte string. - """ - # Removes the random prefix - message = message[12:] - - if self.crypter: - message = self.crypter.decrypt(message, self.expiry + 10) - return msgpack.unpackb(message, raw=False) - ### Internal functions ### def consistent_hash(self, value): diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py index 6957b0a..f595d80 100644 --- a/channels_redis/pubsub.py +++ b/channels_redis/pubsub.py @@ -3,9 +3,9 @@ import logging import uuid -import msgpack from redis import asyncio as aioredis +from .serializers import registry from .utils import ( _close_redis, _consistent_hash, @@ -25,10 +25,23 @@ async def _async_proxy(obj, name, *args, **kwargs): class RedisPubSubChannelLayer: - def __init__(self, *args, **kwargs) -> None: + def __init__( + self, + *args, + symmetric_encryption_keys=None, + serializer_format="msgpack", + **kwargs, + ) -> None: self._args = args self._kwargs = kwargs self._layers = {} + # serialization + self._serializer = registry.get_serializer( + serializer_format, + symmetric_encryption_keys=symmetric_encryption_keys, + ) + self.serialize = self._serializer.serialize + self.deserialize = self._serializer.deserialize def __getattr__(self, name): if name in ( @@ -44,18 +57,6 @@ def __getattr__(self, name): else: return getattr(self._get_layer(), name) - def serialize(self, message): - """ - Serializes message to a byte string. - """ - return msgpack.packb(message) - - def deserialize(self, message): - """ - Deserializes from a byte string. - """ - return msgpack.unpackb(message) - def _get_layer(self): loop = asyncio.get_running_loop() diff --git a/channels_redis/serializers.py b/channels_redis/serializers.py new file mode 100644 index 0000000..1e89e5f --- /dev/null +++ b/channels_redis/serializers.py @@ -0,0 +1,141 @@ +import json +import random +import abc + + +class SerializerDoesNotExist(KeyError): + """The requested serializer was not found.""" + + +class BaseMessageSerializer(abc.ABC): + + def __init__( + self, + symmetric_encryption_keys=None, + random_prefix_length=0, + expiry=None, + ): + self.random_prefix_length = random_prefix_length + self.expiry = expiry + # Set up any encryption objects + self._setup_encryption(symmetric_encryption_keys) + + def _setup_encryption(self, symmetric_encryption_keys): + # See if we can do encryption if they asked + if symmetric_encryption_keys: + if isinstance(symmetric_encryption_keys, (str, bytes)): + raise ValueError( + "symmetric_encryption_keys must be a list of possible keys" + ) + try: + from cryptography.fernet import MultiFernet + except ImportError: + raise ValueError( + "Cannot run with encryption without 'cryptography' installed." + ) + sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys] + self.crypter = MultiFernet(sub_fernets) + else: + self.crypter = None + + @abc.abstractmethod + def dumps(self, message): + raise NotImplementedError + + @abc.abstractmethod + def loads(self, message): + raise NotImplementedError + + def serialize(self, message): + """ + Serializes message to a byte string. + """ + message = self.dumps(message) + # ensure message is bytes + if isinstance(message, str): + message = message.encode("utf-8") + if self.crypter: + message = self.crypter.encrypt(message) + + if self.random_prefix_length > 0: + # provide random prefix + message = ( + random.getrandbits(8 * self.random_prefix_length).to_bytes( + self.random_prefix_length, "big" + ) + + message + ) + return message + + def deserialize(self, message): + """ + Deserializes from a byte string. + """ + if self.random_prefix_length > 0: + # Removes the random prefix + message = message[self.random_prefix_length :] # noqa: E203 + + if self.crypter: + ttl = self.expiry if self.expiry is None else self.expiry + 10 + message = self.crypter.decrypt(message, ttl) + return self.loads(message) + + +class MissingSerializer(BaseMessageSerializer): + exception = None + + def __init__(self, *args, **kwargs): + raise self.exception + + +class JSONSerializer(BaseMessageSerializer): + dumps = staticmethod(json.dumps) + loads = staticmethod(json.loads) + + +# code ready for a future in which msgpack may become an optional dependency +try: + import msgpack +except ImportError as exc: + + class MsgPackSerializer(MissingSerializer): + exception = exc + +else: + + class MsgPackSerializer(BaseMessageSerializer): + dumps = staticmethod(msgpack.packb) + loads = staticmethod(msgpack.unpackb) + + +class SerializersRegistry: + def __init__(self): + self._registry = {} + + def register_serializer(self, format, serializer_class): + """ + Register a new serializer for given format + """ + assert isinstance(serializer_class, type) and ( + issubclass(serializer_class, BaseMessageSerializer) + or hasattr(serializer_class, "serialize") + and hasattr(serializer_class, "deserialize") + ), """ + `serializer_class` should be a class which implements `serialize` and `deserialize` method + or a subclass of `channels_redis.serializers.BaseMessageSerializer` + """ + + self._registry[format] = serializer_class + + def get_serializer(self, format, *args, **kwargs): + try: + serializer_class = self._registry[format] + except KeyError: + raise SerializerDoesNotExist(format) + + return serializer_class(*args, **kwargs) + + +registry = SerializersRegistry() +registry.register_serializer("json", JSONSerializer) +registry.register_serializer("msgpack", MsgPackSerializer) diff --git a/tests/test_serializers.py b/tests/test_serializers.py new file mode 100644 index 0000000..85907d1 --- /dev/null +++ b/tests/test_serializers.py @@ -0,0 +1,48 @@ +import pytest + +from channels_redis.serializers import SerializerDoesNotExist, SerializersRegistry + + +@pytest.fixture +def registry(): + return SerializersRegistry() + + +class OnlySerialize: + def serialize(self, message): + return message + + +class OnlyDeserialize: + def deserialize(self, message): + return message + + +def bad_serializer(): + pass + + +class NoopSerializer: + def serialize(self, message): + return message + + def deserialize(self, message): + return message + + +@pytest.mark.parametrize( + "serializer_class", (OnlyDeserialize, OnlySerialize, bad_serializer) +) +def test_refuse_to_register_bad_serializers(registry, serializer_class): + with pytest.raises(AssertionError): + registry.register_serializer("custom", serializer_class) + + +def test_raise_error_for_unregistered_serializer(registry): + with pytest.raises(SerializerDoesNotExist): + registry.get_serializer("unexistent") + + +def test_register_custom_serializer(registry): + registry.register_serializer("custom", NoopSerializer) + registry.get_serializer("custom")