diff --git a/dill/__init__.py b/dill/__init__.py index 6f71bbe5..a97e973d 100644 --- a/dill/__init__.py +++ b/dill/__init__.py @@ -34,13 +34,13 @@ dump_module, load_module, load_module_asdict, dump_session, load_session # backward compatibility ) -from . import detect, logger, session, source, temp +from . import detect, logging, session, source, temp # get global settings from .settings import settings # make sure "trace" is turned off -logger.trace(False) +logging.trace(False) from importlib import reload diff --git a/dill/_dill.py b/dill/_dill.py index 0130e709..85504c44 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -26,8 +26,10 @@ __module__ = 'dill' import warnings -from .logger import adapter as logger -from .logger import trace as _trace +from dill import logging +from .logging import adapter as logger +from .logging import trace as _trace +_logger = logging.getLogger(__name__) import os import sys @@ -39,6 +41,7 @@ #XXX: get types from .objtypes ? import builtins as __builtin__ from pickle import _Pickler as StockPickler, Unpickler as StockUnpickler +from pickle import DICT, GLOBAL, MARK, POP, SETITEM from _thread import LockType from _thread import RLock as RLockType #from io import IOBase @@ -58,13 +61,13 @@ import marshal import gc # import zlib +import weakref from weakref import ReferenceType, ProxyType, CallableProxyType from collections import OrderedDict -from functools import partial +from functools import partial, wraps from operator import itemgetter, attrgetter GENERATOR_FAIL = False import importlib.machinery -EXTENSION_SUFFIXES = tuple(importlib.machinery.EXTENSION_SUFFIXES) try: import ctypes HAS_CTYPES = True @@ -172,8 +175,6 @@ def get_file_type(*args, **kwargs): import dataclasses import typing -from pickle import GLOBAL - ### Shims for different versions of Python and dill class Sentinel(object): @@ -212,6 +213,9 @@ def __reduce_ex__(self, protocol): #: Pickles the entire file (handle and contents), preserving mode and position. FILE_FMODE = 2 +# Exceptions commonly raised by unpickleable objects in the Standard Library. +UNPICKLEABLE_ERRORS = (PicklingError, TypeError, ValueError, NotImplementedError) + ### Shorthands (modified from python2.5/lib/pickle.py) def copy(obj, *args, **kwds): """ @@ -321,8 +325,12 @@ class UnpicklingWarning(PickleWarning, UnpicklingError): class Pickler(StockPickler): """python's Pickler extended to interpreter sessions""" dispatch = MetaCatchingDict(StockPickler.dispatch.copy()) - _session = False from .settings import settings + # Flags set by dump_module() is dill.session: + _refimported = False + _refonfail = False + _session = False + _first_pass = False def __init__(self, file, *args, **kwds): settings = Pickler.settings @@ -338,9 +346,18 @@ def __init__(self, file, *args, **kwds): self._fmode = settings['fmode'] if _fmode is None else _fmode self._recurse = settings['recurse'] if _recurse is None else _recurse self._postproc = OrderedDict() - self._file = file + self._file_tell = getattr(file, 'tell', None) # for logger and refonfail def save(self, obj, save_persistent_id=True): + # This method overrides StockPickler.save() and is called for every + # object pickled. When 'refonfail' is True, it tries to save the object + # by reference if pickling it fails with a common pickling error, as + # defined by the constant UNPICKLEABLE_ERRORS. If that also fails, then + # the exception is raised and, if this method was called indirectly from + # another Pickler.save() call, the parent objects will try to be saved + # by reference recursively, until it succeeds or the exception + # propagates beyond the topmost save() call. + # register if the object is a numpy ufunc # thanks to Paul Kienzle for pointing out ufuncs didn't pickle obj_type = type(obj) @@ -375,7 +392,7 @@ def save_numpy_dtype(pickler, obj): if NumpyArrayType and ndarraysubclassinstance(obj_type): @register(obj_type) def save_numpy_array(pickler, obj): - logger.trace(pickler, "Nu: (%s, %s)", obj.shape, obj.dtype) + logger.trace(pickler, "Nu: (%s, %s)", obj.shape, obj.dtype, obj=obj) npdict = getattr(obj, '__dict__', None) f, args, state = obj.__reduce__() pickler.save_reduce(_create_array, (f,args,state,npdict), obj=obj) @@ -385,8 +402,61 @@ def save_numpy_array(pickler, obj): if GENERATOR_FAIL and type(obj) == GeneratorType: msg = "Can't pickle %s: attribute lookup builtins.generator failed" % GeneratorType raise PicklingError(msg) - StockPickler.save(self, obj, save_persistent_id) + if not self._refonfail: + StockPickler.save(self, obj, save_persistent_id) + return + + ## Save with 'refonfail' ## + + # Disable framing. This must be set right after the + # framer.init_framing() call at StockPickler.dump()). + self.framer.current_frame = None + # Store initial state. + position = self._file_tell() + memo_size = len(self.memo) + try: + StockPickler.save(self, obj, save_persistent_id) + except UNPICKLEABLE_ERRORS as error_stack: + trace_message = ( + "# X: fallback to save as global: <%s object at %#012x>" + % (type(obj).__name__, id(obj)) + ) + # Roll back the stream. Note: truncate(position) doesn't always work. + self._file_seek(position) + self._file_truncate() + # Roll back memo. + for _ in range(len(self.memo) - memo_size): + self.memo.popitem() # LIFO order is guaranteed since 3.7 + # Handle session main. + if self._session and obj is self._main: + if self._main is _main_module or not _is_imported_module(self._main): + raise + # Save an empty dict as state to distinguish from modules saved with dump(). + self.save_reduce(_import_module, (obj.__name__,), obj=obj, state={}) + logger.trace(self, trace_message, obj=obj) + warnings.warn( + "module %r saved by reference due to the unpickleable " + "variable %r. No changes to the module were saved." + % (self._main.__name__, error_stack.name), + PicklingWarning, + stacklevel=5, + ) + # Try to save object by reference. + elif hasattr(obj, '__name__') or hasattr(obj, '__qualname__'): + try: + self.save_global(obj) + logger.trace(self, trace_message, obj=obj) + return True # for _saved_byref, ignored otherwise + except PicklingError as error: + # Roll back trace state. + logger.roll_back(self, obj) + raise error from error_stack + else: + # Roll back trace state. + logger.roll_back(self, obj) + raise + return save.__doc__ = StockPickler.save.__doc__ def dump(self, obj): #NOTE: if settings change, need to update attributes @@ -1157,30 +1227,156 @@ def save_code(pickler, obj): logger.trace(pickler, "# Co") return +def _module_map(main_module): + """get map of imported modules""" + from collections import defaultdict + from types import SimpleNamespace + modmap = SimpleNamespace( + by_name = defaultdict(list), + by_id = defaultdict(list), + top_level = {}, # top-level modules + module = main_module.__name__, + package = _module_package(main_module), + ) + for modname, module in sys.modules.items(): + if (modname in ('__main__', '__mp_main__') or module is main_module + or not isinstance(module, ModuleType)): + continue + if '.' not in modname: + modmap.top_level[id(module)] = modname + for objname, modobj in module.__dict__.items(): + modmap.by_name[objname].append((modobj, modname)) + modmap.by_id[id(modobj)].append((objname, modname)) + return modmap + +def _lookup_module(modmap, name, obj, lookup_by_id=True) -> typing.Tuple[str, str, bool]: + """Lookup name or id of obj if module is imported. + + Lookup for objects identical to 'obj' at modules in 'modmpap'. If multiple + copies are found in different modules, return the one from the module with + higher probability of being available at unpickling time, according to the + hierarchy: + + 1. Standard Library modules + 2. modules of the same top-level package as the module being saved (if it's part of a package) + 3. installed modules in general + 4. non-installed modules + + Returns: + A 3-tuple containing the module's name, the object's name in the module, + and a boolean flag, which is `True` if the module falls under categories + (1) to (3) from the hierarchy, or `False` if it's in category (4). + """ + not_found = None, None, None + # Don't look for objects likely related to the module itself. + obj_module = getattr(obj, '__module__', type(obj).__module__) + if obj_module == modmap.module: + return not_found + obj_package = _module_package(_import_module(obj_module, safe=True)) + + for map, by_id in [(modmap.by_name, False), (modmap.by_id, True)]: + if by_id and not lookup_by_id: + break + _2nd_choice = _3rd_choice = _4th_choice = None + key = id(obj) if by_id else name + for other, modname in map[key]: + if by_id or other is obj: + other_name = other if by_id else name + other_module = sys.modules[modname] + other_package = _module_package(other_module) + # Don't return a reference to a module of another package + # if the object is likely from the same top-level package. + if (modmap.package and obj_package == modmap.package + and other_package != modmap.package): + continue + # Prefer modules imported earlier (the first found). + if _is_stdlib_module(other_module): + return modname, other_name, True + elif modmap.package and modmap.package == other_package: + if _2nd_choice: continue + _2nd_choice = modname, other_name, True + elif not _2nd_choice: + # Don't call _is_builtin_module() unnecessarily. + if _is_builtin_module(other_module): + if _3rd_choice: continue + _3rd_choice = modname, other_name, True + else: + if _4th_choice: continue + _4th_choice = modname, other_name, False # unsafe + found = _2nd_choice or _3rd_choice or _4th_choice + if found: + return found + return not_found + +def _global_string(modname, name): + return GLOBAL + bytes('%s\n%s\n' % (modname, name), 'UTF-8') + +def _save_module_dict(pickler, main_dict): + """Save a module's dictionary, saving unpickleable variables by referece.""" + main = getattr(pickler, '_original_main', pickler._main) + modmap = getattr(pickler, '_modmap', None) # cached from _stash_modules() + is_builtin = _is_builtin_module(main) + pickler.write(MARK + DICT) # don't need to memoize + for name, value in main_dict.items(): + _logger.debug("Pickling %r (%s)", name, type(value).__name__) + pickler.save(name) + try: + if pickler.save(value): + global_name = getattr(value, '__qualname__', value.__name__) + pickler._saved_byref.append((name, value.__module__, global_name)) + except UNPICKLEABLE_ERRORS as error_stack: + if modmap is None: + modmap = _module_map(main) + modname, objname, installed = _lookup_module(modmap, name, value) + if modname and (installed or not is_builtin): + pickler.write(_global_string(modname, objname)) + pickler._saved_byref.append((name, modname, objname)) + elif is_builtin: + pickler.write(_global_string(main.__name__, name)) + pickler._saved_byref.append((name, main.__name__, name)) + else: + error = PicklingError("can't save variable %r as global" % name) + error.name = name + raise error from error_stack + pickler.memoize(value) + pickler.write(SETITEM) + def _repr_dict(obj): - """make a short string representation of a dictionary""" + """Make a short string representation of a dictionary.""" return "<%s object at %#012x>" % (type(obj).__name__, id(obj)) @register(dict) def save_module_dict(pickler, obj): - if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and \ - not (pickler._session and pickler._first_pass): - logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj - pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8')) + is_pickler_dill = is_dill(pickler, child=False) + if (is_pickler_dill + and obj is pickler._main.__dict__ + and not (pickler._session and pickler._first_pass)): + logger.trace(pickler, "D1: %s", _repr_dict(obj), obj=obj) + pickler.write(GLOBAL + b'__builtin__\n__main__\n') logger.trace(pickler, "# D1") - elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__): - logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj - pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general? + elif not is_pickler_dill and obj is _main_module.__dict__: #prama: no cover + logger.trace(pickler, "D3: %s", _repr_dict(obj), obj=obj) + pickler.write(GLOBAL + b'__main__\n__dict__\n') #XXX: works in general? logger.trace(pickler, "# D3") - elif '__name__' in obj and obj != _main_module.__dict__ \ - and type(obj['__name__']) is str \ - and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None): - logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj - pickler.write(bytes('c%s\n__dict__\n' % obj['__name__'], 'UTF-8')) + elif (is_pickler_dill + and pickler._session + and pickler._refonfail + and obj is pickler._main_dict_copy): + logger.trace(pickler, "D5: %s", _repr_dict(obj), obj=obj) + # we only care about session the first pass thru + pickler.first_pass = False + _save_module_dict(pickler, obj) + logger.trace(pickler, "# D5") + elif ('__name__' in obj + and obj is not _main_module.__dict__ + and type(obj['__name__']) is str + and obj is getattr(_import_module(obj['__name__'], safe=True), '__dict__', None)): + logger.trace(pickler, "D4: %s", _repr_dict(obj), obj=obj) + pickler.write(_global_string(obj['__name__'], '__dict__')) logger.trace(pickler, "# D4") else: - logger.trace(pickler, "D2: %s", _repr_dict(obj)) # obj - if is_dill(pickler, child=False) and pickler._session: + logger.trace(pickler, "D2: %s", _repr_dict(obj), obj=obj) + if is_pickler_dill: # we only care about session the first pass thru pickler._first_pass = False StockPickler.save_dict(pickler, obj) @@ -1498,7 +1694,7 @@ def save_cell(pickler, obj): if MAPPING_PROXY_TRICK: @register(DictProxyType) def save_dictproxy(pickler, obj): - logger.trace(pickler, "Mp: %s", _repr_dict(obj)) # obj + logger.trace(pickler, "Mp: %s", _repr_dict(obj), obj=obj) mapping = obj | _dictproxy_helper_instance pickler.save_reduce(DictProxyType, (mapping,), obj=obj) logger.trace(pickler, "# Mp") @@ -1506,7 +1702,7 @@ def save_dictproxy(pickler, obj): else: @register(DictProxyType) def save_dictproxy(pickler, obj): - logger.trace(pickler, "Mp: %s", _repr_dict(obj)) # obj + logger.trace(pickler, "Mp: %s", _repr_dict(obj), obj=obj) pickler.save_reduce(DictProxyType, (obj.copy(),), obj=obj) logger.trace(pickler, "# Mp") return @@ -1577,24 +1773,78 @@ def save_weakref(pickler, obj): @register(CallableProxyType) def save_weakproxy(pickler, obj): # Must do string substitution here and use %r to avoid ReferenceError. - logger.trace(pickler, "R2: %r" % obj) + logger.trace(pickler, "R2: %r" % obj, obj=obj) refobj = _locate_object(_proxy_helper(obj)) pickler.save_reduce(_create_weakproxy, (refobj, callable(obj)), obj=obj) logger.trace(pickler, "# R2") return +def _weak_cache(func=None, *, defaults=None): + if defaults is None: + defaults = {} + if func is None: + return partial(_weak_cache, defaults=defaults) + cache = weakref.WeakKeyDictionary() + @wraps(func) + def wrapper(referent): + try: + return defaults[referent] + except KeyError: + try: + return cache[referent] + except KeyError: + value = func(referent) + cache[referent] = value + return value + return wrapper + +@_weak_cache(defaults={None: False}) +def _is_imported_module(module): + return getattr(module, '__loader__', None) is not None or module in sys.modules.values() + +PYTHONPATH_PREFIXES = {getattr(sys, attr) for attr in ( + 'base_prefix', 'prefix', 'base_exec_prefix', 'exec_prefix', + 'real_prefix', # for old virtualenv versions + ) if hasattr(sys, attr)} +PYTHONPATH_PREFIXES = tuple(os.path.realpath(path) for path in PYTHONPATH_PREFIXES) +EXTENSION_SUFFIXES = tuple(importlib.machinery.EXTENSION_SUFFIXES) +if OLD310: + STDLIB_PREFIX = os.path.dirname(os.path.realpath(os.__file__)) + +@_weak_cache(defaults={None: True}) #XXX: shouldn't return False for None? def _is_builtin_module(module): - if not hasattr(module, "__file__"): return True + if module.__name__ in ('__main__', '__mp_main__'): + return False + mod_path = getattr(module, '__file__', None) + if not mod_path: + return _is_imported_module(module) # If a module file name starts with prefix, it should be a builtin # module, so should always be pickled as a reference. - names = ["base_prefix", "base_exec_prefix", "exec_prefix", "prefix", "real_prefix"] - return any(os.path.realpath(module.__file__).startswith(os.path.realpath(getattr(sys, name))) - for name in names if hasattr(sys, name)) or \ - module.__file__.endswith(EXTENSION_SUFFIXES) or \ - 'site-packages' in module.__file__ + mod_path = os.path.realpath(mod_path) + return ( + any(mod_path.startswith(prefix) for prefix in PYTHONPATH_PREFIXES) + or mod_path.endswith(EXTENSION_SUFFIXES) + or 'site-packages' in mod_path + ) -def _is_imported_module(module): - return getattr(module, '__loader__', None) is not None or module in sys.modules.values() +@_weak_cache(defaults={None: False}) +def _is_stdlib_module(module): + first_level = module.__name__.partition('.')[0] + if OLD310: + if first_level in sys.builtin_module_names: + return True + mod_path = getattr(module, '__file__', '') + if mod_path: + mod_path = os.path.realpath(mod_path) + return mod_path.startswith(STDLIB_PREFIX) + else: + return first_level in sys.stdlib_module_names + +@_weak_cache(defaults={None: None}) +def _module_package(module): + """get the top-level package of a module, if any""" + package = getattr(module, '__package__', None) + return package.partition('.')[0] if package else None @register(ModuleType) def save_module(pickler, obj): @@ -1617,13 +1867,16 @@ def save_module(pickler, obj): logger.trace(pickler, "# M1") else: builtin_mod = _is_builtin_module(obj) - if obj.__name__ not in ("builtins", "dill", "dill._dill") and not builtin_mod or \ - is_dill(pickler, child=True) and obj is pickler._main: + is_session_main = is_dill(pickler, child=True) and obj is pickler._main + if (obj.__name__ not in ("builtins", "dill", "dill._dill") and not builtin_mod + or is_session_main): logger.trace(pickler, "M1: %s", obj) _main_dict = obj.__dict__.copy() #XXX: better no copy? option to copy? [_main_dict.pop(item, None) for item in singletontypes + ["__builtins__", "__loader__"]] mod_name = obj.__name__ if _is_imported_module(obj) else '__runtime__.%s' % obj.__name__ + if is_session_main: + pickler._main_dict_copy = _main_dict pickler.save_reduce(_import_module, (mod_name,), obj=obj, state=_main_dict) logger.trace(pickler, "# M1") @@ -1661,7 +1914,7 @@ def save_type(pickler, obj, postproc_list=None): elif obj is type(None): logger.trace(pickler, "T7: %s", obj) #XXX: pickler.save_reduce(type, (None,), obj=obj) - pickler.write(bytes('c__builtin__\nNoneType\n', 'UTF-8')) + pickler.write(GLOBAL + b'__builtin__\nNoneType\n') logger.trace(pickler, "# T7") elif obj is NotImplementedType: logger.trace(pickler, "T7: %s", obj) @@ -1763,8 +2016,7 @@ def save_function(pickler, obj): logger.trace(pickler, "F1: %s", obj) _recurse = getattr(pickler, '_recurse', None) _postproc = getattr(pickler, '_postproc', None) - _main_modified = getattr(pickler, '_main_modified', None) - _original_main = getattr(pickler, '_original_main', __builtin__)#'None' + _original_main = getattr(pickler, '_original_main', None) postproc_list = [] if _recurse: # recurse to get all globals referred to by obj @@ -1781,8 +2033,8 @@ def save_function(pickler, obj): # If the globals is the __dict__ from the module being saved as a # session, substitute it by the dictionary being actually saved. - if _main_modified and globs_copy is _original_main.__dict__: - globs_copy = getattr(pickler, '_main', _original_main).__dict__ + if _original_main is not None and globs_copy is _original_main.__dict__: + globs_copy = pickler._main.__dict__ globs = globs_copy # If the globals is a module __dict__, do not save it in the pickle. elif globs_copy is not None and obj.__module__ is not None and \ @@ -1949,7 +2201,7 @@ def pickles(obj,exact=False,safe=False,**kwds): """ if safe: exceptions = (Exception,) # RuntimeError, ValueError else: - exceptions = (TypeError, AssertionError, NotImplementedError, PicklingError, UnpicklingError) + exceptions = UNPICKLEABLE_ERRORS + (AssertionError, UnpicklingError) try: pik = copy(obj, **kwds) #FIXME: should check types match first, then check content if "exact" diff --git a/dill/_utils.py b/dill/_utils.py new file mode 100644 index 00000000..912a2e8e --- /dev/null +++ b/dill/_utils.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python +# +# Author: Leonardo Gama (@leogama) +# Copyright (c) 2022 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE +""" +Auxiliary classes and functions used in more than one module, defined here to +avoid circular import problems. +""" + +import contextlib +import io +import math +from contextlib import suppress + +#NOTE: dill._dill is not completely loaded at this point, can't import from it. +from dill import _dill + +# Type hints. +from typing import Tuple, Union + +def _format_bytes_size(size: Union[int, float]) -> Tuple[int, str]: + """Return bytes size text representation in human-redable form.""" + unit = "B" + power_of_2 = math.trunc(size).bit_length() - 1 + magnitude = min(power_of_2 - power_of_2 % 10, 80) # 2**80 == 1 YiB + if magnitude: + # Rounding trick: 1535 (1024 + 511) -> 1K; 1536 -> 2K + size = ((size >> magnitude-1) + 1) >> 1 + unit = "%siB" % "KMGTPEZY"[(magnitude // 10) - 1] + return size, unit + + +## File-related utilities ## + +class _PeekableReader(contextlib.AbstractContextManager): + """lightweight readable stream wrapper that implements peek()""" + def __init__(self, stream, closing=True): + self.stream = stream + self.closing = closing + def __exit__(self, *exc_info): + if self.closing: + self.stream.close() + def read(self, n): + return self.stream.read(n) + def readline(self): + return self.stream.readline() + def tell(self): + return self.stream.tell() + def close(self): + return self.stream.close() + def peek(self, n): + stream = self.stream + try: + if hasattr(stream, 'flush'): + stream.flush() + position = stream.tell() + stream.seek(position) # assert seek() works before reading + chunk = stream.read(n) + stream.seek(position) + return chunk + except (AttributeError, OSError): + raise NotImplementedError("stream is not peekable: %r", stream) from None + +class _SeekableWriter(io.BytesIO, contextlib.AbstractContextManager): + """works as an unlimited buffer, writes to file on close""" + def __init__(self, stream, closing=True, *args, **kwds): + super().__init__(*args, **kwds) + self.stream = stream + self.closing = closing + def __exit__(self, *exc_info): + self.close() + def close(self): + self.stream.write(self.getvalue()) + with suppress(AttributeError): + self.stream.flush() + super().close() + if self.closing: + self.stream.close() + +def _open(file, mode, *, peekable=False, seekable=False): + """return a context manager with an opened file-like object""" + readonly = ('r' in mode and '+' not in mode) + if not readonly and peekable: + raise ValueError("the 'peekable' option is invalid for writable files") + if readonly and seekable: + raise ValueError("the 'seekable' option is invalid for read-only files") + should_close = not hasattr(file, 'read' if readonly else 'write') + if should_close: + file = open(file, mode) + # Wrap stream in a helper class if necessary. + if peekable and not hasattr(file, 'peek'): + # Try our best to return it as an object with a peek() method. + if hasattr(file, 'seekable'): + file_seekable = file.seekable() + elif hasattr(file, 'seek') and hasattr(file, 'tell'): + try: + file.seek(file.tell()) + file_seekable = True + except Exception: + file_seekable = False + else: + file_seekable = False + if file_seekable: + file = _PeekableReader(file, closing=should_close) + else: + try: + file = io.BufferedReader(file) + except Exception: + # It won't be peekable, but will fail gracefully in _identify_module(). + file = _PeekableReader(file, closing=should_close) + elif seekable and ( + not hasattr(file, 'seek') + or not hasattr(file, 'truncate') + or (hasattr(file, 'seekable') and not file.seekable()) + ): + file = _SeekableWriter(file, closing=should_close) + if should_close or isinstance(file, (_PeekableReader, _SeekableWriter)): + return file + else: + return contextlib.nullcontext(file) diff --git a/dill/detect.py b/dill/detect.py index b6a6cb76..e6149d15 100644 --- a/dill/detect.py +++ b/dill/detect.py @@ -13,7 +13,7 @@ from inspect import ismethod, isfunction, istraceback, isframe, iscode from .pointers import parent, reference, at, parents, children -from .logger import trace +from .logging import trace __all__ = ['baditems','badobjects','badtypes','code','errors','freevars', 'getmodule','globalvars','nestedcode','nestedglobals','outermost', diff --git a/dill/logger.py b/dill/logging.py similarity index 65% rename from dill/logger.py rename to dill/logging.py index be557a5e..92386e0c 100644 --- a/dill/logger.py +++ b/dill/logging.py @@ -11,37 +11,45 @@ The 'logger' object is dill's top-level logger. The 'adapter' object wraps the logger and implements a 'trace()' method that -generates a detailed tree-style trace for the pickling call at log level INFO. +generates a detailed tree-style trace for the pickling call at log level +:const:`dill.logging.TRACE`, which has an intermediary value between +:const:`logging.INFO` and :const:`logging.DEGUB`. The 'trace()' function sets and resets dill's logger log level, enabling and disabling the pickling trace. The trace shows a tree structure depicting the depth of each object serialized *with dill save functions*, but not the ones that use save functions from -'pickle._Pickler.dispatch'. If the information is available, it also displays +``pickle._Pickler.dispatch``. If the information is available, it also displays the size in bytes that the object contributed to the pickle stream (including its child objects). Sample trace output: - >>> import dill, dill.tests - >>> dill.detect.trace(True) - >>> dill.dump_session(main=dill.tests) - ┬ M1: - ├┬ F2: + >>> import dill + >>> import keyword + >>> with dill.detect.trace(): + ... dill.dump_module(module=keyword) + ┬ M1: + ├┬ F2: │└ # F2 [32 B] - ├┬ D2: + ├┬ D5: │├┬ T4: ││└ # T4 [35 B] - │├┬ D2: + │├┬ D2: ││├┬ T4: │││└ # T4 [50 B] - ││├┬ D2: - │││└ # D2 [84 B] - ││└ # D2 [413 B] - │└ # D2 [763 B] - └ # M1 [813 B] + ││├┬ D2: + │││└ # D2 [47 B] + ││└ # D2 [280 B] + │└ # D5 [1 KiB] + └ # M1 [1 KiB] """ -__all__ = ['adapter', 'logger', 'trace'] +from __future__ import annotations + +__all__ = [ + 'adapter', 'logger', 'trace', 'getLogger', + 'CRITICAL', 'ERROR', 'WARNING', 'INFO', 'TRACE', 'DEBUG', 'NOTSET', +] import codecs import contextlib @@ -49,10 +57,21 @@ import logging import math import os +from contextlib import suppress +from logging import getLogger, CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET from functools import partial -from typing import TextIO, Union +from typing import Optional, TextIO, Union import dill +from ._utils import _format_bytes_size + +# Intermediary logging level for tracing. +TRACE = (INFO + DEBUG) // 2 + +_nameOrBoolToLevel = logging._nameToLevel.copy() +_nameOrBoolToLevel['TRACE'] = TRACE +_nameOrBoolToLevel[False] = WARNING +_nameOrBoolToLevel[True] = TRACE # Tree drawing characters: Unicode to ASCII map. ASCII_MAP = str.maketrans({"│": "|", "├": "|", "┬": "+", "└": "`"}) @@ -105,13 +124,24 @@ class TraceAdapter(logging.LoggerAdapter): creates extra values to be added in the LogRecord from it, then calls 'info()'. - Usage of logger with 'trace()' method: + Examples: - >>> from dill.logger import adapter as logger #NOTE: not dill.logger.logger - >>> ... - >>> def save_atype(pickler, obj): - >>> logger.trace(pickler, "Message with %s and %r etc. placeholders", 'text', obj) - >>> ... + In the first call to `trace()`, before pickling an object, it must be passed + to `trace()` as the last positional argument or as the keyword argument + `obj`. Note how, in the second example, the object is not passed as a + positional argument, and therefore won't be substituted in the message: + + >>> from dill.logger import adapter as logger #NOTE: not dill.logger.logger + >>> ... + >>> def save_atype(pickler, obj): + >>> logger.trace(pickler, "X: Message with %s and %r placeholders", 'text', obj) + >>> ... + >>> logger.trace(pickler, "# X") + >>> def save_weakproxy(pickler, obj) + >>> trace_message = "W: This works even with a broken weakproxy: %r" % obj + >>> logger.trace(pickler, trace_message, obj=obj) + >>> ... + >>> logger.trace(pickler, "# W") """ def __init__(self, logger): self.logger = logger @@ -128,44 +158,57 @@ def trace_setup(self, pickler): # Called by Pickler.dump(). if not dill._dill.is_dill(pickler, child=False): return - if self.isEnabledFor(logging.INFO): - pickler._trace_depth = 1 + elif self.isEnabledFor(TRACE): + pickler._trace_stack = [] pickler._size_stack = [] else: - pickler._trace_depth = None - def trace(self, pickler, msg, *args, **kwargs): - if not hasattr(pickler, '_trace_depth'): + pickler._trace_stack = None + def trace(self, pickler, msg, *args, obj=None, **kwargs): + if not hasattr(pickler, '_trace_stack'): logger.info(msg, *args, **kwargs) return - if pickler._trace_depth is None: + elif pickler._trace_stack is None: return extra = kwargs.get('extra', {}) pushed_obj = msg.startswith('#') + if not pushed_obj: + if obj is None and (not args or type(args[-1]) is str): + raise TypeError( + "the pickled object must be passed as the last positional " + "argument (being substituted in the message) or as the " + "'obj' keyword argument." + ) + if obj is None: + obj = args[-1] + pickler._trace_stack.append(id(obj)) size = None - try: + with suppress(AttributeError, TypeError): # Streams are not required to be tellable. - size = pickler._file.tell() + size = pickler._file_tell() frame = pickler.framer.current_frame try: size += frame.tell() except AttributeError: # PyPy may use a BytesBuilder as frame size += len(frame) - except (AttributeError, TypeError): - pass if size is not None: if not pushed_obj: pickler._size_stack.append(size) + if len(pickler._size_stack) == 3: # module > dict > variable + with suppress(AttributeError, KeyError): + extra['varname'] = pickler._id_to_name.pop(id(obj)) else: size -= pickler._size_stack.pop() extra['size'] = size - if pushed_obj: - pickler._trace_depth -= 1 - extra['depth'] = pickler._trace_depth + extra['depth'] = len(pickler._trace_stack) kwargs['extra'] = extra self.info(msg, *args, **kwargs) - if not pushed_obj: - pickler._trace_depth += 1 + if pushed_obj: + pickler._trace_stack.pop() + def roll_back(self, pickler, obj): + if pickler._trace_stack and id(obj) == pickler._trace_stack[-1]: + pickler._trace_stack.pop() + pickler._size_stack.pop() class TraceFormatter(logging.Formatter): """ @@ -200,24 +243,26 @@ def format(self, record): if not self.is_utf8: prefix = prefix.translate(ASCII_MAP) + "-" fields['prefix'] = prefix + " " - if hasattr(record, 'size'): - # Show object size in human-redable form. - power = int(math.log(record.size, 2)) // 10 - size = record.size >> power*10 - fields['suffix'] = " [%d %sB]" % (size, "KMGTP"[power] + "i" if power else "") + if hasattr(record, 'varname'): + fields['suffix'] = " as %r" % record.varname + elif hasattr(record, 'size'): + fields['suffix'] = " [%d %s]" % _format_bytes_size(record.size) vars(record).update(fields) return super().format(record) -logger = logging.getLogger('dill') +logger = getLogger('dill') logger.propagate = False adapter = TraceAdapter(logger) stderr_handler = logging._StderrHandler() adapter.addHandler(stderr_handler) -def trace(arg: Union[bool, TextIO, str, os.PathLike] = None, *, mode: str = 'a') -> None: +def trace( + arg: Union[bool, str, TextIO, os.PathLike] = None, *, mode: str = 'a' + ) -> Optional[TraceManager]: """print a trace through the stack when pickling; useful for debugging - With a single boolean argument, enable or disable the tracing. + With a single boolean argument, enable or disable the tracing. Or, with a + logging level name (not ``int``), set the logging level of the dill logger. Example usage: @@ -227,10 +272,10 @@ def trace(arg: Union[bool, TextIO, str, os.PathLike] = None, *, mode: str = 'a') Alternatively, ``trace()`` can be used as a context manager. With no arguments, it just takes care of restoring the tracing state on exit. - Either a file handle, or a file name and (optionally) a file mode may be - specitfied to redirect the tracing output in the ``with`` block context. A - log function is yielded by the manager so the user can write extra - information to the file. + Either a file handle, or a file name and a file mode (optional) may be + specified to redirect the tracing output in the ``with`` block. A ``log()`` + function is yielded by the manager so the user can write extra information + to the file. Example usage: @@ -249,13 +294,18 @@ def trace(arg: Union[bool, TextIO, str, os.PathLike] = None, *, mode: str = 'a') >>> log("> squared = %r", squared) >>> dumps(squared) - Arguments: - arg: a boolean value, or an optional file-like or path-like object for the context manager - mode: mode string for ``open()`` if a file name is passed as the first argument + Parameters: + arg: a boolean value, the name of a logging level (including "TRACE") + or an optional file-like or path-like object for the context manager + mode: mode string for ``open()`` if a file name is passed as the first + argument """ - if not isinstance(arg, bool): + level = _nameOrBoolToLevel.get(arg) if isinstance(arg, (bool, str)) else None + if level is not None: + logger.setLevel(level) + return + else: return TraceManager(file=arg, mode=mode) - logger.setLevel(logging.INFO if arg else logging.WARNING) class TraceManager(contextlib.AbstractContextManager): """context manager version of trace(); can redirect the trace to a file""" @@ -274,7 +324,7 @@ def __enter__(self): adapter.removeHandler(stderr_handler) adapter.addHandler(self.handler) self.old_level = adapter.getEffectiveLevel() - adapter.setLevel(logging.INFO) + adapter.setLevel(TRACE) return adapter.info def __exit__(self, *exc_info): adapter.setLevel(self.old_level) diff --git a/dill/session.py b/dill/session.py index 6acdd432..a993407a 100644 --- a/dill/session.py +++ b/dill/session.py @@ -7,7 +7,23 @@ # License: 3-clause BSD. The full license text is available at: # - https://github.com/uqfoundation/dill/blob/master/LICENSE """ -Pickle and restore the intepreter session. +Pickle and restore the intepreter session or a module's state. + +The functions :func:`dump_module`, :func:`load_module` and +:func:`load_module_asdict` are capable of saving and restoring, as long as +objects are pickleable, the complete state of a module. For imported modules +that are pickled, `dill` assumes that they are importable when unpickling. + +Contrary of using :func:`dill.dump` and :func:`dill.load` to save and load a +module object, :func:`dill.dump_module` always tries to pickle the module by +value (including built-in modules). Also, options like +``dill.settings['byref']`` and ``dill.settings['recurse']`` don't affect its +behavior. + +However, if a module contains references to objects originating from other +modules, that would prevent it from pickling or drastically increase its disk +size, they can be saved by reference instead of by value, using the option +``refimported``. """ __all__ = [ @@ -19,79 +35,63 @@ import sys import warnings -from dill import _dill, Pickler, Unpickler +from dill import _dill, logging +from dill import Pickler, Unpickler, UnpicklingError from ._dill import ( BuiltinMethodType, FunctionType, MethodType, ModuleType, TypeType, - _import_module, _is_builtin_module, _is_imported_module, _main_module, - _reverse_typemap, __builtin__, + _import_module, _is_builtin_module, _is_imported_module, + _lookup_module, _main_module, _module_map, _reverse_typemap, __builtin__, ) +from ._utils import _open + +logger = logging.getLogger(__name__) # Type hints. -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import pathlib import tempfile TEMPDIR = pathlib.PurePath(tempfile.gettempdir()) -def _module_map(): - """get map of imported modules""" - from collections import defaultdict - from types import SimpleNamespace - modmap = SimpleNamespace( - by_name=defaultdict(list), - by_id=defaultdict(list), - top_level={}, - ) - for modname, module in sys.modules.items(): - if modname in ('__main__', '__mp_main__') or not isinstance(module, ModuleType): - continue - if '.' not in modname: - modmap.top_level[id(module)] = modname - for objname, modobj in module.__dict__.items(): - modmap.by_name[objname].append((modobj, modname)) - modmap.by_id[id(modobj)].append((modobj, objname, modname)) - return modmap - +# Unique objects (with no duplicates) that may be imported with "import as". IMPORTED_AS_TYPES = (ModuleType, TypeType, FunctionType, MethodType, BuiltinMethodType) if 'PyCapsuleType' in _reverse_typemap: IMPORTED_AS_TYPES += (_reverse_typemap['PyCapsuleType'],) -IMPORTED_AS_MODULES = ('ctypes', 'typing', 'subprocess', 'threading', - r'concurrent\.futures(\.\w+)?', r'multiprocessing(\.\w+)?') -IMPORTED_AS_MODULES = tuple(re.compile(x) for x in IMPORTED_AS_MODULES) - -def _lookup_module(modmap, name, obj, main_module): - """lookup name or id of obj if module is imported""" - for modobj, modname in modmap.by_name[name]: - if modobj is obj and sys.modules[modname] is not main_module: - return modname, name - __module__ = getattr(obj, '__module__', None) - if isinstance(obj, IMPORTED_AS_TYPES) or (__module__ is not None - and any(regex.fullmatch(__module__) for regex in IMPORTED_AS_MODULES)): - for modobj, objname, modname in modmap.by_id[id(obj)]: - if sys.modules[modname] is not main_module: - return modname, objname - return None, None - -def _stash_modules(main_module): - modmap = _module_map() - newmod = ModuleType(main_module.__name__) +# For unique objects of various types that have a '__module__' attribute. +IMPORTED_AS_MODULES = [re.compile(x) for x in ( + 'ctypes', 'typing', 'subprocess', 'threading', + r'concurrent\.futures(\.\w+)?', r'multiprocessing(\.\w+)?' +)] + +BUILTIN_CONSTANTS = (None, False, True, NotImplemented) + +def _stash_modules(main_module, original_main): + """pop imported variables to be saved by reference in the __dill_imported* attributes""" + modmap = _module_map(original_main) + newmod = ModuleType(main_module.__name__) + original = {} imported = [] imported_as = [] imported_top_level = [] # keep separated for backward compatibility - original = {} + for name, obj in main_module.__dict__.items(): - if obj is main_module: - original[name] = newmod # self-reference - elif obj is main_module.__dict__: - original[name] = newmod.__dict__ - # Avoid incorrectly matching a singleton value in another package (ex.: __doc__). - elif any(obj is singleton for singleton in (None, False, True)) \ - or isinstance(obj, ModuleType) and _is_builtin_module(obj): # always saved by ref + # Avoid incorrectly matching a singleton value in another package (e.g. __doc__ == None). + if (any(obj is constant for constant in BUILTIN_CONSTANTS) # must compare by identity + or type(obj) is str and obj == '' # internalized, for cases like: __package__ == '' + or type(obj) is int and -128 <= obj <= 256 # possibly cached by compiler/interpreter + or isinstance(obj, ModuleType) and _is_builtin_module(obj) # always saved by ref + or obj is main_module or obj is main_module.__dict__): original[name] = obj else: - source_module, objname = _lookup_module(modmap, name, obj, main_module) + modname = getattr(obj, '__module__', None) + lookup_by_id = ( + isinstance(obj, IMPORTED_AS_TYPES) + or modname is not None + and any(regex.fullmatch(modname) for regex in IMPORTED_AS_MODULES) + ) + source_module, objname, _ = _lookup_module(modmap, name, obj, lookup_by_id) if source_module is not None: if objname == name: imported.append((source_module, name)) @@ -108,51 +108,91 @@ def _stash_modules(main_module): newmod.__dill_imported = imported newmod.__dill_imported_as = imported_as newmod.__dill_imported_top_level = imported_top_level - if getattr(newmod, '__loader__', None) is None and _is_imported_module(main_module): - # Trick _is_imported_module() to force saving as an imported module. - newmod.__loader__ = True # will be discarded by save_module() - return newmod + _discard_added_variables(newmod, main_module.__dict__) + + if logger.isEnabledFor(logging.INFO): + refimported = [(name, "%s.%s" % (mod, name)) for mod, name in imported] + refimported += [(name, "%s.%s" % (mod, objname)) for mod, objname, name in imported_as] + refimported += [(name, mod) for mod, name in imported_top_level] + message = "[dump_module] Variables saved by reference (refimported):\n" + logger.info(message + _format_log_dict(dict(refimported))) + logger.debug("main namespace after _stash_modules(): %s", dir(newmod)) + + return newmod, modmap else: - return main_module + return main_module, modmap def _restore_modules(unpickler, main_module): - try: - for modname, name in main_module.__dict__.pop('__dill_imported'): - main_module.__dict__[name] = unpickler.find_class(modname, name) - for modname, objname, name in main_module.__dict__.pop('__dill_imported_as'): - main_module.__dict__[name] = unpickler.find_class(modname, objname) - for modname, name in main_module.__dict__.pop('__dill_imported_top_level'): - main_module.__dict__[name] = __import__(modname) - except KeyError: - pass - -#NOTE: 06/03/15 renamed main_module to main + for modname, name in main_module.__dict__.pop('__dill_imported', ()): + main_module.__dict__[name] = unpickler.find_class(modname, name) + for modname, objname, name in main_module.__dict__.pop('__dill_imported_as', ()): + main_module.__dict__[name] = unpickler.find_class(modname, objname) + for modname, name in main_module.__dict__.pop('__dill_imported_top_level', ()): + main_module.__dict__[name] = _import_module(modname) + +def _format_log_dict(dict): + return pprint.pformat(dict, compact=True, sort_dicts=True).replace("'", "") + +def _discard_added_variables(main, original_namespace): + # Some empty attributes like __doc__ may have been added by ModuleType(). + added_names = set(main.__dict__) + added_names.discard('__name__') # required + added_names.difference_update(original_namespace) + added_names.difference_update('__dill_imported%s' % s for s in ('', '_as', '_top_level')) + for name in added_names: + delattr(main, name) + +def _fix_module_namespace(main, original_main): + # Self-references. + for name, obj in main.__dict__.items(): + if obj is original_main: + setattr(main, name, main) + elif obj is original_main.__dict__: + setattr(main, name, main.__dict__) + # Trick _is_imported_module(), forcing main to be saved as an imported module. + if getattr(main, '__loader__', None) is None and _is_imported_module(original_main): + main.__loader__ = True # will be discarded by _dill.save_module() + def dump_module( filename = str(TEMPDIR/'session.pkl'), module: Optional[Union[ModuleType, str]] = None, - refimported: bool = False, + *, + refimported: Optional[bool] = None, + refonfail: Optional[bool] = None, **kwds ) -> None: - """Pickle the current state of :py:mod:`__main__` or another module to a file. + """Pickle the current state of :mod:`__main__` or another module to a file. - Save the contents of :py:mod:`__main__` (e.g. from an interactive + Save the contents of :mod:`__main__` (e.g. from an interactive interpreter session), an imported module, or a module-type object (e.g. - built with :py:class:`~types.ModuleType`), to a file. The pickled - module can then be restored with the function :py:func:`load_module`. + built with :class:`~types.ModuleType`), to a file. The pickled + module can then be restored with the function :func:`load_module`. Parameters: filename: a path-like object or a writable stream. module: a module object or the name of an importable module. If `None` - (the default), :py:mod:`__main__` is saved. + (the default), :mod:`__main__` is saved. refimported: if `True`, all objects identified as having been imported into the module's namespace are saved by reference. *Note:* this is - similar but independent from ``dill.settings[`byref`]``, as + similar but independent from ``dill.settings['byref']``, as ``refimported`` refers to virtually all imported objects, while ``byref`` only affects select objects. - **kwds: extra keyword arguments passed to :py:class:`Pickler()`. + refonfail: if `True` (the default), objects that fail to pickle by value + will try to be saved by reference. If this also fails, saving their + parent objects by reference will be attempted recursively. In the + worst case scenario, the module itself may be saved by reference, + with a warning. *Note:* this has the side effect of disabling framing + for pickle protocol ≥ 4. Turning this option off may improve + unpickling speed, but may cause a module to fail pickling. + **kwds: extra keyword arguments passed to :class:`Pickler()`. Raises: - :py:exc:`PicklingError`: if pickling fails. + :exc:`PicklingError`: if pickling fails. + :exc:`PicklingWarning`: if the module itself ends being saved by + reference due to unpickleable objects in its namespace. + + Default values for keyword-only arguments can be set in + `dill.session.settings`. Examples: @@ -177,7 +217,16 @@ def dump_module( >>> foo.values = [1,2,3] >>> import math >>> foo.sin = math.sin - >>> dill.dump_module('foo_session.pkl', module=foo, refimported=True) + >>> dill.dump_module('foo_session.pkl', module=foo) + + - Save the state of a module with unpickleable objects: + + >>> import dill + >>> import os + >>> os.altsep = '\\' + >>> dill.dump_module('os_session.pkl', module=os, refonfail=False) + PicklingError: ... + >>> dill.dump_module('os_session.pkl', module=os, refonfail=True) # the default - Restore the state of the saved modules: @@ -191,6 +240,25 @@ def dump_module( >>> foo = dill.load_module('foo_session.pkl') >>> [foo.sin(x) for x in foo.values] [0.8414709848078965, 0.9092974268256817, 0.1411200080598672] + >>> os = dill.load_module('os_session.pkl') + >>> print(os.altsep.join('path')) + p\\a\\t\\h + + - Use `refimported` to save imported objects by reference: + + >>> import dill + >>> from html.entities import html5 + >>> type(html5), len(html5) + (dict, 2231) + >>> import io + >>> buf = io.BytesIO() + >>> dill.dump_module(buf) # saves __main__, with html5 saved by value + >>> len(buf.getvalue()) # pickle size in bytes + 71665 + >>> buf = io.BytesIO() + >>> dill.dump_module(buf, refimported=True) # html5 saved by reference + >>> len(buf.getvalue()) + 438 *Changed in version 0.3.6:* Function ``dump_session()`` was renamed to ``dump_module()``. Parameters ``main`` and ``byref`` were renamed to @@ -211,8 +279,11 @@ def dump_module( refimported = kwds.pop('byref', refimported) module = kwds.pop('main', module) - from .settings import settings - protocol = settings['protocol'] + from .settings import settings as dill_settings + protocol = dill_settings['protocol'] + if refimported is None: refimported = settings['refimported'] + if refonfail is None: refonfail = settings['refonfail'] + main = module if main is None: main = _main_module @@ -220,25 +291,37 @@ def dump_module( main = _import_module(main) if not isinstance(main, ModuleType): raise TypeError("%r is not a module" % main) - if hasattr(filename, 'write'): - file = filename - else: - file = open(filename, 'wb') - try: + original_main = main + + logger.debug("original main namespace: %s", dir(main)) + if refimported: + main, modmap = _stash_modules(main, original_main) + + with _open(filename, 'wb', seekable=True) as file: pickler = Pickler(file, protocol, **kwds) - pickler._original_main = main - if refimported: - main = _stash_modules(main) pickler._main = main #FIXME: dill.settings are disabled pickler._byref = False # disable pickling by name reference pickler._recurse = False # disable pickling recursion for globals pickler._session = True # is best indicator of when pickling a session pickler._first_pass = True - pickler._main_modified = main is not pickler._original_main + if main is not original_main: + pickler._original_main = original_main + _fix_module_namespace(main, original_main) + if refonfail: + pickler._refonfail = True # False by default + pickler._file_seek = file.seek + pickler._file_truncate = file.truncate + pickler._saved_byref = [] + if refimported: + # Cache modmap for refonfail. + pickler._modmap = modmap + if logger.isEnabledFor(logging.TRACE): + pickler._id_to_name = {id(v): k for k, v in main.__dict__.items()} pickler.dump(main) - finally: - if file is not filename: # if newly opened file - file.close() + if refonfail and pickler._saved_byref and logger.isEnabledFor(logging.INFO): + saved_byref = {var: "%s.%s" % (mod, obj) for var, mod, obj in pickler._saved_byref} + message = "[dump_module] Variables saved by reference (refonfail):\n" + logger.info(message + _format_log_dict(saved_byref)) return # Backward compatibility. @@ -247,42 +330,6 @@ def dump_session(filename=str(TEMPDIR/'session.pkl'), main=None, byref=False, ** dump_module(filename, module=main, refimported=byref, **kwds) dump_session.__doc__ = dump_module.__doc__ -class _PeekableReader: - """lightweight stream wrapper that implements peek()""" - def __init__(self, stream): - self.stream = stream - def read(self, n): - return self.stream.read(n) - def readline(self): - return self.stream.readline() - def tell(self): - return self.stream.tell() - def close(self): - return self.stream.close() - def peek(self, n): - stream = self.stream - try: - if hasattr(stream, 'flush'): stream.flush() - position = stream.tell() - stream.seek(position) # assert seek() works before reading - chunk = stream.read(n) - stream.seek(position) - return chunk - except (AttributeError, OSError): - raise NotImplementedError("stream is not peekable: %r", stream) from None - -def _make_peekable(stream): - """return stream as an object with a peek() method""" - import io - if hasattr(stream, 'peek'): - return stream - if not (hasattr(stream, 'tell') and hasattr(stream, 'seek')): - try: - return io.BufferedReader(stream) - except Exception: - pass - return _PeekableReader(stream) - def _identify_module(file, main=None): """identify the name of the module stored in the given file-type object""" from pickletools import genops @@ -311,34 +358,87 @@ def load_module( module: Optional[Union[ModuleType, str]] = None, **kwds ) -> Optional[ModuleType]: - """Update the selected module (default is :py:mod:`__main__`) with - the state saved at ``filename``. + """Update the selected module with the state saved at ``filename``. - Restore a module to the state saved with :py:func:`dump_module`. The - saved module can be :py:mod:`__main__` (e.g. an interpreter session), + Restore a module to the state saved with :func:`dump_module`. The + saved module can be :mod:`__main__` (e.g. an interpreter session), an imported module, or a module-type object (e.g. created with - :py:class:`~types.ModuleType`). + :class:`~types.ModuleType`). - When restoring the state of a non-importable module-type object, the - current instance of this module may be passed as the argument ``main``. - Otherwise, a new instance is created with :py:class:`~types.ModuleType` + When restoring the state of a non-importable, module-type object, the + current instance of this module may be passed as the argument ``module``. + Otherwise, a new instance is created with :class:`~types.ModuleType` and returned. Parameters: filename: a path-like object or a readable stream. module: a module object or the name of an importable module; - the module name and kind (i.e. imported or non-imported) must + the module's name and kind (i.e. imported or non-imported) must match the name and kind of the module stored at ``filename``. - **kwds: extra keyword arguments passed to :py:class:`Unpickler()`. + **kwds: extra keyword arguments passed to :class:`Unpickler()`. Raises: - :py:exc:`UnpicklingError`: if unpickling fails. - :py:exc:`ValueError`: if the argument ``main`` and module saved - at ``filename`` are incompatible. + :exc:`UnpicklingError`: if unpickling fails. + :exc:`ValueError`: if the argument ``module`` and the module + saved at ``filename`` are incompatible. Returns: - A module object, if the saved module is not :py:mod:`__main__` or - a module instance wasn't provided with the argument ``main``. + A module object, if the saved module is not :mod:`__main__` and + a module instance wasn't provided with the argument ``module``. + + Passing an argument to ``module`` forces `dill` to verify that the module + being loaded is compatible with the argument value. Additionally, if the + argument is a module instance (instead of a module name), it supresses the + return value. Each case and behavior is exemplified below: + + 1. `module`: ``None`` --- This call loads a previously saved state of + the module ``math`` and returns it (the module object) at the end: + + >>> import dill + >>> # load module -> restore state -> return module + >>> dill.load_module('math_session.pkl') + + + 2. `module`: ``str`` --- Passing the module name does the same as above, + but also verifies that the module being loaded, restored and returned is + indeed ``math``: + + >>> import dill + >>> # load module -> check name/kind -> restore state -> return module + >>> dill.load_module('math_session.pkl', module='math') + + >>> dill.load_module('math_session.pkl', module='cmath') + ValueError: can't update module 'cmath' with the saved state of module 'math' + + 3. `module`: ``ModuleType`` --- Passing the module itself instead of its + name has the additional effect of suppressing the return value (and the + module is already loaded at this point): + + >>> import dill + >>> import math + >>> # check name/kind -> restore state -> return None + >>> dill.load_module('math_session.pkl', module=math) + + For imported modules, the return value is meant as a convenience, so that + the function call can substitute an ``import`` statement. Therefore these + statements: + + >>> import dill + >>> math2 = dill.load_module('math_session.pkl', module='math') + + are equivalent to these: + + >>> import dill + >>> import math as math2 + >>> dill.load_module('math_session.pkl', module=math2) + + Note that, in both cases, ``math2`` is just a reference to + ``sys.modules['math']``: + + >>> import math + >>> import sys + >>> math is math2 is sys.modules['math'] + True Examples: @@ -402,10 +502,6 @@ def load_module( *Changed in version 0.3.6:* Function ``load_session()`` was renamed to ``load_module()``. Parameter ``main`` was renamed to ``module``. - - See also: - :py:func:`load_module_asdict` to load the contents of module saved - with :py:func:`dump_module` into a dictionary. """ if 'main' in kwds: warnings.warn( @@ -415,20 +511,12 @@ def load_module( if module is not None: raise TypeError("both 'module' and 'main' arguments were used") module = kwds.pop('main') - main = module - if hasattr(filename, 'read'): - file = filename - else: - file = open(filename, 'rb') - try: - file = _make_peekable(file) - #FIXME: dill.settings are disabled - unpickler = Unpickler(file, **kwds) - unpickler._session = True - # Resolve unpickler._main + main = module + with _open(filename, 'rb', peekable=True) as file: + # Resolve main. pickle_main = _identify_module(file, main) - if main is None and pickle_main is not None: + if main is None: main = pickle_main if isinstance(main, str): if main.startswith('__runtime__.'): @@ -436,12 +524,8 @@ def load_module( main = ModuleType(main.partition('.')[-1]) else: main = _import_module(main) - if main is not None: - if not isinstance(main, ModuleType): - raise TypeError("%r is not a module" % main) - unpickler._main = main - else: - main = unpickler._main + if not isinstance(main, ModuleType): + raise TypeError("%r is not a module" % main) # Check against the pickle's main. is_main_imported = _is_imported_module(main) @@ -450,32 +534,33 @@ def load_module( if is_runtime_mod: pickle_main = pickle_main.partition('.')[-1] error_msg = "can't update{} module{} %r with the saved state of{} module{} %r" - if is_runtime_mod and is_main_imported: + if main.__name__ != pickle_main: + raise ValueError(error_msg.format("", "", "", "") % (main.__name__, pickle_main)) + elif is_runtime_mod and is_main_imported: raise ValueError( error_msg.format(" imported", "", "", "-type object") - % (main.__name__, pickle_main) + % (main.__name__, main.__name__) ) - if not is_runtime_mod and not is_main_imported: + elif not is_runtime_mod and not is_main_imported: raise ValueError( error_msg.format("", "-type object", " imported", "") - % (pickle_main, main.__name__) + % (main.__name__, main.__name__) ) - if main.__name__ != pickle_main: - raise ValueError(error_msg.format("", "", "", "") % (main.__name__, pickle_main)) - - # This is for find_class() to be able to locate it. - if not is_main_imported: - runtime_main = '__runtime__.%s' % main.__name__ - sys.modules[runtime_main] = main - loaded = unpickler.load() - finally: - if not hasattr(filename, 'read'): # if newly opened file - file.close() + # Load the module's state. + #FIXME: dill.settings are disabled + unpickler = Unpickler(file, **kwds) + unpickler._session = True try: - del sys.modules[runtime_main] - except (KeyError, NameError): - pass + if not is_main_imported: + # This is for find_class() to be able to locate it. + runtime_main = '__runtime__.%s' % main.__name__ + sys.modules[runtime_main] = main + loaded = unpickler.load() + finally: + if not is_main_imported: + del sys.modules[runtime_main] + assert loaded is main _restore_modules(unpickler, main) if main is _main_module or main is module: @@ -491,9 +576,8 @@ def load_session(filename=str(TEMPDIR/'session.pkl'), main=None, **kwds): def load_module_asdict( filename = str(TEMPDIR/'session.pkl'), - update: bool = False, **kwds -) -> dict: +) -> Dict[str, Any]: """ Load the contents of a saved module into a dictionary. @@ -501,27 +585,22 @@ def load_module_asdict( lambda filename: vars(dill.load_module(filename)).copy() - however, does not alter the original module. Also, the path of - the loaded module is stored in the ``__session__`` attribute. + however, it does not alter the original module. Also, the path of + the loaded file is stored with the key ``'__session__'``. Parameters: filename: a path-like object or a readable stream - update: if `True`, initialize the dictionary with the current state - of the module prior to loading the state stored at filename. - **kwds: extra keyword arguments passed to :py:class:`Unpickler()` + **kwds: extra keyword arguments passed to :class:`Unpickler()` Raises: - :py:exc:`UnpicklingError`: if unpickling fails + :exc:`UnpicklingError`: if unpickling fails Returns: A copy of the restored module's dictionary. Note: - If ``update`` is True, the corresponding module may first be imported - into the current namespace before the saved state is loaded from - filename to the dictionary. Note that any module that is imported into - the current namespace as a side-effect of using ``update`` will not be - modified by loading the saved module in filename to a dictionary. + Even if not changed, the module refered in the file is always loaded + before its saved state is restored from `filename` to the dictionary. Example: >>> import dill @@ -541,46 +620,52 @@ def load_module_asdict( False >>> main['anum'] == anum # changed after the session was saved False - >>> new_var in main # would be True if the option 'update' was set - False + >>> new_var in main # it was initialized with the current state of __main__ + True """ if 'module' in kwds: raise TypeError("'module' is an invalid keyword argument for load_module_asdict()") - if hasattr(filename, 'read'): - file = filename - else: - file = open(filename, 'rb') - try: - file = _make_peekable(file) - main_name = _identify_module(file) - old_main = sys.modules.get(main_name) - main = ModuleType(main_name) - if update: - if old_main is None: - old_main = _import_module(main_name) - main.__dict__.update(old_main.__dict__) - else: - main.__builtins__ = __builtin__ - sys.modules[main_name] = main - load_module(file, **kwds) - finally: - if not hasattr(filename, 'read'): # if newly opened file - file.close() + + with _open(filename, 'rb', peekable=True) as file: + main_qualname = _identify_module(file) + main = _import_module(main_qualname) + main_copy = ModuleType(main_qualname) + main_copy.__dict__.clear() + main_copy.__dict__.update(main.__dict__) + + parent_name, _, main_name = main_qualname.rpartition('.') + if parent_name: + parent = sys.modules[parent_name] try: - if old_main is None: - del sys.modules[main_name] - else: - sys.modules[main_name] = old_main - except NameError: # failed before setting old_main - pass - main.__session__ = str(filename) - return main.__dict__ + sys.modules[main_qualname] = main_copy + if parent_name and getattr(parent, main_name, None) is main: + setattr(parent, main_name, main_copy) + load_module(file, **kwds) + finally: + sys.modules[main_qualname] = main + if parent_name and getattr(parent, main_name, None) is main_copy: + setattr(parent, main_name, main) + + if isinstance(getattr(filename, 'name', None), str): + main_copy.__session__ = filename.name + else: + main_copy.__session__ = str(filename) + return main_copy.__dict__ + + +## Session settings ## + +settings = { + 'refimported': False, + 'refonfail': True, +} + +## Variables set in this module to avoid circular import problems ## # Internal exports for backward compatibility with dill v0.3.5.1 -# Can't be placed in dill._dill because of circular import problems. for name in ( - '_lookup_module', '_module_map', '_restore_modules', '_stash_modules', + '_restore_modules', '_stash_modules', 'dump_session', 'load_session' # backward compatibility functions ): setattr(_dill, name, globals()[name]) diff --git a/dill/tests/test_logger.py b/dill/tests/test_logging.py similarity index 97% rename from dill/tests/test_logger.py rename to dill/tests/test_logging.py index b4e4881a..ed33e6c4 100644 --- a/dill/tests/test_logger.py +++ b/dill/tests/test_logging.py @@ -11,7 +11,7 @@ import dill from dill import detect -from dill.logger import stderr_handler, adapter as logger +from dill.logging import stderr_handler, adapter as logger try: from StringIO import StringIO diff --git a/dill/tests/test_session.py b/dill/tests/test_session.py index 51128916..52e9cdd0 100644 --- a/dill/tests/test_session.py +++ b/dill/tests/test_session.py @@ -11,8 +11,10 @@ import __main__ from contextlib import suppress from io import BytesIO +from types import ModuleType import dill +from dill import _dill session_file = os.path.join(os.path.dirname(__file__), 'session-refimported-%s.pkl') @@ -20,7 +22,7 @@ # Child process # ################### -def _error_line(error, obj, refimported): +def _error_line(obj, refimported): import traceback line = traceback.format_exc().splitlines()[-2].replace('[obj]', '['+repr(obj)+']') return "while testing (with refimported=%s): %s" % (refimported, line.lstrip()) @@ -52,7 +54,7 @@ def test_modules(refimported): assert __main__.complex_log is cmath.log except AssertionError as error: - error.args = (_error_line(error, obj, refimported),) + error.args = (_error_line(obj, refimported),) raise test_modules(refimported) @@ -91,6 +93,7 @@ def weekdays(self): return [day_name[i] for i in self.iterweekdays()] cal = CalendarSubclass() selfref = __main__ +self_dict = __main__.__dict__ # Setup global namespace for session saving tests. class TestNamespace: @@ -120,7 +123,7 @@ def _clean_up_cache(module): def _test_objects(main, globals_copy, refimported): try: main_dict = __main__.__dict__ - global Person, person, Calendar, CalendarSubclass, cal, selfref + global Person, person, Calendar, CalendarSubclass, cal, selfref, self_dict for obj in ('json', 'url', 'local_mod', 'sax', 'dom'): assert globals()[obj].__name__ == globals_copy[obj].__name__ @@ -141,9 +144,10 @@ def _test_objects(main, globals_copy, refimported): assert cal.weekdays() == globals_copy['cal'].weekdays() assert selfref is __main__ + assert self_dict is __main__.__dict__ except AssertionError as error: - error.args = (_error_line(error, obj, refimported),) + error.args = (_error_line(obj, refimported),) raise def test_session_main(refimported): @@ -192,13 +196,12 @@ def test_session_other(): assert module.selfref is module def test_runtime_module(): - from types import ModuleType - modname = '__runtime__' - runtime = ModuleType(modname) - runtime.x = 42 + modname = 'runtime' + runtime_mod = ModuleType(modname) + runtime_mod.x = 42 - mod = dill.session._stash_modules(runtime) - if mod is not runtime: + mod, _ = dill.session._stash_modules(runtime_mod, runtime_mod) + if mod is not runtime_mod: print("There are objects to save by referenece that shouldn't be:", mod.__dill_imported, mod.__dill_imported_as, mod.__dill_imported_top_level, file=sys.stderr) @@ -207,46 +210,23 @@ def test_runtime_module(): # without imported objects in the namespace. It's a contrived example because # even dill can't be in it. This should work after fixing #462. session_buffer = BytesIO() - dill.dump_module(session_buffer, module=runtime, refimported=True) + dill.dump_module(session_buffer, module=runtime_mod, refimported=True) session_dump = session_buffer.getvalue() # Pass a new runtime created module with the same name. - runtime = ModuleType(modname) # empty - return_val = dill.load_module(BytesIO(session_dump), module=runtime) + runtime_mod = ModuleType(modname) # empty + return_val = dill.load_module(BytesIO(session_dump), module=runtime_mod) assert return_val is None - assert runtime.__name__ == modname - assert runtime.x == 42 - assert runtime not in sys.modules.values() + assert runtime_mod.__name__ == modname + assert runtime_mod.x == 42 + assert runtime_mod not in sys.modules.values() # Pass nothing as main. load_module() must create it. session_buffer.seek(0) - runtime = dill.load_module(BytesIO(session_dump)) - assert runtime.__name__ == modname - assert runtime.x == 42 - assert runtime not in sys.modules.values() - -def test_refimported_imported_as(): - import collections - import concurrent.futures - import types - import typing - mod = sys.modules['__test__'] = types.ModuleType('__test__') - dill.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - mod.Dict = collections.UserDict # select by type - mod.AsyncCM = typing.AsyncContextManager # select by __module__ - mod.thread_exec = dill.executor # select by __module__ with regex - - session_buffer = BytesIO() - dill.dump_module(session_buffer, mod, refimported=True) - session_buffer.seek(0) - mod = dill.load(session_buffer) - del sys.modules['__test__'] - - assert set(mod.__dill_imported_as) == { - ('collections', 'UserDict', 'Dict'), - ('typing', 'AsyncContextManager', 'AsyncCM'), - ('dill', 'executor', 'thread_exec'), - } + runtime_mod = dill.load_module(BytesIO(session_dump)) + assert runtime_mod.__name__ == modname + assert runtime_mod.x == 42 + assert runtime_mod not in sys.modules.values() def test_load_module_asdict(): with TestNamespace(): @@ -268,13 +248,155 @@ def test_load_module_asdict(): assert main_vars['names'] == names assert main_vars['names'] is not names assert main_vars['x'] != x - assert 'y' not in main_vars + assert 'y' in main_vars assert 'empty' in main_vars + # Test a submodule. + import html + from html import entities + entitydefs = entities.entitydefs + + session_buffer = BytesIO() + dill.dump_module(session_buffer, entities) + session_buffer.seek(0) + entities_vars = dill.load_module_asdict(session_buffer) + + assert entities is html.entities # restored + assert entities is sys.modules['html.entities'] # restored + assert entitydefs is entities.entitydefs # unchanged + assert entitydefs is not entities_vars['entitydefs'] # saved by value + assert entitydefs == entities_vars['entitydefs'] + +def test_lookup_module(): + assert not _dill._is_builtin_module(local_mod) and local_mod.__package__ == '' + + def lookup(mod, name, obj, lookup_by_name=True): + from dill._dill import _lookup_module, _module_map + return _lookup_module(_module_map(mod), name, obj, lookup_by_name) + + name = '__unpickleable' + obj = object() + setattr(dill, name, obj) + assert lookup(dill, name, obj) == (None, None, None) + + # 4th level: non-installed module + setattr(local_mod, name, obj) + sys.modules[local_mod.__name__] = sys.modules.pop(local_mod.__name__) # put at the end + assert lookup(dill, name, obj) == (local_mod.__name__, name, False) # not installed + try: + import pox + # 3rd level: installed third-party module + setattr(pox, name, obj) + sys.modules['pox'] = sys.modules.pop('pox') + assert lookup(dill, name, obj) == ('pox', name, True) + except ModuleNotFoundError: + pass + # 2nd level: module of same package + setattr(dill.session, name, obj) + sys.modules['dill.session'] = sys.modules.pop('dill.session') + assert lookup(dill, name, obj) == ('dill.session', name, True) + # 1st level: stdlib module + setattr(os, name, obj) + sys.modules['os'] = sys.modules.pop('os') + assert lookup(dill, name, obj) == ('os', name, True) + + # Lookup by id. + name2 = name + '2' + setattr(dill, name2, obj) + assert lookup(dill, name2, obj) == ('os', name, True) + assert lookup(dill, name2, obj, lookup_by_name=False) == (None, None, None) + setattr(local_mod, name2, obj) + assert lookup(dill, name2, obj) == (local_mod.__name__, name2, False) + +def test_refimported(): + import collections + import concurrent.futures + import types + import typing + + mod = sys.modules['__test__'] = ModuleType('__test__') + mod.builtin_module_names = sys.builtin_module_names + dill.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + mod.Dict = collections.UserDict # select by type + mod.AsyncCM = typing.AsyncContextManager # select by __module__ + mod.thread_exec = dill.executor # select by __module__ with regex + mod.local_mod = local_mod + + session_buffer = BytesIO() + dill.dump_module(session_buffer, mod, refimported=True) + session_buffer.seek(0) + mod = dill.load(session_buffer) + + assert mod.__dill_imported == [('sys', 'builtin_module_names')] + assert set(mod.__dill_imported_as) == { + ('collections', 'UserDict', 'Dict'), + ('typing', 'AsyncContextManager', 'AsyncCM'), + ('dill', 'executor', 'thread_exec'), + } + assert mod.__dill_imported_top_level == [(local_mod.__name__, 'local_mod')] + + session_buffer.seek(0) + dill.load_module(session_buffer, mod) + del sys.modules['__test__'] + assert mod.builtin_module_names is sys.builtin_module_names + assert mod.Dict is collections.UserDict + assert mod.AsyncCM is typing.AsyncContextManager + assert mod.thread_exec is dill.executor + assert mod.local_mod is local_mod + +def test_unpickleable_var(): + global local_mod + import keyword as builtin_mod + from dill._dill import _global_string + refonfail_default = dill.session.settings['refonfail'] + dill.session.settings['refonfail'] = True + name = '__unpickleable' + obj = memoryview(b'') + assert _dill._is_builtin_module(builtin_mod) + assert not _dill._is_builtin_module(local_mod) + # assert not dill.pickles(obj) + try: + dill.dumps(obj) + except _dill.UNPICKLEABLE_ERRORS: + pass + else: + raise Exception("test object should be unpickleable") + + def dump_with_ref(mod, other_mod): + setattr(other_mod, name, obj) + buf = BytesIO() + dill.dump_module(buf, mod) + return buf.getvalue() + + # "user" modules + _local_mod = local_mod + del local_mod # remove from __main__'s namespace + try: + dump_with_ref(__main__, __main__) + except dill.PicklingError: + pass # success + else: + raise Exception("saving with a reference to the module itself should fail for '__main__'") + assert _global_string(_local_mod.__name__, name) in dump_with_ref(__main__, _local_mod) + assert _global_string('os', name) in dump_with_ref(__main__, os) + local_mod = _local_mod + del _local_mod, __main__.__unpickleable, local_mod.__unpickleable, os.__unpickleable + + # "builtin" or "installed" modules + assert _global_string(builtin_mod.__name__, name) in dump_with_ref(builtin_mod, builtin_mod) + assert _global_string(builtin_mod.__name__, name) in dump_with_ref(builtin_mod, local_mod) + assert _global_string('os', name) in dump_with_ref(builtin_mod, os) + del builtin_mod.__unpickleable, local_mod.__unpickleable, os.__unpickleable + + dill.session.settings['refonfail'] = refonfail_default + if __name__ == '__main__': - test_session_main(refimported=False) - test_session_main(refimported=True) + if os.getenv('COVERAGE') != 'true': + test_session_main(refimported=False) + test_session_main(refimported=True) test_session_other() test_runtime_module() - test_refimported_imported_as() test_load_module_asdict() + test_lookup_module() + test_refimported() + test_unpickleable_var() diff --git a/dill/tests/test_stdlib_modules.py b/dill/tests/test_stdlib_modules.py new file mode 100644 index 00000000..15cb0767 --- /dev/null +++ b/dill/tests/test_stdlib_modules.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python + +# Author: Leonardo Gama (@leogama) +# Copyright (c) 2022 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE + +import io +import itertools +import logging +import multiprocessing +import os +import sys +import warnings + +import dill + +if not dill._dill.OLD310: + STDLIB_MODULES = list(sys.stdlib_module_names) + STDLIB_MODULES += [ + # From https://docs.python.org/3.11/library/ + 'collections.abc', 'concurrent.futures', 'curses.ascii', 'curses.panel', 'curses.textpad', + 'html.entities', 'html.parser', 'http.client', 'http.cookiejar', 'http.cookies', 'http.server', + 'importlib.metadata', 'importlib.resources', 'importlib.resources.abc', 'logging.config', + 'logging.handlers', 'multiprocessing.shared_memory', 'os.path', 'test.support', + 'test.support.bytecode_helper', 'test.support.import_helper', 'test.support.os_helper', + 'test.support.script_helper', 'test.support.socket_helper', 'test.support.threading_helper', + 'test.support.warnings_helper', 'tkinter.colorchooser', 'tkinter.dnd', 'tkinter.font', + 'tkinter.messagebox', 'tkinter.scrolledtext', 'tkinter.tix', 'tkinter.ttk', 'unittest.mock', + 'urllib.error', 'urllib.parse', 'urllib.request', 'urllib.response', 'urllib.robotparser', + 'xml.dom', 'xml.dom.minidom', 'xml.dom.pulldom', 'xml.etree.ElementTree', 'xml.parsers.expat', + 'xml.sax', 'xml.sax.handler', 'xml.sax.saxutils', 'xml.sax.xmlreader', 'xmlrpc.client', + 'xmlrpc.server', + ] + STDLIB_MODULES.sort() +else: + STDLIB_MODULES = [ + # From https://docs.python.org/3.9/library/ + '__future__', '_thread', 'abc', 'aifc', 'argparse', 'array', 'ast', 'asynchat', 'asyncio', + 'asyncore', 'atexit', 'audioop', 'base64', 'bdb', 'binascii', 'binhex', 'bisect', 'builtins', + 'bz2', 'calendar', 'cgi', 'cgitb', 'chunk', 'cmath', 'cmd', 'code', 'codecs', 'codeop', + 'collections', 'collections.abc', 'colorsys', 'compileall', 'concurrent', 'concurrent.futures', + 'configparser', 'contextlib', 'contextvars', 'copy', 'copyreg', 'crypt', 'csv', 'ctypes', + 'curses', 'curses.ascii', 'curses.panel', 'curses.textpad', 'dataclasses', 'datetime', 'dbm', + 'decimal', 'difflib', 'dis', 'distutils', 'doctest', 'email', 'ensurepip', 'enum', 'errno', + 'faulthandler', 'fcntl', 'filecmp', 'fileinput', 'fnmatch', 'formatter', 'fractions', 'ftplib', + 'functools', 'gc', 'getopt', 'getpass', 'gettext', 'glob', 'graphlib', 'grp', 'gzip', 'hashlib', + 'heapq', 'hmac', 'html', 'html.entities', 'html.parser', 'http', 'http.client', + 'http.cookiejar', 'http.cookies', 'http.server', 'imaplib', 'imghdr', 'imp', 'importlib', + 'importlib.metadata', 'inspect', 'io', 'ipaddress', 'itertools', 'json', 'keyword', 'linecache', + 'locale', 'logging', 'logging.config', 'logging.handlers', 'lzma', 'mailbox', 'mailcap', + 'marshal', 'math', 'mimetypes', 'mmap', 'modulefinder', 'msilib', 'msvcrt', 'multiprocessing', + 'multiprocessing.shared_memory', 'netrc', 'nis', 'nntplib', 'numbers', 'operator', 'optparse', + 'os', 'os.path', 'ossaudiodev', 'parser', 'pathlib', 'pdb', 'pickle', 'pickletools', 'pipes', + 'pkgutil', 'platform', 'plistlib', 'poplib', 'posix', 'pprint', 'pty', 'pwd', 'py_compile', + 'pyclbr', 'pydoc', 'queue', 'quopri', 'random', 're', 'readline', 'reprlib', 'resource', + 'rlcompleter', 'runpy', 'sched', 'secrets', 'select', 'selectors', 'shelve', 'shlex', 'shutil', + 'signal', 'site', 'site', 'smtpd', 'smtplib', 'sndhdr', 'socket', 'socketserver', 'spwd', + 'sqlite3', 'ssl', 'stat', 'statistics', 'string', 'stringprep', 'struct', 'subprocess', 'sunau', + 'symbol', 'symtable', 'sys', 'sysconfig', 'syslog', 'tabnanny', 'tarfile', 'telnetlib', + 'tempfile', 'termios', 'test', 'test.support', 'test.support.bytecode_helper', + 'test.support.script_helper', 'test.support.socket_helper', 'textwrap', 'threading', 'time', + 'timeit', 'tkinter', 'tkinter.colorchooser', 'tkinter.dnd', 'tkinter.font', + 'tkinter.messagebox', 'tkinter.scrolledtext', 'tkinter.tix', 'tkinter.ttk', 'token', 'tokenize', + 'trace', 'traceback', 'tracemalloc', 'tty', 'turtle', 'types', 'typing', 'unicodedata', + 'unittest', 'unittest.mock', 'urllib', 'urllib.error', 'urllib.parse', 'urllib.request', + 'urllib.response', 'urllib.robotparser', 'uu', 'uuid', 'venv', 'warnings', 'wave', 'weakref', + 'webbrowser', 'winreg', 'winsound', 'wsgiref', 'xdrlib', 'xml.dom', 'xml.dom.minidom', + 'xml.dom.pulldom', 'xml.etree.ElementTree', 'xml.parsers.expat', 'xml.sax', 'xml.sax.handler', + 'xml.sax.saxutils', 'xml.sax.xmlreader', 'xmlrpc', 'xmlrpc.client', 'xmlrpc.server', 'zipapp', + 'zipfile', 'zipimport', 'zlib', 'zoneinfo', +] + +def _dump_load_module(module_name, refonfail): + try: + __import__(module_name) + except ImportError: + return None, None + success_load = None + buf = io.BytesIO() + try: + dill.dump_module(buf, module_name, refonfail=refonfail) + except Exception: + print("F", end="") + success_dump = False + return success_dump, success_load + print(":", end="") + success_dump = True + buf.seek(0) + try: + module = dill.load_module(buf) + except Exception: + success_load = False + return success_dump, success_load + success_load = True + return success_dump, success_load + +def test_stdlib_modules(): + modules = [x for x in STDLIB_MODULES if + not x.startswith('_') + and not x.startswith('test') + and x not in ('antigravity', 'this')] + + + print("\nTesting pickling and unpickling of Standard Library modules...") + message = "Success rate (%s_module, refonfail=%s): %.1f%% [%d/%d]" + with multiprocessing.Pool(maxtasksperchild=1) as pool: + for refonfail in (False, True): + args = zip(modules, itertools.repeat(refonfail)) + result = pool.starmap(_dump_load_module, args, chunksize=1) + dump_successes = sum(dumped for dumped, loaded in result if dumped is not None) + load_successes = sum(loaded for dumped, loaded in result if loaded is not None) + dump_failures = sum(not dumped for dumped, loaded in result if dumped is not None) + load_failures = sum(not loaded for dumped, loaded in result if loaded is not None) + dump_total = dump_successes + dump_failures + load_total = load_successes + load_failures + dump_percent = 100 * dump_successes / dump_total + load_percent = 100 * load_successes / load_total + if logging.getLogger().isEnabledFor(logging.INFO): print() + logging.info(message, "dump", refonfail, dump_percent, dump_successes, dump_total) + logging.info(message, "load", refonfail, load_percent, load_successes, load_total) + if refonfail: + failed_dump = [mod for mod, (dumped, _) in zip(modules, result) if dumped is False] + failed_load = [mod for mod, (_, loaded) in zip(modules, result) if loaded is False] + if failed_dump: + logging.info("dump_module() FAILURES: %s", str(failed_dump).replace("'", "")[1:-1]) + if failed_load: + logging.info("load_module() FAILURES: %s", str(failed_load).replace("'", "")[1:-1]) + assert dump_percent > 99 + assert load_percent > 85 #FIXME: many important modules fail to unpickle + print() + +if __name__ == '__main__': + logging.basicConfig(level=os.environ.get('PYTHONLOGLEVEL', 'WARNING')) + warnings.simplefilter('ignore') + test_stdlib_modules() diff --git a/dill/tests/test_utils.py b/dill/tests/test_utils.py new file mode 100644 index 00000000..8da0ac99 --- /dev/null +++ b/dill/tests/test_utils.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python + +# Author: Leonardo Gama (@leogama) +# Copyright (c) 2022 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE + +"""test general utilities in _utils.py""" + +import io +import os +import sys + +from dill import _utils + +def test_format_bytes(): + formatb = _utils._format_bytes_size + assert formatb(1000) == (1000, 'B') + assert formatb(1024) == (1, 'KiB') + assert formatb(1024 + 511) == (1, 'KiB') + assert formatb(1024 + 512) == (2, 'KiB') + assert formatb(10**9) == (954, 'MiB') + +def test_open(): + file_unpeekable = open(__file__, 'rb', buffering=0) + assert not hasattr(file_unpeekable, 'peek') + + content = file_unpeekable.read() + peeked_chars = content[:10] + first_line = content[:100].partition(b'\n')[0] + b'\n' + file_unpeekable.seek(0) + + # Test _PeekableReader for seekable stream + with _utils._open(file_unpeekable, 'r', peekable=True) as file: + assert isinstance(file, _utils._PeekableReader) + assert file.peek(10)[:10] == peeked_chars + assert file.readline() == first_line + assert not file_unpeekable.closed + file_unpeekable.close() + + _pipe_r, _pipe_w = os.pipe() + pipe_r = io.FileIO(_pipe_r, closefd=False) + pipe_w = io.FileIO(_pipe_w, mode='w') + assert not hasattr(pipe_r, 'peek') + assert not pipe_r.seekable() + assert not pipe_w.seekable() + + # Test io.BufferedReader for unseekable stream + with _utils._open(pipe_r, 'r', peekable=True) as file: + assert isinstance(file, io.BufferedReader) + pipe_w.write(content[:100]) + assert file.peek(10)[:10] == peeked_chars + assert file.readline() == first_line + assert not pipe_r.closed + + # Test _SeekableWriter for unseekable stream + with _utils._open(pipe_w, 'w', seekable=True) as file: + # pipe_r is closed here for some reason... + assert isinstance(file, _utils._SeekableWriter) + file.write(content) + file.flush() + file.seek(0) + file.truncate() + file.write(b'a line of text\n') + assert not pipe_w.closed + pipe_r = io.FileIO(_pipe_r) + assert pipe_r.readline() == b'a line of text\n' + pipe_r.close() + pipe_w.close() + +if __name__ == '__main__': + test_format_bytes() + test_open() diff --git a/docs/source/dill.rst b/docs/source/dill.rst index 2770af2a..e1ca2344 100644 --- a/docs/source/dill.rst +++ b/docs/source/dill.rst @@ -25,17 +25,17 @@ detect module :imported-members: .. :exclude-members: ismethod, isfunction, istraceback, isframe, iscode, parent, reference, at, parents, children -logger module -------------- +logging module +-------------- -.. automodule:: dill.logger +.. automodule:: dill.logging :members: :undoc-members: :private-members: :special-members: :show-inheritance: :imported-members: -.. :exclude-members: + :exclude-members: +trace objtypes module --------------- @@ -62,7 +62,7 @@ pointers module .. :exclude-members: session module ---------------- +-------------- .. automodule:: dill.session :members: