Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shared namespaces #534

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 81 additions & 34 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def get_file_type(*args, **kwargs):
import inspect
import typing

from pickle import GLOBAL, EMPTY_DICT, MARK, DICT, SETITEM


### Shims for different versions of Python and dill
class Sentinel(object):
Expand Down Expand Up @@ -357,6 +359,7 @@ def __init__(self, file, *args, **kwds):
self._recurse = settings['recurse'] if _recurse is None else _recurse
self._postproc = OrderedDict()
self._file = file
self._globals_cache = {}

def save(self, obj, save_persistent_id=True):
# numpy hack
Expand Down Expand Up @@ -1182,12 +1185,13 @@ def _repr_dict(obj):

@register(dict)
def save_module_dict(pickler, obj):
if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and \
_is_dill = is_dill(pickler, child=False)
if _is_dill and obj == pickler._main.__dict__ and \
not (pickler._session and pickler._first_pass):
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
logger.trace(pickler, "# D1")
elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
elif (not _is_dill) and (obj == _main_module.__dict__):
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
logger.trace(pickler, "# D3")
Expand All @@ -1197,12 +1201,37 @@ def save_module_dict(pickler, obj):
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c%s\n__dict__\n' % obj['__name__'], 'UTF-8'))
logger.trace(pickler, "# D4")
elif _is_dill and id(obj) in pickler._globals_cache:
logger.trace(pickler, "D5: %s", _repr_dict(obj)) # obj
# This is a globals dictionary that was partially copied, but not fully saved.
# Save the dictionary again to ensure that everything is there.
globs_copy = pickler._globals_cache[id(obj)]
pickler.write(pickler.get(pickler.memo[id(globs_copy)][0]))
pickler._batch_setitems(iter(obj.items()))
del pickler._globals_cache[id(obj)]
pickler.memo[id(obj)] = (pickler.memo.pop(id(globs_copy))[0], obj)
logger.trace(pickler, "# D5")
else:
logger.trace(pickler, "D2: %s", _repr_dict(obj)) # obj
if is_dill(pickler, child=False) and pickler._session:
if _is_dill and pickler._session:
# we only care about session the first pass thru
pickler._first_pass = False
StockPickler.save_dict(pickler, obj)

# IMPORTANT: update the following code whenever save_dict is changed in pickle.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that the code for save_dict in all supported version of python are is the same (at least from if pickler.bin to pickler.memoize(obj)). The idea here is to insert the code with __name__ to save_dict while not yielding control to the StockPickler, correct? It seems that save_dict hasn't changed as far back as at least 3.1, so I'm not worried about the copy being made here.

# StockPickler.save_dict(pickler, obj)
if pickler.bin:
pickler.write(EMPTY_DICT)
else: # proto 0 -- can't use EMPTY_DICT
pickler.write(MARK + DICT)

pickler.memoize(obj)
# add __name__ first
if '__name__' in obj:
pickler.save('__name__')
pickler.save(obj['__name__'])
pickler.write(SETITEM)
pickler._batch_setitems(obj.items())

logger.trace(pickler, "# D2")
return

Expand Down Expand Up @@ -1803,45 +1832,63 @@ def save_function(pickler, obj):
logger.trace(pickler, "F1: %s", obj)
_recurse = getattr(pickler, '_recurse', None)
_postproc = getattr(pickler, '_postproc', None)
_main_modified = getattr(pickler, '_main_modified', None)
_original_main = getattr(pickler, '_original_main', __builtin__)#'None'
_original_main = getattr(pickler, '_original_main', None)
_globals_cache = getattr(pickler, '_globals_cache', None)
postproc_list = []
if _recurse:
# recurse to get all globals referred to by obj
from .detect import globalvars
globs_copy = globalvars(obj, recurse=True, builtin=True)

# Add the name of the module to the globs dictionary to prevent
# the duplication of the dictionary. Pickle the unpopulated
# globals dictionary and set the remaining items after the function
# is created to correctly handle recursion.
globs = {'__name__': obj.__module__}
else:
globs_copy = obj.__globals__
is_memoized = id(obj.__globals__) in pickler.memo
is_modified_main_dict = (
_original_main is not None
and obj.__globals__ is _original_main.__dict__
)
is_module_dict = (
not (_recurse or is_memoized or is_modified_main_dict)
and obj.__module__ is not None
and obj.__globals__ is getattr(_import_module(obj.__module__, safe=True), '__dict__', None)
)

if is_modified_main_dict:
# If the globals is the __dict__ from the module being saved as a
# session, substitute it by the dictionary being actually saved.
if _main_modified and globs_copy is _original_main.__dict__:
globs_copy = getattr(pickler, '_main', _original_main).__dict__
globs = globs_copy
globs = pickler._main.__dict__
elif is_memoized or is_module_dict:
# If the globals is a module __dict__, do not save it in the pickle.
elif globs_copy is not None and obj.__module__ is not None and \
getattr(_import_module(obj.__module__, True), '__dict__', None) is globs_copy:
globs = globs_copy
# It is possible that the globals dictionary itself is also being
# pickled directly.
globs = obj.__globals__
else:
if _recurse:
# recurse to get all globals referred to by obj
from .detect import globalvars
globs_copy = globalvars(obj, recurse=True, builtin=True)
else:
globs = {'__name__': obj.__module__}
# function not bound to an importable module
globs_copy = obj.__globals__

if globs_copy is not None and globs is not globs_copy:
# In the case that the globals are copied, we need to ensure that
# the globals dictionary is updated when all objects in the
# dictionary are already created.
glob_ids = {id(g) for g in globs_copy.values()}
for stack_element in _postproc:
if stack_element in glob_ids:
_postproc[stack_element].append((_setitems, (globs, globs_copy)))
break
# Add the name of the module to the globs dictionary and prevent
# the duplication of the dictionary. Pickle the unpopulated
# globals dictionary and set the remaining items after the function
# is created to correctly handle recursion.
if _globals_cache is not None and obj.__globals__ is not None:
if id(obj.__globals__) not in _globals_cache:
globs = {'__name__': obj.__module__}
_globals_cache[id(obj.__globals__)] = globs
else:
globs = _globals_cache[id(obj.__globals__)]
else:
postproc_list.append((_setitems, (globs, globs_copy)))
globs = {'__name__': obj.__module__}

if globs_copy is not None and globs is not globs_copy:
# In the case that the globals are copied, we need to ensure that
# the globals dictionary is updated when all objects in the
# dictionary are already created.
glob_ids = {id(g) for g in globs_copy.values()}
for stack_element in _postproc:
if stack_element in glob_ids:
_postproc[stack_element].append((_setitems, (globs, globs_copy)))
break
else:
postproc_list.append((_setitems, (globs, globs_copy)))

closure = obj.__closure__
state_dict = {}
Expand Down
9 changes: 5 additions & 4 deletions dill/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,18 @@ def dump_module(
if filename is None:
filename = str(TEMPDIR/'session.pkl')
file = open(filename, 'wb')
original_main = main
if refimported:
main = _stash_modules(main)
try:
pickler = Pickler(file, protocol, **kwds)
pickler._original_main = main
if refimported:
main = _stash_modules(main)
pickler._main = main #FIXME: dill.settings are disabled
pickler._byref = False # disable pickling by name reference
pickler._recurse = False # disable pickling recursion for globals
pickler._session = True # is best indicator of when pickling a session
pickler._first_pass = True
pickler._main_modified = main is not pickler._original_main
if main is not original_main:
pickler._original_main = original_main
pickler.dump(main)
finally:
if file is not filename: # if newly opened file
Expand Down
26 changes: 26 additions & 0 deletions dill/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,33 @@ def test_code_object():
except Exception as error:
raise Exception("failed to construct code object with format version {}".format(version)) from error


def test_shared_globals():
import dill, test_functors as f, sys

for recurse in False, True:
g, h = dill.copy((f.g, f.h), recurse=recurse)
assert f.g.__globals__ is f.h.__globals__
assert g.__globals__ is h.__globals__
assert f.g.__globals__ is g.__globals__
assert g(1, 2) == h() == 3

del sys.modules['test_functors']

g, h = dill.copy((f.g, f.h), recurse=recurse)
assert f.g.__globals__ is f.h.__globals__
assert g.__globals__ is h.__globals__
assert f.g.__globals__ is not g.__globals__
assert g(1, 2) == h() == 3
g1, g, g2 = dill.copy((f.__dict__, f.g, f.g.__globals__), recurse=recurse)
assert g1 is g.__globals__
assert g1 is g2

sys.modules['test_functors'] = f


if __name__ == '__main__':
test_functions()
test_issue_510()
test_code_object()
test_shared_globals()
5 changes: 3 additions & 2 deletions dill/tests/test_functors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
import dill
dill.settings['recurse'] = True

x = 3

def f(a, b, c): # without keywords
pass


def g(a, b, c=2): # with keywords
pass
return h(a=a, b=b, c=c)


def h(a=1, b=2, c=3): # without args
pass
return x


def test_functools():
Expand Down