Skip to content

Commit

Permalink
add ext_serializable() decorator
Browse files Browse the repository at this point in the history
Co-authored-by: Gabe Appleton <[email protected]>
  • Loading branch information
vsergeev and LivInTheLookingGlass committed Apr 25, 2020
1 parent 871a953 commit 8aa0ee8
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 0 deletions.
63 changes: 63 additions & 0 deletions test_umsgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@
"DuplicateKeyException",
"KeyNotPrimitiveException",
"KeyDuplicateException",
"ext_serializable",
"pack",
"packb",
"unpack",
Expand Down Expand Up @@ -609,6 +610,68 @@ def test_unpack_ext_override(self):
unpacked = umsgpack.unpackb(data, ext_handlers=override_ext_handlers)
self.assertEqual(unpacked, obj)

def test_ext_serializable(self):
# Register test class
@umsgpack.ext_serializable(0x20)
class CustomComplex:
def __init__(self, real, imag):
self.real = real
self.imag = imag

def __eq__(self, other):
return self.real == other.real and self.imag == other.imag

def packb(self):
return struct.pack("<II", self.real, self.imag)

@classmethod
def unpackb(cls, data):
return cls(*struct.unpack("<II", data))

obj, data = CustomComplex(123, 456), b"\xd7\x20\x7b\x00\x00\x00\xc8\x01\x00\x00"

# Test pack
packed = umsgpack.packb(obj)
self.assertEqual(packed, data)

# Test unpack
unpacked = umsgpack.unpackb(packed)
self.assertTrue(isinstance(unpacked, CustomComplex))
self.assertEqual(unpacked, obj)

_, obj, data = ext_handlers_test_vectors[0]

# Test pack priority of ext_handlers over ext_serializable()
packed = umsgpack.packb(obj, ext_handlers=ext_handlers)
self.assertEqual(packed, data)

# Test unpack priority of ext_handlers over ext_serializable()
unpacked = umsgpack.unpackb(data, ext_handlers=ext_handlers)
self.assertTrue(isinstance(unpacked, complex))
self.assertEqual(unpacked, obj)

# Test registration collision
with self.assertRaises(ValueError):
@umsgpack.ext_serializable(0x20)
class DummyClass:
pass

# Register class with missing packb() and unpackb()
@umsgpack.ext_serializable(0x21)
class IncompleteClass:
pass

# Test unimplemented packb()
with self.assertRaises(NotImplementedError):
umsgpack.packb(IncompleteClass())

# Test unimplemented unpackb()
with self.assertRaises(NotImplementedError):
umsgpack.unpackb(b"\xd4\x21\x00")

# Unregister Ext serializable classes for future tests
umsgpack._ext_classes = {}

def test_streaming_writer(self):
# Try first composite test vector
(_, obj, data) = composite_test_vectors[0]
Expand Down
54 changes: 54 additions & 0 deletions umsgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,43 @@ def __hash__(self):
class InvalidString(bytes):
"""Subclass of bytes to hold invalid UTF-8 strings."""


##############################################################################
# Ext Serializable Decorator
##############################################################################

_ext_classes = {}


def ext_serializable(ext_type):
"""
Return a decorator to register a class for automatic packing and unpacking
with the specified Ext type code. The application class should implement a
`packb()` method that returns serialized bytes, and an `unpackb()` class
method or static method that accepts serialized bytes and returns an
instance of the application class.
Args:
ext_type: application-defined Ext type code
Raises:
ValueError:
Ext type or class already registered.
"""
def wrapper(cls):
if ext_type in _ext_classes:
raise ValueError("Ext type 0x{:02x} already registered with class {:s}".format(ext_type, repr(_ext_classes[ext_type])))
elif cls in _ext_classes:
raise ValueError("Class {:s} already registered with Ext type 0x{:02x}".format(repr(cls), ext_type))

_ext_classes[ext_type] = cls
_ext_classes[cls] = ext_type

return cls

return wrapper


##############################################################################
# Exceptions
##############################################################################
Expand Down Expand Up @@ -435,6 +472,11 @@ def _pack2(obj, fp, **options):
_pack_nil(obj, fp, options)
elif ext_handlers and obj.__class__ in ext_handlers:
_pack_ext(ext_handlers[obj.__class__](obj), fp, options)
elif obj.__class__ in _ext_classes:
try:
_pack_ext(Ext(_ext_classes[obj.__class__], obj.packb()), fp, options)
except AttributeError:
raise NotImplementedError("Ext serializable class {:s} is missing implementation of packb()".format(repr(obj.__class__)))
elif isinstance(obj, bool):
_pack_boolean(obj, fp, options)
elif isinstance(obj, (int, long)):
Expand Down Expand Up @@ -507,6 +549,11 @@ def _pack3(obj, fp, **options):
_pack_nil(obj, fp, options)
elif ext_handlers and obj.__class__ in ext_handlers:
_pack_ext(ext_handlers[obj.__class__](obj), fp, options)
elif obj.__class__ in _ext_classes:
try:
_pack_ext(Ext(_ext_classes[obj.__class__], obj.packb()), fp, options)
except AttributeError:
raise NotImplementedError("Ext serializable class {:s} is missing implementation of packb()".format(repr(obj.__class__)))
elif isinstance(obj, bool):
_pack_boolean(obj, fp, options)
elif isinstance(obj, int):
Expand Down Expand Up @@ -751,6 +798,13 @@ def _unpack_ext(code, fp, options):
if ext_handlers and ext_type in ext_handlers:
return ext_handlers[ext_type](Ext(ext_type, ext_data))

# Unpack with ext classes, if type is registered
if ext_type in _ext_classes:
try:
return _ext_classes[ext_type].unpackb(ext_data)
except AttributeError:
raise NotImplementedError("Ext serializable class {:s} is missing implementation of unpackb()".format(repr(_ext_classes[ext_type])))

# Timestamp extension
if ext_type == -1:
return _unpack_ext_timestamp(ext_data, options)
Expand Down

0 comments on commit 8aa0ee8

Please sign in to comment.