diff --git a/README.rst b/README.rst
index 2fa5626..ed19c7f 100644
--- a/README.rst
+++ b/README.rst
@@ -171,6 +171,7 @@ to 10, and all ``websocket.send!`` channels to 20:
If you want to enforce a matching order, use an ``OrderedDict`` as the
argument; channels will then be matched in the order the dict provides them.
+.. _encryption
``symmetric_encryption_keys``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -237,6 +238,44 @@ And then in your channels consumer, you can implement the handler:
async def redis_disconnect(self, *args):
# Handle disconnect
+
+
+``serializer_format``
+~~~~~~~~~~~~~~~~~~~~~~
+By default every message which reach redis is encoded using `msgpack `_.
+It is also possible to switch to `JSON `_:
+
+.. code-block:: python
+
+ CHANNEL_LAYERS = {
+ "default": {
+ "BACKEND": "channels_redis.core.RedisChannelLayer",
+ "CONFIG": {
+ "hosts": ["redis://:password@127.0.0.1:6379/0"],
+ "serializer_format": "json",
+ },
+ },
+ }
+
+A new serializer may be registered (or can be overriden) by using ``channels_redis.serializers.registry``,
+providing a class which extends ``channels_redis.serializers.BaseMessageSerializer``, implementing ``dumps``
+and ``loads`` methods, or which provides ``serialize``/``deserialize`` methods and calling the registration method on registry:
+
+.. code-block:: python
+
+ from channels_redis.serializers import registry
+
+ class MyFormatSerializer:
+ def serialize(self, message):
+ ...
+ def deserialize(self, message):
+ ...
+
+ registry.register_serializer('myformat', MyFormatSerializer)
+
+**NOTE**: Serializers also perform the encryption job see *symmetric_encryption_keys*.
+
+
Dependencies
------------
diff --git a/channels_redis/core.py b/channels_redis/core.py
index a164059..a70a600 100644
--- a/channels_redis/core.py
+++ b/channels_redis/core.py
@@ -5,16 +5,15 @@
import hashlib
import itertools
import logging
-import random
import time
import uuid
-import msgpack
from redis import asyncio as aioredis
from channels.exceptions import ChannelFull
from channels.layers import BaseChannelLayer
+from .serializers import registry
from .utils import (
_close_redis,
_consistent_hash,
@@ -115,6 +114,7 @@ def __init__(
capacity=100,
channel_capacity=None,
symmetric_encryption_keys=None,
+ serializer_format="msgpack",
):
# Store basic information
self.expiry = expiry
@@ -126,6 +126,16 @@ def __init__(
# Configure the host objects
self.hosts = decode_hosts(hosts)
self.ring_size = len(self.hosts)
+ # serialization
+ self._serializer = registry.get_serializer(
+ serializer_format,
+ # As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
+ random_prefix_length=12,
+ expiry=self.expiry,
+ symmetric_encryption_keys=symmetric_encryption_keys,
+ )
+ self.serialize = self._serializer.serialize
+ self.deserialize = self._serializer.deserialize
# Cached redis connection pools and the event loop they are from
self._layers = {}
# Normal channels choose a host index by cycling through the available hosts
@@ -133,8 +143,6 @@ def __init__(
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
# Decide on a unique client prefix to use in ! sections
self.client_prefix = uuid.uuid4().hex
- # Set up any encryption objects
- self._setup_encryption(symmetric_encryption_keys)
# Number of coroutines trying to receive right now
self.receive_count = 0
# The receive lock
@@ -154,24 +162,6 @@ def __init__(
def create_pool(self, index):
return create_pool(self.hosts[index])
- def _setup_encryption(self, symmetric_encryption_keys):
- # See if we can do encryption if they asked
- if symmetric_encryption_keys:
- if isinstance(symmetric_encryption_keys, (str, bytes)):
- raise ValueError(
- "symmetric_encryption_keys must be a list of possible keys"
- )
- try:
- from cryptography.fernet import MultiFernet
- except ImportError:
- raise ValueError(
- "Cannot run with encryption without 'cryptography' installed."
- )
- sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys]
- self.crypter = MultiFernet(sub_fernets)
- else:
- self.crypter = None
-
### Channel layer API ###
extensions = ["groups", "flush"]
@@ -650,31 +640,6 @@ def _group_key(self, group):
"""
return f"{self.prefix}:group:{group}".encode("utf8")
- ### Serialization ###
-
- def serialize(self, message):
- """
- Serializes message to a byte string.
- """
- value = msgpack.packb(message, use_bin_type=True)
- if self.crypter:
- value = self.crypter.encrypt(value)
-
- # As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
- random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big")
- return random_prefix + value
-
- def deserialize(self, message):
- """
- Deserializes from a byte string.
- """
- # Removes the random prefix
- message = message[12:]
-
- if self.crypter:
- message = self.crypter.decrypt(message, self.expiry + 10)
- return msgpack.unpackb(message, raw=False)
-
### Internal functions ###
def consistent_hash(self, value):
diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py
index 6957b0a..f595d80 100644
--- a/channels_redis/pubsub.py
+++ b/channels_redis/pubsub.py
@@ -3,9 +3,9 @@
import logging
import uuid
-import msgpack
from redis import asyncio as aioredis
+from .serializers import registry
from .utils import (
_close_redis,
_consistent_hash,
@@ -25,10 +25,23 @@ async def _async_proxy(obj, name, *args, **kwargs):
class RedisPubSubChannelLayer:
- def __init__(self, *args, **kwargs) -> None:
+ def __init__(
+ self,
+ *args,
+ symmetric_encryption_keys=None,
+ serializer_format="msgpack",
+ **kwargs,
+ ) -> None:
self._args = args
self._kwargs = kwargs
self._layers = {}
+ # serialization
+ self._serializer = registry.get_serializer(
+ serializer_format,
+ symmetric_encryption_keys=symmetric_encryption_keys,
+ )
+ self.serialize = self._serializer.serialize
+ self.deserialize = self._serializer.deserialize
def __getattr__(self, name):
if name in (
@@ -44,18 +57,6 @@ def __getattr__(self, name):
else:
return getattr(self._get_layer(), name)
- def serialize(self, message):
- """
- Serializes message to a byte string.
- """
- return msgpack.packb(message)
-
- def deserialize(self, message):
- """
- Deserializes from a byte string.
- """
- return msgpack.unpackb(message)
-
def _get_layer(self):
loop = asyncio.get_running_loop()
diff --git a/channels_redis/serializers.py b/channels_redis/serializers.py
new file mode 100644
index 0000000..1e89e5f
--- /dev/null
+++ b/channels_redis/serializers.py
@@ -0,0 +1,141 @@
+import json
+import random
+import abc
+
+
+class SerializerDoesNotExist(KeyError):
+ """The requested serializer was not found."""
+
+
+class BaseMessageSerializer(abc.ABC):
+
+ def __init__(
+ self,
+ symmetric_encryption_keys=None,
+ random_prefix_length=0,
+ expiry=None,
+ ):
+ self.random_prefix_length = random_prefix_length
+ self.expiry = expiry
+ # Set up any encryption objects
+ self._setup_encryption(symmetric_encryption_keys)
+
+ def _setup_encryption(self, symmetric_encryption_keys):
+ # See if we can do encryption if they asked
+ if symmetric_encryption_keys:
+ if isinstance(symmetric_encryption_keys, (str, bytes)):
+ raise ValueError(
+ "symmetric_encryption_keys must be a list of possible keys"
+ )
+ try:
+ from cryptography.fernet import MultiFernet
+ except ImportError:
+ raise ValueError(
+ "Cannot run with encryption without 'cryptography' installed."
+ )
+ sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys]
+ self.crypter = MultiFernet(sub_fernets)
+ else:
+ self.crypter = None
+
+ @abc.abstractmethod
+ def dumps(self, message):
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def loads(self, message):
+ raise NotImplementedError
+
+ def serialize(self, message):
+ """
+ Serializes message to a byte string.
+ """
+ message = self.dumps(message)
+ # ensure message is bytes
+ if isinstance(message, str):
+ message = message.encode("utf-8")
+ if self.crypter:
+ message = self.crypter.encrypt(message)
+
+ if self.random_prefix_length > 0:
+ # provide random prefix
+ message = (
+ random.getrandbits(8 * self.random_prefix_length).to_bytes(
+ self.random_prefix_length, "big"
+ )
+ + message
+ )
+ return message
+
+ def deserialize(self, message):
+ """
+ Deserializes from a byte string.
+ """
+ if self.random_prefix_length > 0:
+ # Removes the random prefix
+ message = message[self.random_prefix_length :] # noqa: E203
+
+ if self.crypter:
+ ttl = self.expiry if self.expiry is None else self.expiry + 10
+ message = self.crypter.decrypt(message, ttl)
+ return self.loads(message)
+
+
+class MissingSerializer(BaseMessageSerializer):
+ exception = None
+
+ def __init__(self, *args, **kwargs):
+ raise self.exception
+
+
+class JSONSerializer(BaseMessageSerializer):
+ dumps = staticmethod(json.dumps)
+ loads = staticmethod(json.loads)
+
+
+# code ready for a future in which msgpack may become an optional dependency
+try:
+ import msgpack
+except ImportError as exc:
+
+ class MsgPackSerializer(MissingSerializer):
+ exception = exc
+
+else:
+
+ class MsgPackSerializer(BaseMessageSerializer):
+ dumps = staticmethod(msgpack.packb)
+ loads = staticmethod(msgpack.unpackb)
+
+
+class SerializersRegistry:
+ def __init__(self):
+ self._registry = {}
+
+ def register_serializer(self, format, serializer_class):
+ """
+ Register a new serializer for given format
+ """
+ assert isinstance(serializer_class, type) and (
+ issubclass(serializer_class, BaseMessageSerializer)
+ or hasattr(serializer_class, "serialize")
+ and hasattr(serializer_class, "deserialize")
+ ), """
+ `serializer_class` should be a class which implements `serialize` and `deserialize` method
+ or a subclass of `channels_redis.serializers.BaseMessageSerializer`
+ """
+
+ self._registry[format] = serializer_class
+
+ def get_serializer(self, format, *args, **kwargs):
+ try:
+ serializer_class = self._registry[format]
+ except KeyError:
+ raise SerializerDoesNotExist(format)
+
+ return serializer_class(*args, **kwargs)
+
+
+registry = SerializersRegistry()
+registry.register_serializer("json", JSONSerializer)
+registry.register_serializer("msgpack", MsgPackSerializer)
diff --git a/tests/test_serializers.py b/tests/test_serializers.py
new file mode 100644
index 0000000..85907d1
--- /dev/null
+++ b/tests/test_serializers.py
@@ -0,0 +1,48 @@
+import pytest
+
+from channels_redis.serializers import SerializerDoesNotExist, SerializersRegistry
+
+
+@pytest.fixture
+def registry():
+ return SerializersRegistry()
+
+
+class OnlySerialize:
+ def serialize(self, message):
+ return message
+
+
+class OnlyDeserialize:
+ def deserialize(self, message):
+ return message
+
+
+def bad_serializer():
+ pass
+
+
+class NoopSerializer:
+ def serialize(self, message):
+ return message
+
+ def deserialize(self, message):
+ return message
+
+
+@pytest.mark.parametrize(
+ "serializer_class", (OnlyDeserialize, OnlySerialize, bad_serializer)
+)
+def test_refuse_to_register_bad_serializers(registry, serializer_class):
+ with pytest.raises(AssertionError):
+ registry.register_serializer("custom", serializer_class)
+
+
+def test_raise_error_for_unregistered_serializer(registry):
+ with pytest.raises(SerializerDoesNotExist):
+ registry.get_serializer("unexistent")
+
+
+def test_register_custom_serializer(registry):
+ registry.register_serializer("custom", NoopSerializer)
+ registry.get_serializer("custom")