diff --git a/dill/_dill.py b/dill/_dill.py index be0aeba2..2aae255e 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -60,7 +60,7 @@ # import zlib from weakref import ReferenceType, ProxyType, CallableProxyType from collections import OrderedDict -from functools import partial +from functools import lru_cache, partial from operator import itemgetter, attrgetter GENERATOR_FAIL = False import importlib.machinery @@ -1028,6 +1028,23 @@ def _import_module(import_name, safe=False): return None raise +@lru_cache(maxsize=2**3) +def _import_module_copy(module_name, safe=False): + """load (a copy of) a module without adding it to sys.modules""" + import importlib.util + try: + spec = importlib.util.find_spec(module_name) + if spec is None: + raise ModuleNotFoundError("No module named %r" % module_name) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + except ImportError: + if safe: + return None + else: + raise + return module + # https://github.com/python/cpython/blob/a8912a0f8d9eba6d502c37d522221f9933e976db/Lib/pickle.py#L322-L333 def _getattribute(obj, name): for subpath in name.split('.'): @@ -1047,23 +1064,24 @@ def _locate_function(obj, pickler=None): if module_name in ['__main__', None] or \ pickler and is_dill(pickler, child=False) and pickler._session and module_name == pickler._main.__name__: return False - if hasattr(obj, '__qualname__'): + unloaded_module = module_name not in sys.modules + if unloaded_module: + module = _import_module_copy(module_name, safe=True) + else: module = _import_module(module_name, safe=True) - try: - found, _ = _getattribute(module, obj.__qualname__) - return found is obj - except AttributeError: - return False + try: + found, _ = _getattribute(module, getattr(obj, '__qualname__', obj.__name__)) + except AttributeError: + return False + if unloaded_module and inspect.isfunction(obj) and obj.__code__ == found.__code__: + return found else: - found = _import_module(module_name + '.' + obj.__name__, safe=True) return found is obj - def _setitems(dest, source): for k, v in source.items(): dest[k] = v - def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO_DEFAULT, postproc_list=None): if obj is Getattr.NO_DEFAULT: obj = Reduce(reduction) # pragma: no cover @@ -1765,7 +1783,8 @@ def save_classmethod(pickler, obj): @register(FunctionType) def save_function(pickler, obj): - if not _locate_function(obj, pickler): + found_obj = _locate_function(obj, pickler) + if not found_obj: if type(obj.__code__) is not CodeType: # Some PyPy builtin functions have no module name, and thus are not # able to be located @@ -1870,9 +1889,19 @@ def save_function(pickler, obj): pickler.write(bytes('0', 'UTF-8')) logger.trace(pickler, "# F1") + elif type(found_obj) is not bool: + # Deal with global functions of unloaded modules. + logger.trace(pickler, "F3: %s", obj) + name = getattr(obj, '__qualname__', obj.__name__) + sys.modules[obj.__module__] = _import_module_copy(obj.__module__, safe=True) # cached + try: + StockPickler.save_global(pickler, found_obj, name=name) + finally: + del sys.modules[obj.__module__] + logger.trace(pickler, "F3: %s", obj) else: logger.trace(pickler, "F2: %s", obj) - name = getattr(obj, '__qualname__', getattr(obj, '__name__', None)) + name = getattr(obj, '__qualname__', obj.__name__) StockPickler.save_global(pickler, obj, name=name) logger.trace(pickler, "# F2") return diff --git a/dill/tests/test_functions.py b/dill/tests/test_functions.py index d8c73396..e6c54c5d 100644 --- a/dill/tests/test_functions.py +++ b/dill/tests/test_functions.py @@ -132,7 +132,20 @@ def test_code_object(): except Exception as error: raise Exception("failed to construct code object with format version {}".format(version)) from error +def test_unloaded_module(): + # Function should be saved as global until module is *re*loaded. + from statistics import mean + pickle_loaded = dill.dumps(mean) + del sys.modules['statistics'] + pickle_unloaded = dill.dumps(mean) + assert 'statistics' not in sys.modules + import statistics + pickle_reloaded = dill.dumps(mean) + assert pickle_unloaded == pickle_loaded + assert pickle_reloaded != pickle_loaded + if __name__ == '__main__': test_functions() test_issue_510() test_code_object() + test_unloaded_module()