From 20d7c806138f6d63ccd630411b3e4df37eeba44b Mon Sep 17 00:00:00 2001 From: cassiobotaro Date: Wed, 12 Jan 2022 19:23:18 -0300 Subject: [PATCH] Guarantee empty dictionary when memory is initialized --- shared_memory_dict/dict.py | 19 ++++++++++++++----- tests/test_dict.py | 13 +++++++++++++ 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/shared_memory_dict/dict.py b/shared_memory_dict/dict.py index ec66df6..dd83489 100644 --- a/shared_memory_dict/dict.py +++ b/shared_memory_dict/dict.py @@ -15,7 +15,11 @@ ) from .lock import lock -from .serializers import PickleSerializer, SharedMemoryDictSerializer +from .serializers import ( + NULL_BYTE, + PickleSerializer, + SharedMemoryDictSerializer, +) from .templates import MEMORY_NAME NOT_GIVEN = object() @@ -38,6 +42,14 @@ def __init__( self._memory_block = self._get_or_create_memory_block( MEMORY_NAME.format(name=name), size ) + self._ensure_memory_initialization() + + def _ensure_memory_initialization(self): + memory_is_empty = ( + bytes(self._memory_block.buf).split(NULL_BYTE, 1)[0] == b'' + ) + if memory_is_empty: + self._save_memory({}) def cleanup(self) -> None: if not hasattr(self, '_memory_block'): @@ -159,10 +171,7 @@ def _get_or_create_memory_block( try: return SharedMemory(name=name) except FileNotFoundError: - shm = SharedMemory(name=name, create=True, size=size) - data = self._serializer.dumps({}) - shm.buf[: len(data)] = data - return shm + return SharedMemory(name=name, create=True, size=size) def _save_memory(self, db: Dict[str, Any]) -> None: data = self._serializer.dumps(db) diff --git a/tests/test_dict.py b/tests/test_dict.py index a37e387..e3e3675 100644 --- a/tests/test_dict.py +++ b/tests/test_dict.py @@ -5,6 +5,7 @@ from shared_memory_dict import SharedMemoryDict from shared_memory_dict.dict import DEFAULT_SERIALIZER from shared_memory_dict.serializers import JSONSerializer +from multiprocessing.shared_memory import SharedMemory DEFAULT_MEMORY_SIZE = 1024 @@ -16,6 +17,7 @@ def shared_memory_dict(self): yield smd smd.clear() smd.cleanup() + smd.shm.unlink() @pytest.fixture def key(self): @@ -215,3 +217,14 @@ def test_use_custom_serializer_when_specified(self): name='unit-tests', size=64, serializer=serializer ) assert smd._serializer is serializer + + def test_shoud_initialize_when_memory_is_empty(self): + SharedMemory(name='sm_ut', create=True, size=64) + smd = SharedMemoryDict(name='ut', size=64) + try: + print(smd) + except Exception as e: + pytest.fail(f'Its should not raises: {e}') + smd.clear() + smd.cleanup() + smd.shm.unlink()