Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dump_session: add split_imports #101

Merged
merged 1 commit into from
May 31, 2015
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion dill/dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,55 @@ def loads(str):
### End: Shorthands ###

### Pickle the Interpreter Session
def dump_session(filename='/tmp/session.pkl', main_module=_main_module):
def _module_map():
from collections import defaultdict
modmap = defaultdict(list)
items = 'items' if PY3 else 'iteritems'
for name, module in getattr(sys.modules, items)():
if module is None:
continue
for objname, obj in module.__dict__.items():
modmap[objname].append((obj, name))
return modmap

def _find_source_module(modmap, name, obj, main_module):
for modobj, modname in modmap[name]:
if modobj is obj and modname != main_module.__name__:
return modname

def _split_module_imports(main_module):
modmap = _module_map()
imported = []
original = {}
items = 'items' if PY3 else 'iteritems'
for name, obj in getattr(main_module.__dict__, items)():
source_module = _find_source_module(modmap, name, obj, main_module)
if source_module:
imported.append((source_module, name))
else:
original[name] = obj
if len(imported):
import types
newmod = types.ModuleType(main_module.__name__)
newmod.__dict__.update(original)
newmod.__dill_imported = imported
return newmod
else:
return original

def _restore_module_imports(main_module):
if '__dill_imported' not in main_module.__dict__:
return
imports = main_module.__dict__.pop('__dill_imported')
for module, name in imports:
exec("from %s import %s" % (module, name), main_module.__dict__)

def dump_session(filename='/tmp/session.pkl', main_module=_main_module, byref=False):
"""pickle the current state of __main__ to a file"""
f = open(filename, 'wb')
try:
if byref:
main_module = _split_module_imports(main_module)
pickler = Pickler(f, 2)
pickler._main_module = main_module
_byref = pickler._byref
Expand All @@ -248,6 +293,7 @@ def load_session(filename='/tmp/session.pkl', main_module=_main_module):
module = unpickler.load()
unpickler._session = False
main_module.__dict__.update(module.__dict__)
_restore_module_imports(main_module)
finally:
f.close()
return
Expand Down