From 32fdf11de5a0fa09cf88accac60ead2b3ac72088 Mon Sep 17 00:00:00 2001 From: anivegesana Date: Wed, 3 Aug 2022 23:47:55 -0700 Subject: [PATCH] Shared namespaces --- dill/_dill.py | 22 ++++++++++++++++------ dill/tests/_globals_dummy.py | 10 ++++++++++ dill/tests/test_functions.py | 20 ++++++++++++++++++++ 3 files changed, 46 insertions(+), 6 deletions(-) create mode 100644 dill/tests/_globals_dummy.py diff --git a/dill/_dill.py b/dill/_dill.py index be0aeba2..b81f62a1 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -366,6 +366,7 @@ def __init__(self, file, *args, **kwds): self._recurse = settings['recurse'] if _recurse is None else _recurse self._postproc = OrderedDict() self._file = file + self._globals_cache = {} def dump(self, obj): #NOTE: if settings change, need to update attributes # register if the object is a numpy ufunc @@ -1792,17 +1793,14 @@ def save_function(pickler, obj): _postproc = getattr(pickler, '_postproc', None) _main_modified = getattr(pickler, '_main_modified', None) _original_main = getattr(pickler, '_original_main', __builtin__)#'None' + _globals_cache = getattr(pickler, '_globals_cache', None) postproc_list = [] + + globs = None if _recurse: # recurse to get all globals referred to by obj from .detect import globalvars globs_copy = globalvars(obj, recurse=True, builtin=True) - - # Add the name of the module to the globs dictionary to prevent - # the duplication of the dictionary. Pickle the unpopulated - # globals dictionary and set the remaining items after the function - # is created to correctly handle recursion. - globs = {'__name__': obj.__module__} else: globs_copy = obj.__globals__ @@ -1815,6 +1813,18 @@ def save_function(pickler, obj): elif globs_copy is not None and obj.__module__ is not None and \ getattr(_import_module(obj.__module__, True), '__dict__', None) is globs_copy: globs = globs_copy + + if globs is None: + # Add the name of the module to the globs dictionary and prevent + # the duplication of the dictionary. Pickle the unpopulated + # globals dictionary and set the remaining items after the function + # is created to correctly handle recursion. + if _globals_cache is not None and obj.__globals__ is not None: + if id(obj.__globals__) not in _globals_cache: + globs = {'__name__': obj.__module__} + _globals_cache[id(obj.__globals__)] = globs + else: + globs = _globals_cache[id(obj.__globals__)] else: globs = {'__name__': obj.__module__} diff --git a/dill/tests/_globals_dummy.py b/dill/tests/_globals_dummy.py new file mode 100644 index 00000000..128a1ddc --- /dev/null +++ b/dill/tests/_globals_dummy.py @@ -0,0 +1,10 @@ +# This file is used by test_shared_globals in test_functions.py + +x = 3 + +def h(): + print(x) + +def g(): + h() + diff --git a/dill/tests/test_functions.py b/dill/tests/test_functions.py index d8c73396..36a0e052 100644 --- a/dill/tests/test_functions.py +++ b/dill/tests/test_functions.py @@ -132,7 +132,27 @@ def test_code_object(): except Exception as error: raise Exception("failed to construct code object with format version {}".format(version)) from error +def test_shared_globals(): + import dill, _globals_dummy as f, sys + + for recurse in False, True: + g, h = dill.copy((f.g, f.h), recurse=recurse) + assert f.g.__globals__ is f.h.__globals__ + assert g.__globals__ is h.__globals__ + assert f.g.__globals__ is g.__globals__ + + del sys.modules['_globals_dummy'] + + g, h = dill.copy((f.g, f.h), recurse=recurse) + assert f.g.__globals__ is f.h.__globals__ + assert g.__globals__ is h.__globals__ + assert f.g.__globals__ is not g.__globals__ + + sys.modules['_globals_dummy'] = f + + if __name__ == '__main__': test_functions() test_issue_510() test_code_object() + test_shared_globals()