From 7fe9bb55d48fde49e296d8d0f1aecc48a3ff1387 Mon Sep 17 00:00:00 2001 From: anivegesana Date: Thu, 4 Aug 2022 19:53:27 -0700 Subject: [PATCH] Fix a minor bug --- dill/_dill.py | 42 ++++++++++++++++++++++++++++++------ dill/tests/_globals_dummy.py | 4 ++-- dill/tests/test_functions.py | 8 ++++++- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/dill/_dill.py b/dill/_dill.py index b81f62a1..955bff1e 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -199,7 +199,7 @@ def get_file_type(*args, **kwargs): import dataclasses import typing -from pickle import GLOBAL +from pickle import GLOBAL, EMPTY_DICT, MARK, DICT, SETITEM ### Shims for different versions of Python and dill @@ -1184,12 +1184,13 @@ def _repr_dict(obj): @register(dict) def save_module_dict(pickler, obj): - if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and \ + _is_dill = is_dill(pickler, child=False) + if _is_dill and obj == pickler._main.__dict__ and \ not (pickler._session and pickler._first_pass): logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8')) logger.trace(pickler, "# D1") - elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__): + elif (not _is_dill) and (obj == _main_module.__dict__): logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general? logger.trace(pickler, "# D3") @@ -1199,12 +1200,37 @@ def save_module_dict(pickler, obj): logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj pickler.write(bytes('c%s\n__dict__\n' % obj['__name__'], 'UTF-8')) logger.trace(pickler, "# D4") + elif _is_dill and id(obj) in pickler._globals_cache: + logger.trace(pickler, "D5: %s", _repr_dict(obj)) # obj + # This is a globals dictionary that was partially copied, but not fully saved. + # Save the dictionary again to ensure that everything is there. + globs_copy = pickler._globals_cache[id(obj)] + pickler.write(pickler.get(pickler.memo[id(globs_copy)][0])) + pickler._batch_setitems(iter(obj.items())) + del pickler._globals_cache[id(obj)] + pickler.memo[id(obj)] = (pickler.memo.pop(id(globs_copy))[0], obj) + logger.trace(pickler, "# D5") else: logger.trace(pickler, "D2: %s", _repr_dict(obj)) # obj - if is_dill(pickler, child=False) and pickler._session: + if _is_dill and pickler._session: # we only care about session the first pass thru pickler._first_pass = False - StockPickler.save_dict(pickler, obj) + + # IMPORTANT: update the following code whenever save_dict is changed in pickle.py + # StockPickler.save_dict(pickler, obj) + if pickler.bin: + pickler.write(EMPTY_DICT) + else: # proto 0 -- can't use EMPTY_DICT + pickler.write(MARK + DICT) + + pickler.memoize(obj) + # add __name__ first + if '__name__' in obj: + pickler.save('__name__') + pickler.save(obj['__name__']) + pickler.write(SETITEM) + pickler._batch_setitems(obj.items()) + logger.trace(pickler, "# D2") return @@ -1797,7 +1823,11 @@ def save_function(pickler, obj): postproc_list = [] globs = None - if _recurse: + if id(obj.__globals__) in pickler.memo: + # It is possible that the globals dictionary itself is also being + # pickled directly. + globs = globs_copy = obj.__globals__ + elif _recurse: # recurse to get all globals referred to by obj from .detect import globalvars globs_copy = globalvars(obj, recurse=True, builtin=True) diff --git a/dill/tests/_globals_dummy.py b/dill/tests/_globals_dummy.py index 128a1ddc..27b86127 100644 --- a/dill/tests/_globals_dummy.py +++ b/dill/tests/_globals_dummy.py @@ -3,8 +3,8 @@ x = 3 def h(): - print(x) + return x def g(): - h() + return h() diff --git a/dill/tests/test_functions.py b/dill/tests/test_functions.py index 36a0e052..f5bc5a7f 100644 --- a/dill/tests/test_functions.py +++ b/dill/tests/test_functions.py @@ -132,6 +132,7 @@ def test_code_object(): except Exception as error: raise Exception("failed to construct code object with format version {}".format(version)) from error + def test_shared_globals(): import dill, _globals_dummy as f, sys @@ -140,6 +141,7 @@ def test_shared_globals(): assert f.g.__globals__ is f.h.__globals__ assert g.__globals__ is h.__globals__ assert f.g.__globals__ is g.__globals__ + assert g() == h() == 3 del sys.modules['_globals_dummy'] @@ -147,10 +149,14 @@ def test_shared_globals(): assert f.g.__globals__ is f.h.__globals__ assert g.__globals__ is h.__globals__ assert f.g.__globals__ is not g.__globals__ + assert g() == h() == 3 + g1, g, g2 = dill.copy((f.__dict__, f.g, f.g.__globals__), recurse=recurse) + assert g1 is g.__globals__ + assert g1 is g2 sys.modules['_globals_dummy'] = f - + if __name__ == '__main__': test_functions() test_issue_510()