Skip to content

Commit

Permalink
Merge pull request #45 from luizalabs/custom-serialization-exceptions
Browse files Browse the repository at this point in the history
Custom serialization exceptions
  • Loading branch information
cassiobotaro authored Nov 4, 2021
2 parents 3432a5e + bf22c57 commit ab9c014
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 26 deletions.
25 changes: 20 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Shared Memory Dict

A very simple [shared memory](https://docs.python.org/3/library/multiprocessing.shared_memory.html) dict implementation.

**Requires**: Python >= 3.8
Expand Down Expand Up @@ -35,12 +36,15 @@ A very simple [shared memory](https://docs.python.org/3/library/multiprocessing.
> The size (in bytes) occupied by the contents of the dictionary depends on the serialization used in storage. By default pickle is used.
## Installation

Using `pip`:

```shell
pip install shared-memory-dict
```

## Locks

To use [multiprocessing.Lock](https://docs.python.org/3.8/library/multiprocessing.html#multiprocessing.Lock) on write operations of shared memory dict set environment variable `SHARED_MEMORY_USE_LOCK=1`.

## Serialization
Expand All @@ -49,24 +53,34 @@ We use [pickle](https://docs.python.org/3/library/pickle.html) as default to rea

You can create a custom serializer by implementing the `dumps` and `loads` methods.

Custom serializers should raise `SerializationError` if the serialization fails and `DeserializationError` if the deserialization fails. Both are defined in the `shared_memory_dict.serializers` module.

An example of a JSON serializer extracted from serializers module:

```python
NULL_BYTE: Final = b"\x00"


class JSONSerializer:
def dumps(self, obj: dict) -> bytes:
return json.dumps(obj).encode() + NULL_BYTE
try:
return json.dumps(obj).encode() + NULL_BYTE
except (ValueError, TypeError):
raise SerializationError(obj)

def loads(self, data: bytes) -> dict:
data = data.split(NULL_BYTE, 1)[0]
return json.loads(data)
try:
return json.loads(data)
except json.JSONDecodeError:
raise DeserializationError(data)

```

Note: A null byte is used to separate the dictionary contents from the bytes that are in memory.

To use the custom serializer you must set it when creating a new shared memory dict instance:

```python
>>> smd = SharedMemoryDict(name='tokens', size=1024, serializer=JSONSerializer())
```
Expand All @@ -77,8 +91,8 @@ The pickle module is not secure. Only unpickle data you trust.

See more [here](https://docs.python.org/3/library/pickle.html).


## Django Cache Implementation

There's a [Django Cache Implementation](https://docs.djangoproject.com/en/3.0/topics/cache/) with Shared Memory Dict:

```python
Expand All @@ -95,10 +109,11 @@ CACHES = {
**Install with**: `pip install "shared-memory-dict[django]"`

### Caveat
With Django cache implementation the keys only expire when they're read. Be careful with memory usage

With Django cache implementation the keys only expire when they're read. Be careful with memory usage

## AioCache Backend

There's also a [AioCache Backend Implementation](https://aiocache.readthedocs.io/en/latest/caches.html) with Shared Memory Dict:

```python
Expand Down
6 changes: 1 addition & 5 deletions shared_memory_dict/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,7 @@ def _save_memory(self, db: Dict[str, Any]) -> None:
raise ValueError("exceeds available storage") from exc

def _read_memory(self) -> Dict[str, Any]:
try:
return self._serializer.loads(self._memory_block.buf.tobytes())
except Exception as exc:
logger.warning(f"Fail to load data: {exc!r}")
return {}
return self._serializer.loads(self._memory_block.buf.tobytes())

@property
def shm(self) -> SharedMemory:
Expand Down
30 changes: 26 additions & 4 deletions shared_memory_dict/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
NULL_BYTE: Final = b"\x00"


class SerializationError(ValueError):
def __init__(self, data: dict) -> None:
super().__init__(f"Failed to serialize data: {data!r}")


class DeserializationError(ValueError):
def __init__(self, data: bytes) -> None:
super().__init__(f"Failed to deserialize data: {data!r}")


class SharedMemoryDictSerializer(Protocol):
def dumps(self, obj: dict) -> bytes:
...
Expand All @@ -15,16 +25,28 @@ def loads(self, data: bytes) -> dict:

class JSONSerializer:
def dumps(self, obj: dict) -> bytes:
return json.dumps(obj).encode() + NULL_BYTE
try:
return json.dumps(obj).encode() + NULL_BYTE
except (ValueError, TypeError):
raise SerializationError(obj)

def loads(self, data: bytes) -> dict:
data = data.split(NULL_BYTE, 1)[0]
return json.loads(data)
try:
return json.loads(data)
except json.JSONDecodeError:
raise DeserializationError(data)


class PickleSerializer:
def dumps(self, obj: dict) -> bytes:
return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
try:
return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
except pickle.PicklingError:
raise SerializationError(obj)

def loads(self, data: bytes) -> dict:
return pickle.loads(data)
try:
return pickle.loads(data)
except pickle.UnpicklingError:
raise DeserializationError(data)
15 changes: 4 additions & 11 deletions tests/test_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import sys

import pytest
Expand Down Expand Up @@ -192,14 +191,16 @@ def test_raise_an_error_when_memory_is_full(
):
with pytest.raises(ValueError, match="exceeds available storage"):
shared_memory_dict[key] = big_value

def test_should_expose_shared_memory(self, shared_memory_dict):
try:
shared_memory_dict.shm
except AttributeError:
pytest.fail('Should expose shared memory')

def test_shared_memory_attribute_should_be_read_only(self, shared_memory_dict):
def test_shared_memory_attribute_should_be_read_only(
self, shared_memory_dict
):
with pytest.raises(AttributeError):
shared_memory_dict.shm = 'test'

Expand All @@ -214,11 +215,3 @@ def test_use_custom_serializer_when_specified(self):
name='unit-tests', size=64, serializer=serializer
)
assert smd._serializer is serializer

def test_should_log_when_failed_to_load_shared_memory_content(self, shared_memory_dict, key, value, caplog):
smd = SharedMemoryDict(
name='ut', size=DEFAULT_MEMORY_SIZE, serializer=JSONSerializer()
)
with caplog.at_level(logging.WARNING):
smd[key] = value
assert "Fail to load data:" in caplog.text
61 changes: 60 additions & 1 deletion tests/test_serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import pytest

from shared_memory_dict.serializers import JSONSerializer, PickleSerializer
from shared_memory_dict.serializers import (
JSONSerializer,
PickleSerializer,
SerializationError,
DeserializationError,
)


class TestPickleSerializer:
Expand All @@ -19,6 +24,18 @@ def bytes_content(self):
def dict_content(self):
return {"key": "value"}

@pytest.fixture
def dict_not_serializable(self):
class C:
def __reduce__(self):
return (C, None)

return {"key": C()}

@pytest.fixture
def bytes_content_with_invalid_pickle(self):
return b'not pickle'

def test_loads_should_transform_bytes_into_dict(
self, pickle_serializer, bytes_content, dict_content
):
Expand All @@ -29,6 +46,23 @@ def test_dumps_should_transform_dict_into_bytes(
):
assert pickle_serializer.dumps(dict_content) == bytes_content

def test_should_raise_deserialization_error_when_content_is_not_pickle(
self, pickle_serializer, bytes_content_with_invalid_pickle
):
with pytest.raises(
DeserializationError, match="Failed to deserialize data"
):
pickle_serializer.loads(bytes_content_with_invalid_pickle)

def test_should_raise_serialization_error_when_content_is_not_pickle(
self, pickle_serializer, dict_not_serializable
):
with pytest.raises(
SerializationError, match="Failed to serialize data"
):
# sets are not pickle serializable
pickle_serializer.dumps(dict_not_serializable)


class TestJSONSerializer:
@pytest.fixture
Expand All @@ -43,6 +77,14 @@ def bytes_content(self):
def dict_content(self):
return {"key": "value"}

@pytest.fixture
def dict_not_serializable(self):
return {"key": {1, 2, 3}}

@pytest.fixture
def bytes_content_with_invalid_json(self):
return b'not json'

def test_loads_should_transform_bytes_into_dict(
self, json_serializer, bytes_content, dict_content
):
Expand All @@ -52,3 +94,20 @@ def test_dumps_should_transform_dict_into_bytes(
self, json_serializer, bytes_content, dict_content
):
assert json_serializer.dumps(dict_content) == bytes_content

def test_should_raise_desserialization_error_when_content_is_not_json(
self, json_serializer, bytes_content_with_invalid_json
):
with pytest.raises(
DeserializationError, match="Failed to deserialize data"
):
json_serializer.loads(bytes_content_with_invalid_json)

def test_should_raise_serialization_error_when_content_is_not_json(
self, json_serializer, dict_not_serializable
):
with pytest.raises(
SerializationError, match="Failed to serialize data"
):
# sets are not json serializable
json_serializer.dumps(dict_not_serializable)

0 comments on commit ab9c014

Please sign in to comment.