Skip to content

Commit

Permalink
Fix a minor bug
Browse files Browse the repository at this point in the history
  • Loading branch information
anivegesana committed Aug 5, 2022
1 parent 32fdf11 commit 6221041
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
40 changes: 35 additions & 5 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,12 +1184,13 @@ def _repr_dict(obj):

@register(dict)
def save_module_dict(pickler, obj):
if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and \
_is_dill = is_dill(pickler, child=False)
if _is_dill and obj == pickler._main.__dict__ and \
not (pickler._session and pickler._first_pass):
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
logger.trace(pickler, "# D1")
elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
elif (not _is_dill) and (obj == _main_module.__dict__):
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
logger.trace(pickler, "# D3")
Expand All @@ -1199,12 +1200,37 @@ def save_module_dict(pickler, obj):
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c%s\n__dict__\n' % obj['__name__'], 'UTF-8'))
logger.trace(pickler, "# D4")
elif _is_dill and id(obj) in pickler._globals_cache:
logger.trace(pickler, "D5: %s", _repr_dict(obj)) # obj
# This is a globals dictionary that was partially copied, but not fully saved.
# Save the dictionary again to ensure that everything is there.
globs_copy = pickler._globals_cache[id(obj)]
pickler.write(pickler.get(pickler.memo[id(globs_copy)][0]))
pickler._batch_setitems(iter(obj.items()))
del pickler._globals_cache[id(obj)]
pickler.memo[id(obj)] = (pickler.memo.pop(id(globs_copy))[0], obj)
logger.trace(pickler, "# D5")
else:
logger.trace(pickler, "D2: %s", _repr_dict(obj)) # obj
if is_dill(pickler, child=False) and pickler._session:
if _is_dill and pickler._session:
# we only care about session the first pass thru
pickler._first_pass = False
StockPickler.save_dict(pickler, obj)

from pickle import EMPTY_DICT, MARK, DICT, SETITEM
if pickler.bin:
pickler.write(EMPTY_DICT)
else: # proto 0 -- can't use EMPTY_DICT
pickler.write(MARK + DICT)

# StockPickler.save_dict(pickler, obj)
pickler.memoize(obj)
# add __name__ first
if '__name__' in obj:
pickler.save('__name__')
pickler.save(obj['__name__'])
pickler.write(SETITEM)
pickler._batch_setitems(obj.items())

logger.trace(pickler, "# D2")
return

Expand Down Expand Up @@ -1797,7 +1823,11 @@ def save_function(pickler, obj):
postproc_list = []

globs = None
if _recurse:
if id(obj.__globals__) in pickler.memo:
# It is possible that the globals dictionary itself is also being
# pickled directly.
globs = globs_copy = obj.__globals__
elif _recurse:
# recurse to get all globals referred to by obj
from .detect import globalvars
globs_copy = globalvars(obj, recurse=True, builtin=True)
Expand Down
4 changes: 2 additions & 2 deletions dill/tests/_globals_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
x = 3

def h():
print(x)
return x

def g():
h()
return h()

8 changes: 7 additions & 1 deletion dill/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_code_object():
except Exception as error:
raise Exception("failed to construct code object with format version {}".format(version)) from error


def test_shared_globals():
import dill, _globals_dummy as f, sys

Expand All @@ -140,17 +141,22 @@ def test_shared_globals():
assert f.g.__globals__ is f.h.__globals__
assert g.__globals__ is h.__globals__
assert f.g.__globals__ is g.__globals__
assert g() == h() == 3

del sys.modules['_globals_dummy']

g, h = dill.copy((f.g, f.h), recurse=recurse)
assert f.g.__globals__ is f.h.__globals__
assert g.__globals__ is h.__globals__
assert f.g.__globals__ is not g.__globals__
assert g() == h() == 3
g1, g, g2 = dill.copy((f.__dict__, f.g, f.g.__globals__), recurse=recurse)
assert g1 is g.__globals__
assert g1 is g2

sys.modules['_globals_dummy'] = f


if __name__ == '__main__':
test_functions()
test_issue_510()
Expand Down

0 comments on commit 6221041

Please sign in to comment.