Skip to content

Commit

Permalink
Shared namespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
anivegesana committed Aug 4, 2022
1 parent 87b8541 commit 32fdf11
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 6 deletions.
22 changes: 16 additions & 6 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def __init__(self, file, *args, **kwds):
self._recurse = settings['recurse'] if _recurse is None else _recurse
self._postproc = OrderedDict()
self._file = file
self._globals_cache = {}

def dump(self, obj): #NOTE: if settings change, need to update attributes
# register if the object is a numpy ufunc
Expand Down Expand Up @@ -1792,17 +1793,14 @@ def save_function(pickler, obj):
_postproc = getattr(pickler, '_postproc', None)
_main_modified = getattr(pickler, '_main_modified', None)
_original_main = getattr(pickler, '_original_main', __builtin__)#'None'
_globals_cache = getattr(pickler, '_globals_cache', None)
postproc_list = []

globs = None
if _recurse:
# recurse to get all globals referred to by obj
from .detect import globalvars
globs_copy = globalvars(obj, recurse=True, builtin=True)

# Add the name of the module to the globs dictionary to prevent
# the duplication of the dictionary. Pickle the unpopulated
# globals dictionary and set the remaining items after the function
# is created to correctly handle recursion.
globs = {'__name__': obj.__module__}
else:
globs_copy = obj.__globals__

Expand All @@ -1815,6 +1813,18 @@ def save_function(pickler, obj):
elif globs_copy is not None and obj.__module__ is not None and \
getattr(_import_module(obj.__module__, True), '__dict__', None) is globs_copy:
globs = globs_copy

if globs is None:
# Add the name of the module to the globs dictionary and prevent
# the duplication of the dictionary. Pickle the unpopulated
# globals dictionary and set the remaining items after the function
# is created to correctly handle recursion.
if _globals_cache is not None and obj.__globals__ is not None:
if id(obj.__globals__) not in _globals_cache:
globs = {'__name__': obj.__module__}
_globals_cache[id(obj.__globals__)] = globs
else:
globs = _globals_cache[id(obj.__globals__)]
else:
globs = {'__name__': obj.__module__}

Expand Down
10 changes: 10 additions & 0 deletions dill/tests/_globals_dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# This file is used by test_shared_globals in test_functions.py

x = 3

def h():
print(x)

def g():
h()

20 changes: 20 additions & 0 deletions dill/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,27 @@ def test_code_object():
except Exception as error:
raise Exception("failed to construct code object with format version {}".format(version)) from error

def test_shared_globals():
import dill, _globals_dummy as f, sys

for recurse in False, True:
g, h = dill.copy((f.g, f.h), recurse=recurse)
assert f.g.__globals__ is f.h.__globals__
assert g.__globals__ is h.__globals__
assert f.g.__globals__ is g.__globals__

del sys.modules['_globals_dummy']

g, h = dill.copy((f.g, f.h), recurse=recurse)
assert f.g.__globals__ is f.h.__globals__
assert g.__globals__ is h.__globals__
assert f.g.__globals__ is not g.__globals__

sys.modules['_globals_dummy'] = f


if __name__ == '__main__':
test_functions()
test_issue_510()
test_code_object()
test_shared_globals()

0 comments on commit 32fdf11

Please sign in to comment.