Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ABC and Enums #450

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 161 additions & 23 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,17 @@
TypeType = type # 'new-style' classes #XXX: unregistered
XRangeType = range
from types import MappingProxyType as DictProxyType
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, \
PicklingError, UnpicklingError
import __main__ as _main_module
import marshal
import gc
# import zlib
import abc
import dataclasses
from weakref import ReferenceType, ProxyType, CallableProxyType
from collections import OrderedDict
from enum import Enum, EnumMeta
from functools import partial
from operator import itemgetter, attrgetter
GENERATOR_FAIL = False
Expand Down Expand Up @@ -1057,7 +1060,6 @@ def _locate_function(obj, pickler=None):
found = _import_module(module_name + '.' + obj.__name__, safe=True)
return found is obj


def _setitems(dest, source):
for k, v in source.items():
dest[k] = v
Expand Down Expand Up @@ -1669,6 +1671,93 @@ def save_module(pickler, obj):
logger.trace(pickler, "# M2")
return

# The following function is based on '_extract_class_dict' from 'cloudpickle'
# Copyright (c) 2012, Regents of the University of California.
# Copyright (c) 2009 `PiCloud, Inc. <http://www.picloud.com>`_.
# License: https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE
def _get_typedict_type(cls, clsdict, postproc_list):
"""Retrieve a copy of the dict of a class without the inherited methods"""
if len(cls.__bases__) == 1:
inherited_dict = cls.__bases__[0].__dict__
else:
inherited_dict = {}
for base in reversed(cls.__bases__):
inherited_dict.update(base.__dict__)
to_remove = []
for name, value in dict.items(clsdict):
try:
base_value = inherited_dict[name]
if value is base_value:
to_remove.append(name)
except KeyError:
pass
for name in to_remove:
dict.pop(clsdict, name)

if issubclass(type(cls), type):
clsdict.pop('__dict__', None)
clsdict.pop('__weakref__', None)
# clsdict.pop('__prepare__', None)
return clsdict

def _get_typedict_abc(obj, _dict, attrs, postproc_list):
if hasattr(abc, '_get_dump'):
(registry, _, _, _) = abc._get_dump(obj)
register = obj.register
postproc_list.extend((register, (reg(),)) for reg in registry)
elif hasattr(obj, '_abc_registry'):
registry = obj._abc_registry
register = obj.register
postproc_list.extend((register, (reg,)) for reg in registry)
else:
raise PicklingError("Cannot find registry of ABC %s", obj)

if '_abc_registry' in _dict:
del _dict['_abc_registry']
del _dict['_abc_cache']
del _dict['_abc_negative_cache']
# del _dict['_abc_negative_cache_version']
else:
del _dict['_abc_impl']
return _dict, attrs

CORE_CLASSES = {int, float, type(None), str, dict, tuple, set, list, frozenset}

def _get_typedict_enum(obj, _dict, attrs, postproc_list):
base = None

metacls = type(obj)
original_dict = {}
for name, enum_value in obj.__members__.items():
value = enum_value.value
if base is None:
import copyreg
base = type(value)
reducer = copyreg.dispatch_table.get(base, None)

if base is tuple:
init_value = (value,)
elif base in CORE_CLASSES:
init_value = value
else:
init_value = reducer(value) if reducer else value.__reduce__()
if init_value[0] is not base or len(init_value) != 2:
raise PickleError('Cannot pickle Enum class, reduction too complex')
init_value = init_value[1]
original_dict[name] = init_value
del _dict[name]

_dict.pop('_member_names_', None)
_dict.pop('_member_map_', None)
_dict.pop('_value2member_map_', None)
_dict.pop('_generate_next_value_', None)

if attrs is not None:
attrs.update(_dict)
_dict = attrs

return original_dict, _dict

@register(TypeType)
def save_type(pickler, obj, postproc_list=None):
if obj in _typemap:
Expand All @@ -1680,15 +1769,22 @@ def save_type(pickler, obj, postproc_list=None):
elif obj.__bases__ == (tuple,) and all([hasattr(obj, attr) for attr in ('_fields','_asdict','_make','_replace')]):
# special case: namedtuples
logger.trace(pickler, "T6: %s", obj)

obj_name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
if obj.__name__ != obj_name:
if postproc_list is None:
postproc_list = []
postproc_list.append((setattr, (obj, '__qualname__', obj_name)))

if not obj._field_defaults:
pickler.save_reduce(_create_namedtuple, (obj.__name__, obj._fields, obj.__module__), obj=obj)
_save_with_postproc(pickler, (_create_namedtuple, (obj.__name__, obj._fields, obj.__module__)), obj=obj, postproc_list=postproc_list)
else:
defaults = [obj._field_defaults[field] for field in obj._fields if field in obj._field_defaults]
pickler.save_reduce(_create_namedtuple, (obj.__name__, obj._fields, obj.__module__, defaults), obj=obj)
_save_with_postproc(pickler, (_create_namedtuple, (obj.__name__, obj._fields, obj.__module__, defaults)), obj=obj, postproc_list=postproc_list)
logger.trace(pickler, "# T6")
return

# special cases: NoneType, NotImplementedType, EllipsisType
# special cases: NoneType, NotImplementedType, EllipsisType, EnumMeta
elif obj is type(None):
logger.trace(pickler, "T7: %s", obj)
#XXX: pickler.save_reduce(type, (None,), obj=obj)
Expand All @@ -1702,35 +1798,74 @@ def save_type(pickler, obj, postproc_list=None):
logger.trace(pickler, "T7: %s", obj)
pickler.save_reduce(type, (Ellipsis,), obj=obj)
logger.trace(pickler, "# T7")
elif obj is EnumMeta:
logger.trace(pickler, "T7: %s", obj)
pickler.write(GLOBAL + b'enum\nEnumMeta\n')
logger.trace(pickler, "# T7")

else:
obj_name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
_byref = getattr(pickler, '_byref', None)
obj_recursive = id(obj) in getattr(pickler, '_postproc', ())
incorrectly_named = not _locate_function(obj, pickler)
if not _byref and not obj_recursive and incorrectly_named: # not a function, but the name was held over
if postproc_list is None:
postproc_list = []

# thanks to Tom Stepleton pointing out pickler._session unneeded
logger.trace(pickler, "T2: %s", obj)
_dict = obj.__dict__.copy() # convert dictproxy to dict
#print (_dict)
#print ("%s\n%s" % (type(obj), obj.__name__))
#print ("%s\n%s" % (obj.__bases__, obj.__dict__))
_dict = _get_typedict_type(obj, obj.__dict__.copy(), postproc_list) # copy dict proxy to a dict
attrs = None

slots = _dict.get('__slots__', ())
if type(slots) == str: slots = (slots,) # __slots__ accepts a single string
if type(slots) == str:
# __slots__ accepts a single string
slots = (slots,)
for name in slots:
del _dict[name]
_dict.pop('__dict__', None)
_dict.pop('__weakref__', None)
_dict.pop('__prepare__', None)
if obj_name != obj.__name__:
if postproc_list is None:
postproc_list = []
postproc_list.append((setattr, (obj, '__qualname__', obj_name)))
_save_with_postproc(pickler, (_create_type, (
type(obj), obj.__name__, obj.__bases__, _dict
)), obj=obj, postproc_list=postproc_list)
_dict.pop(name, None)

if isinstance(obj, abc.ABCMeta):
logger.trace(pickler, "ABC: %s", obj)
_dict, attrs = _get_typedict_abc(obj, _dict, attrs, postproc_list)
logger.trace(pickler, "# ABC")

if isinstance(obj, EnumMeta):
logger.trace(pickler, "E: %s", obj)
_dict, attrs = _get_typedict_enum(obj, _dict, attrs, postproc_list)
logger.trace(pickler, "# E")

qualname = getattr(obj, '__qualname__', None)
if attrs is not None:
if qualname is not None:
attrs['__qualname__'] = qualname
for k, v in attrs.items():
postproc_list.append((setattr, (obj, k, v)))
# TODO: Consider using the state argument to save_reduce?
elif qualname is not None:
postproc_list.append((setattr, (obj, '__qualname__', qualname)))

if False: # not hasattr(obj, '__orig_bases__'):
_save_with_postproc(pickler, (_create_type, (
type(obj), obj.__name__, obj.__bases__, _dict
)), obj=obj, postproc_list=postproc_list)
else:
# This case will always work, but might be overkill.
from types import new_class
_metadict = {
'metaclass': type(obj)
}

if _dict:
_dict_update = PartialType(_setitems, source=_dict)
else:
_dict_update = None

bases = getattr(obj, '__orig_bases__', obj.__bases__)
_save_with_postproc(pickler, (new_class, (
obj.__name__, bases, _metadict, _dict_update
)), obj=obj, postproc_list=postproc_list)
logger.trace(pickler, "# T2")
else:
obj_name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
logger.trace(pickler, "T4: %s", obj)
if incorrectly_named:
warnings.warn(
Expand All @@ -1753,14 +1888,17 @@ def save_type(pickler, obj, postproc_list=None):
return

@register(property)
@register(abc.abstractproperty)
def save_property(pickler, obj):
logger.trace(pickler, "Pr: %s", obj)
pickler.save_reduce(property, (obj.fget, obj.fset, obj.fdel, obj.__doc__),
pickler.save_reduce(type(obj), (obj.fget, obj.fset, obj.fdel, obj.__doc__),
obj=obj)
logger.trace(pickler, "# Pr")

@register(staticmethod)
@register(classmethod)
@register(abc.abstractstaticmethod)
@register(abc.abstractclassmethod)
def save_classmethod(pickler, obj):
logger.trace(pickler, "Cm: %s", obj)
orig_func = obj.__func__
Expand Down
Loading