Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative fix for bug when pickling function from unloaded module #536

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
# import zlib
from weakref import ReferenceType, ProxyType, CallableProxyType
from collections import OrderedDict
from functools import partial
from functools import lru_cache, partial
from operator import itemgetter, attrgetter
GENERATOR_FAIL = False
import importlib.machinery
Expand Down Expand Up @@ -1028,6 +1028,23 @@ def _import_module(import_name, safe=False):
return None
raise

@lru_cache(maxsize=2**3)
def _import_module_copy(module_name, safe=False):
"""load (a copy of) a module without adding it to sys.modules"""
import importlib.util
try:
spec = importlib.util.find_spec(module_name)
if spec is None:
raise ModuleNotFoundError("No module named %r" % module_name)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
except ImportError:
if safe:
return None
else:
raise
return module

# https://github.com/python/cpython/blob/a8912a0f8d9eba6d502c37d522221f9933e976db/Lib/pickle.py#L322-L333
def _getattribute(obj, name):
for subpath in name.split('.'):
Expand All @@ -1047,23 +1064,24 @@ def _locate_function(obj, pickler=None):
if module_name in ['__main__', None] or \
pickler and is_dill(pickler, child=False) and pickler._session and module_name == pickler._main.__name__:
return False
if hasattr(obj, '__qualname__'):
unloaded_module = module_name not in sys.modules
if unloaded_module:
module = _import_module_copy(module_name, safe=True)
else:
module = _import_module(module_name, safe=True)
try:
found, _ = _getattribute(module, obj.__qualname__)
return found is obj
except AttributeError:
return False
try:
found, _ = _getattribute(module, getattr(obj, '__qualname__', obj.__name__))
except AttributeError:
return False
if unloaded_module and inspect.isfunction(obj) and obj.__code__ == found.__code__:
return found
else:
found = _import_module(module_name + '.' + obj.__name__, safe=True)
return found is obj


def _setitems(dest, source):
for k, v in source.items():
dest[k] = v


def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO_DEFAULT, postproc_list=None):
if obj is Getattr.NO_DEFAULT:
obj = Reduce(reduction) # pragma: no cover
Expand Down Expand Up @@ -1765,7 +1783,8 @@ def save_classmethod(pickler, obj):

@register(FunctionType)
def save_function(pickler, obj):
if not _locate_function(obj, pickler):
found_obj = _locate_function(obj, pickler)
if not found_obj:
if type(obj.__code__) is not CodeType:
# Some PyPy builtin functions have no module name, and thus are not
# able to be located
Expand Down Expand Up @@ -1870,9 +1889,19 @@ def save_function(pickler, obj):
pickler.write(bytes('0', 'UTF-8'))

logger.trace(pickler, "# F1")
elif type(found_obj) is not bool:
# Deal with global functions of unloaded modules.
logger.trace(pickler, "F3: %s", obj)
name = getattr(obj, '__qualname__', obj.__name__)
sys.modules[obj.__module__] = _import_module_copy(obj.__module__, safe=True) # cached
try:
StockPickler.save_global(pickler, found_obj, name=name)
finally:
del sys.modules[obj.__module__]
logger.trace(pickler, "F3: %s", obj)
else:
logger.trace(pickler, "F2: %s", obj)
name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
name = getattr(obj, '__qualname__', obj.__name__)
StockPickler.save_global(pickler, obj, name=name)
logger.trace(pickler, "# F2")
return
Expand Down
13 changes: 13 additions & 0 deletions dill/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,20 @@ def test_code_object():
except Exception as error:
raise Exception("failed to construct code object with format version {}".format(version)) from error

def test_unloaded_module():
# Function should be saved as global until module is *re*loaded.
from statistics import mean
pickle_loaded = dill.dumps(mean)
del sys.modules['statistics']
pickle_unloaded = dill.dumps(mean)
assert 'statistics' not in sys.modules
import statistics
pickle_reloaded = dill.dumps(mean)
assert pickle_unloaded == pickle_loaded
assert pickle_reloaded != pickle_loaded

if __name__ == '__main__':
test_functions()
test_issue_510()
test_code_object()
test_unloaded_module()