Skip to content

Commit

Permalink
allow user override of reserved msgpack ext types
Browse files Browse the repository at this point in the history
  • Loading branch information
vsergeev committed Apr 1, 2018
1 parent 54d296f commit f035453
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 28 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ If a non-byte-string argument is passed to `umsgpack.unpackb()`, it will raise a
* Python tuples and lists are both packed into the msgpack array format
* Python float types are packed into the msgpack float32 or float64 format depending on the system's `sys.float_info`
* The Python `datetime.datetime` type is packed into, and unpacked from, the msgpack `timestamp` format
* Note that this Python type only supports microsecond resolution, while the msgpack `timestamp` format supports nanosecond resolution. Timestamps with finer than microsecond resolution will lose precision during unpacking.
* Note that this Python type only supports microsecond resolution, while the msgpack `timestamp` format supports nanosecond resolution. Timestamps with finer than microsecond resolution will lose precision during unpacking. Users may override the packing and unpacking of the msgpack `timestamp` format with a custom type for alternate behavior.

## Testing

Expand Down
41 changes: 37 additions & 4 deletions test_umsgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,22 @@
b"\xd7\x30\x93\xc4\x03\x61\x62\x63\x7b\xc3"],
]

override_ext_handlers = {
datetime.datetime:
lambda obj: umsgpack.Ext(0x40, obj.strftime("%Y%m%dT%H:%M:%S.%f").encode()),
-0x01:
lambda ext: ext,
}

override_ext_handlers_test_vectors = [
["pack override",
datetime.datetime(2000, 1, 1, 10, 5, 2, 0, umsgpack._utc_tzinfo),
b'\xc7\x18@20000101T10:05:02.000000'],
["unpack override",
umsgpack.Ext(-0x01, b"\x00\xbb\xcc\xdd\x01\x02\x03\x04\x05\x06\x07\x08"),
b'\xc7\x0c\xff\x00\xbb\xcc\xdd\x01\x02\x03\x04\x05\x06\x07\x08'],
]

# These are the only global variables that should be exported by umsgpack
exported_vars_test_vector = [
"Ext",
Expand Down Expand Up @@ -492,10 +508,7 @@ def test_unpack_ordered_dict(self):

def test_ext_exceptions(self):
with self.assertRaises(TypeError):
_ = umsgpack.Ext(-1, b"")

with self.assertRaises(TypeError):
_ = umsgpack.Ext(128, b"")
_ = umsgpack.Ext(5.0, b"")

with self.assertRaises(TypeError):
_ = umsgpack.Ext(0, u"unicode string")
Expand Down Expand Up @@ -527,6 +540,26 @@ def test_pack_force_float_precision(self):
packed = umsgpack.packb(obj, force_float_precision=precision)
self.assertEqual(packed, data)

def test_pack_ext_override(self):
# Test overridden packing of datetime.datetime
(name, obj, data) = override_ext_handlers_test_vectors[0]
obj_repr = repr(obj)
print("\tTesting %s: object %s" %
(name, obj_repr if len(obj_repr) < 24 else obj_repr[0:24] + "..."))

packed = umsgpack.packb(obj, ext_handlers=override_ext_handlers)
self.assertEqual(packed, data)

def test_unpack_ext_override(self):
# Test overridden unpacking of Ext type -1
(name, obj, data) = override_ext_handlers_test_vectors[1]
obj_repr = repr(obj)
print("\tTesting %s: object %s" %
(name, obj_repr if len(obj_repr) < 24 else obj_repr[0:24] + "..."))

unpacked = umsgpack.unpackb(data, ext_handlers=override_ext_handlers)
self.assertEqual(unpacked, obj)

def test_streaming_writer(self):
# Try first composite test vector
(_, obj, data) = composite_test_vectors[0]
Expand Down
42 changes: 19 additions & 23 deletions umsgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,9 @@ def __init__(self, type, data):
Construct a new Ext object.
Args:
type: application-defined type integer from 0 to 127
type: application-defined type integer
data: application-defined data byte array
Raises:
TypeError:
Specified ext type is outside of 0 to 127 range.
Example:
>>> foo = umsgpack.Ext(0x05, b"\x01\x02\x03")
>>> umsgpack.packb({u"special stuff": foo, u"awesome": True})
Expand All @@ -88,9 +84,9 @@ def __init__(self, type, data):
Ext Object (Type: 0x05, Data: 01 02 03)
>>>
"""
# Application ext type should be 0 <= type <= 127
if not isinstance(type, int) or not (type >= 0 and type <= 127):
raise TypeError("ext type out of range")
# Check type is type int
if not isinstance(type, int):
raise TypeError("ext type is not type integer")
# Check data is type bytes
elif sys.version_info[0] == 3 and not isinstance(data, bytes):
raise TypeError("ext data is not type \'bytes\'")
Expand Down Expand Up @@ -739,38 +735,38 @@ def _unpack_ext(code, fp, options):
ext_type = struct.unpack("b", _read_except(fp, 1))[0]
ext_data = _read_except(fp, length)

# Timestamp extension
if ext_type == -1:
return _unpack_ext_timestamp(code, ext_data, options)

# Application extension
# Create extension object
ext = Ext(ext_type, ext_data)

# Unpack with ext handler, if we have one
ext_handlers = options.get("ext_handlers")
if ext_handlers and ext.type in ext_handlers:
ext = ext_handlers[ext.type](ext)
return ext_handlers[ext.type](ext)

# Timestamp extension
if ext.type == -1:
return _unpack_ext_timestamp(ext, options)

return ext


def _unpack_ext_timestamp(code, data, options):
if len(data) == 4:
def _unpack_ext_timestamp(ext, options):
if len(ext.data) == 4:
# 32-bit timestamp
seconds = struct.unpack(">I", data)[0]
seconds = struct.unpack(">I", ext.data)[0]
microseconds = 0
elif len(data) == 8:
elif len(ext.data) == 8:
# 64-bit timestamp
value = struct.unpack(">Q", data)[0]
value = struct.unpack(">Q", ext.data)[0]
seconds = value & 0x3ffffffff
microseconds = (value >> 34) // 1000
elif len(data) == 12:
elif len(ext.data) == 12:
# 96-bit timestamp
seconds = struct.unpack(">q", data[4:12])[0]
microseconds = struct.unpack(">I", data[0:4])[0] // 1000
seconds = struct.unpack(">q", ext.data[4:12])[0]
microseconds = struct.unpack(">I", ext.data[0:4])[0] // 1000
else:
raise UnsupportedTimestampException(
"unsupported timestamp with data length %d" % len(data))
"unsupported timestamp with data length %d" % len(ext.data))

return _epoch + datetime.timedelta(seconds=seconds,
microseconds=microseconds)
Expand Down

0 comments on commit f035453

Please sign in to comment.