Skip to content

Commit

Permalink
Guarantee empty dictionary when memory is initialized
Browse files Browse the repository at this point in the history
  • Loading branch information
cassiobotaro committed Jan 12, 2022
1 parent ab9c014 commit 20d7c80
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
19 changes: 14 additions & 5 deletions shared_memory_dict/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
)

from .lock import lock
from .serializers import PickleSerializer, SharedMemoryDictSerializer
from .serializers import (
NULL_BYTE,
PickleSerializer,
SharedMemoryDictSerializer,
)
from .templates import MEMORY_NAME

NOT_GIVEN = object()
Expand All @@ -38,6 +42,14 @@ def __init__(
self._memory_block = self._get_or_create_memory_block(
MEMORY_NAME.format(name=name), size
)
self._ensure_memory_initialization()

def _ensure_memory_initialization(self):
memory_is_empty = (
bytes(self._memory_block.buf).split(NULL_BYTE, 1)[0] == b''
)
if memory_is_empty:
self._save_memory({})

def cleanup(self) -> None:
if not hasattr(self, '_memory_block'):
Expand Down Expand Up @@ -159,10 +171,7 @@ def _get_or_create_memory_block(
try:
return SharedMemory(name=name)
except FileNotFoundError:
shm = SharedMemory(name=name, create=True, size=size)
data = self._serializer.dumps({})
shm.buf[: len(data)] = data
return shm
return SharedMemory(name=name, create=True, size=size)

def _save_memory(self, db: Dict[str, Any]) -> None:
data = self._serializer.dumps(db)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from shared_memory_dict import SharedMemoryDict
from shared_memory_dict.dict import DEFAULT_SERIALIZER
from shared_memory_dict.serializers import JSONSerializer
from multiprocessing.shared_memory import SharedMemory

DEFAULT_MEMORY_SIZE = 1024

Expand All @@ -16,6 +17,7 @@ def shared_memory_dict(self):
yield smd
smd.clear()
smd.cleanup()
smd.shm.unlink()

@pytest.fixture
def key(self):
Expand Down Expand Up @@ -215,3 +217,14 @@ def test_use_custom_serializer_when_specified(self):
name='unit-tests', size=64, serializer=serializer
)
assert smd._serializer is serializer

def test_shoud_initialize_when_memory_is_empty(self):
SharedMemory(name='sm_ut', create=True, size=64)
smd = SharedMemoryDict(name='ut', size=64)
try:
print(smd)
except Exception as e:
pytest.fail(f'Its should not raises: {e}')
smd.clear()
smd.cleanup()
smd.shm.unlink()

0 comments on commit 20d7c80

Please sign in to comment.