diff --git a/dill/__init__.py b/dill/__init__.py index b28a1ac8..12ad3cec 100644 --- a/dill/__init__.py +++ b/dill/__init__.py @@ -283,16 +283,17 @@ """ -from ._dill import dump, dumps, load, loads, \ - Pickler, Unpickler, register, copy, pickle, pickles, check, \ - HIGHEST_PROTOCOL, DEFAULT_PROTOCOL, PicklingError, UnpicklingError, \ - HANDLE_FMODE, CONTENTS_FMODE, FILE_FMODE, PickleError, PickleWarning, \ - PicklingWarning, UnpicklingWarning +from ._dill import ( + Pickler, Unpickler, + dump, dumps, load, loads, copy, check, pickle, pickles, register, + DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, HANDLE_FMODE, CONTENTS_FMODE, FILE_FMODE, + PicklingError, UnpicklingError, PickleError, PicklingWarning, UnpicklingWarning, PickleWarning, + ) from .session import dump_session, load_session from . import detect, session, source, temp # get global settings -from .settings import settings +from .settings import Settings, settings # make sure "trace" is turned off detect.trace(False) diff --git a/dill/_dill.py b/dill/_dill.py index f0d55115..9b062a1d 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -1917,8 +1917,7 @@ def save_function(pickler, obj): _recurse = getattr(pickler, '_recurse', None) _byref = getattr(pickler, '_byref', None) _postproc = getattr(pickler, '_postproc', None) - _main_modified = getattr(pickler, '_main_modified', None) - _original_main = getattr(pickler, '_original_main', __builtin__)#'None' + _original_main = getattr(pickler, '_original_main', None) postproc_list = [] if _recurse: # recurse to get all globals referred to by obj @@ -1935,7 +1934,7 @@ def save_function(pickler, obj): # If the globals is the __dict__ from the module being saved as a # session, substitute it by the dictionary being actually saved. - if _main_modified and globs_copy is _original_main.__dict__: + if _original_main and globs_copy is _original_main.__dict__: globs_copy = getattr(pickler, '_main', _original_main).__dict__ globs = globs_copy # If the globals is a module __dict__, do not save it in the pickle. diff --git a/dill/_utils.py b/dill/_utils.py new file mode 100644 index 00000000..1ce67acd --- /dev/null +++ b/dill/_utils.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python +# +# Author: Leonardo Gama (@leogama) +# Copyright (c) 2022 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE +"""auxiliary internal classes used in multiple submodules, set here to avoid import recursion""" + +__all__ = ['AttrDict', 'ExcludeRules', 'Filter', 'RuleType'] + +import logging +logger = logging.getLogger('dill._utils') + +import inspect +from functools import partialmethod + +class AttrDict(dict): + """syntactic sugar for accessing dictionary items""" + _CAST = object() # singleton + def __init__(self, *args, **kwargs): + data = args[0] if len(args) == 2 and args[1] is self._CAST else dict(*args, **kwargs) + for key, val in tuple(data.items()): + if isinstance(val, dict) and not isinstance(val, AttrDict): + data[key] = AttrDict(val, self._CAST) + super().__setattr__('_data', data) + def _check_attr(self, name): + try: + super().__getattribute__(name) + except AttributeError: + pass + else: + raise AttributeError("'AttrDict' object attribute %r is read-only" % name) + def __getattr__(self, key): + # This is called only if dict.__getattribute__(key) fails. + try: + return self._data[key] + except KeyError: + raise AttributeError("'AttrDict' object has no attribute %r" % key) + def __setattr__(self, key, value): + self._check_attr(key) + if isinstance(value, dict): + self._data[key] = AttrDict(value, self._CAST) + else: + self._data[key] = value + def __delattr__(self, key): + self._check_attr(key) + del self._data[key] + def __proxy__(self, method, *args, **kwargs): + return getattr(self._data, method)(*args, **kwargs) + def __reduce__(self): + return AttrDict, (self._data,) + def copy(self): + # Deep copy. + copy = AttrDict(self._data) + for key, val in tuple(copy.items()): + if isinstance(val, AttrDict): + copy[key] = val.copy() + return copy + +for method, _ in inspect.getmembers(dict, inspect.ismethoddescriptor): + if method not in vars(AttrDict) and method not in {'__getattribute__', '__reduce_ex__'}: + setattr(AttrDict, method, partialmethod(AttrDict.__proxy__, method)) + + +### Namespace filtering +import re +from dataclasses import InitVar, dataclass, field, fields +from collections import abc, namedtuple +from enum import Enum +from functools import partialmethod +from itertools import filterfalse +from re import Pattern +from typing import Callable, Iterable, Set, Tuple, Union + +RuleType = Enum('RuleType', 'EXCLUDE INCLUDE', module=__name__) +NamedObj = namedtuple('NamedObj', 'name value', module=__name__) + +Filter = Union[str, Pattern, int, type, Callable] +Rule = Tuple[RuleType, Union[Filter, Iterable[Filter]]] + +def isiterable(arg): + return isinstance(arg, abc.Iterable) and not isinstance(arg, (str, bytes)) + +@dataclass +class ExcludeFilters: + ids: Set[int] = field(default_factory=set) + names: Set[str] = field(default_factory=set) + regex: Set[Pattern] = field(default_factory=set) + types: Set[type] = field(default_factory=set) + funcs: Set[Callable] = field(default_factory=set) + + @property + def filter_sets(self): + return tuple(field.name for field in fields(self)) + def __bool__(self): + return any(getattr(self, filter_set) for filter_set in self.filter_sets) + def _check(self, filter): + if isinstance(filter, str): + if filter.isidentifier(): + field = 'names' + else: + filter, field = re.compile(filter), 'regex' + elif isinstance(filter, Pattern): + field = 'regex' + elif isinstance(filter, int): + field = 'ids' + elif isinstance(filter, type): + field = 'types' + elif callable(filter): + field = 'funcs' + else: + raise ValueError("invalid filter: %r" % filter) + return filter, getattr(self, field) + def add(self, filter): + filter, filter_set = self._check(filter) + filter_set.add(filter) + def discard(self, filter): + filter, filter_set = self._check(filter) + filter_set.discard(filter) + def remove(self, filter): + filter, filter_set = self._check(filter) + filter_set.remove(filter) + def update(self, filters): + for filter in filters: + self.add(filter) + def clear(self): + for filter_set in self.filter_sets: + getattr(self, filter_set).clear() + def add_type(self, type_name): + import types + name_suffix = type_name + 'Type' if not type_name.endswith('Type') else type_name + if hasattr(types, name_suffix): + type_name = name_suffix + type_obj = getattr(types, type_name, None) + if not isinstance(type_obj, type): + named = type_name if type_name == name_suffix else "%r or %r" % (type_name, name_suffix) + raise NameError("could not find a type named %s in module 'types'" % named) + self.types.add(type_obj) + +@dataclass +class ExcludeRules: + exclude: ExcludeFilters = field(init=False, default_factory=ExcludeFilters) + include: ExcludeFilters = field(init=False, default_factory=ExcludeFilters) + rules: InitVar[Iterable[Rule]] = None + + def __post_init__(self, rules): + if rules is not None: + self.update(rules) + + def __proxy__(self, method, filter, *, rule_type=RuleType.EXCLUDE): + if rule_type is RuleType.EXCLUDE: + getattr(self.exclude, method)(filter) + elif rule_type is RuleType.INCLUDE: + getattr(self.include, method)(filter) + else: + raise ValueError("invalid rule type: %r (must be one of %r)" % (rule_type, list(RuleType))) + + add = partialmethod(__proxy__, 'add') + discard = partialmethod(__proxy__, 'discard') + remove = partialmethod(__proxy__, 'remove') + + def update(self, rules): + if isinstance(rules, ExcludeRules): + for filter_set in self.exclude.filter_sets: + getattr(self.exclude, filter_set).update(getattr(rules.exclude, filter_set)) + getattr(self.include, filter_set).update(getattr(rules.include, filter_set)) + else: + # Validate rules. + for rule in rules: + if not isinstance(rule, tuple) or len(rule) != 2: + raise ValueError("invalid rule format: %r" % rule) + for rule_type, filter in rules: + if isiterable(filter): + for f in filter: + self.add(f, rule_type=rule_type) + else: + self.add(filter, rule_type=rule_type) + + def clear(self): + self.exclude.clear() + self.include.clear() + + def filter_namespace(self, namespace, obj=None): + if not self.exclude and not self.include: + return namespace + + # Protect agains dict changes during the call. + namespace_copy = namespace.copy() if obj is None or namespace is vars(obj) else namespace + objects = all_objects = [NamedObj._make(item) for item in namespace_copy.items()] + + for filters in (self.exclude, self.include): + if filters is self.exclude and not filters: + # Treat the rule set as an allowlist. + exclude_objs = objects + continue + elif filters is self.include: + if not filters or not exclude_objs: + break + objects = exclude_objs + + flist = [] + types_list = tuple(filters.types) + # Apply cheaper/broader filters first. + if types_list: + flist.append(lambda obj: isinstance(obj.value, types_list)) + if filters.ids: + flist.append(lambda obj: id(obj.value) in filters.ids) + if filters.names: + flist.append(lambda obj: obj.name in filters.names) + if filters.regex: + flist.append(lambda obj: any(regex.fullmatch(obj.name) for regex in filters.regex)) + flist.extend(filters.funcs) + for f in flist: + objects = filterfalse(f, objects) + + if filters is self.exclude: + include_names = {obj.name for obj in objects} + exclude_objs = [obj for obj in all_objects if obj.name not in include_names] + else: + exclude_objs = list(objects) + + if not exclude_objs: + return namespace + if len(exclude_objs) == len(namespace): + warnings.warn("filtering operation left the namespace empty!", PicklingWarning) + return {} + if logger.isEnabledFor(logging.INFO): + exclude_listing = {obj.name: type(obj.value).__name__ for obj in sorted(exclude_objs)} + exclude_listing = str(exclude_listing).translate({ord(","): "\n", ord("'"): None}) + logger.info("Objects excluded from dump_session():\n%s\n", exclude_listing) + + for obj in exclude_objs: + del namespace_copy[obj.name] + return namespace_copy diff --git a/dill/session.py b/dill/session.py index 35485009..6e561e9e 100644 --- a/dill/session.py +++ b/dill/session.py @@ -10,12 +10,22 @@ Pickle and restore the intepreter session. """ -__all__ = ['dump_session', 'load_session'] +__all__ = ['dump_session', 'load_session', 'ipython_filter', 'ExcludeRules', 'EXCLUDE', 'INCLUDE'] -import logging, sys +import logging, re, sys +from copy import copy from dill import _dill, Pickler, Unpickler -from ._dill import ModuleType, _import_module, _is_builtin_module, _main_module, PY3 +from ._dill import ModuleType, _import_module, _is_builtin_module, _main_module +from ._utils import AttrDict, ExcludeRules, Filter, RuleType +from .settings import settings + +# Classes and abstract classes for type hints. +from io import BytesIO +from os import PathLike +from typing import Iterable, NoReturn, Union + +EXCLUDE, INCLUDE = RuleType.EXCLUDE, RuleType.INCLUDE SESSION_IMPORTED_AS_TYPES = tuple([Exception] + [getattr(_dill, name) for name in ('ModuleType', 'TypeType', 'FunctionType', 'MethodType', 'BuiltinMethodType')]) @@ -24,11 +34,9 @@ def _module_map(): """get map of imported modules""" - from collections import defaultdict, namedtuple - modmap = namedtuple('Modmap', ['by_name', 'by_id', 'top_level']) - modmap = modmap(defaultdict(list), defaultdict(list), {}) - items = 'items' if PY3 else 'iteritems' - for modname, module in getattr(sys.modules, items)(): + from collections import defaultdict + modmap = AttrDict(by_name=defaultdict(list), by_id=defaultdict(list), top_level={}) + for modname, module in sys.modules.items(): if not isinstance(module, ModuleType): continue if '.' not in modname: @@ -57,8 +65,7 @@ def _stash_modules(main_module): imported_as = [] imported_top_level = [] # keep separeted for backwards compatibility original = {} - items = 'items' if PY3 else 'iteritems' - for name, obj in getattr(main_module.__dict__, items)(): + for name, obj in vars(main_module).items(): if obj is main_module: original[name] = newmod # self-reference continue @@ -101,36 +108,64 @@ def _restore_modules(unpickler, main_module): except KeyError: pass -#NOTE: 06/03/15 renamed main_module to main -def dump_session(filename='/tmp/session.pkl', main=None, byref=False, **kwds): +def _filter_objects(main, exclude_extra, include_extra, obj=None): + filters = ExcludeRules(getattr(settings, 'session_exclude', None)) + if exclude_extra is not None: + filters.update([(EXCLUDE, exclude_extra)]) + if include_extra is not None: + filters.update([(INCLUDE, include_extra)]) + + namespace = filters.filter_namespace(vars(main), obj=obj) + if namespace is vars(main): + return main + + main = ModuleType(main.__name__) + vars(main).update(namespace) + return main + +def dump_session(filename: Union[PathLike, BytesIO] = '/tmp/session.pkl', + main: Union[str, ModuleType] = '__main__', + byref: bool = False, + exclude: Union[Filter, Iterable[Filter]] = None, + include: Union[Filter, Iterable[Filter]] = None, + **kwds) -> NoReturn: """pickle the current state of __main__ to a file""" - from .settings import settings - protocol = settings['protocol'] - if main is None: main = _main_module + protocol = settings.protocol + if isinstance(main, str): + main = _import_module(main) + original_main = main + if byref: + #NOTE: *must* run before _filter_objects() + main = _stash_modules(main) + main = _filter_objects(main, exclude, include, obj=original_main) + + print(list(vars(main))) + if hasattr(filename, 'write'): f = filename else: f = open(filename, 'wb') try: pickler = Pickler(f, protocol, **kwds) - pickler._original_main = main - if byref: - main = _stash_modules(main) pickler._main = main #FIXME: dill.settings are disabled pickler._byref = False # disable pickling by name reference pickler._recurse = False # disable pickling recursion for globals pickler._session = True # is best indicator of when pickling a session pickler._first_pass = True - pickler._main_modified = main is not pickler._original_main + if main is not original_main: + pickler._original_main = original_main pickler.dump(main) finally: if f is not filename: # If newly opened file f.close() return -def load_session(filename='/tmp/session.pkl', main=None, **kwds): +def load_session(filename: Union[PathLike, BytesIO] = '/tmp/session.pkl', + main: ModuleType = None, + **kwds) -> NoReturn: """update the __main__ module with the state from the session file""" - if main is None: main = _main_module + if main is None: + main = _main_module if hasattr(filename, 'read'): f = filename else: @@ -147,3 +182,43 @@ def load_session(filename='/tmp/session.pkl', main=None, **kwds): if f is not filename: # If newly opened file f.close() return + +############# +# IPython # +############# + +def ipython_filter(*, keep_input=True, keep_output=False): + """filter factory for IPython sessions (can't be added to settings currently) + + Usage: + >>> from dill.session import * + >>> dump_session(exclude=[ipython_filter()]) + """ + if not __builtins__.get('__IPYTHON__'): + # Return no-op filter if not in IPython. + return (lambda x: False) + + from IPython import get_ipython + ipython_shell = get_ipython() + + # Code snippet adapted from IPython.core.magics.namespace.who_ls() + user_ns = ipython_shell.user_ns + user_ns_hidden = ipython_shell.user_ns_hidden + nonmatching = object() # This can never be in user_ns + interactive_vars = {x for x in user_ns if user_ns[x] is not user_ns_hidden.get(x, nonmatching)} + + # Input and output history. + history_regex = [] + if keep_input: + interactive_vars |= {'_ih', 'In', '_i', '_ii', '_iii'} + history_regex.append(re.compile(r'_i\d+')) + if keep_output: + interactive_vars |= {'_oh', 'Out', '_', '__', '___'} + history_regex.append(re.compile(r'_\d+')) + + def not_interactive_var(obj): + if any(regex.fullmatch(obj.name) for regex in history_regex): + return False + return obj.name not in interactive_vars + + return not_interactive_var diff --git a/dill/settings.py b/dill/settings.py index 4d0226b0..9e3c06c9 100644 --- a/dill/settings.py +++ b/dill/settings.py @@ -9,12 +9,15 @@ global settings for Pickler """ +__all__ = ['settings', 'Settings'] + try: from pickle import DEFAULT_PROTOCOL except ImportError: from pickle import HIGHEST_PROTOCOL as DEFAULT_PROTOCOL +from ._utils import AttrDict as Settings, ExcludeRules -settings = { +settings = Settings({ #'main' : None, 'protocol' : DEFAULT_PROTOCOL, 'byref' : False, @@ -22,7 +25,8 @@ 'fmode' : 0, #HANDLE_FMODE 'recurse' : False, 'ignore' : False, -} + 'session_exclude': ExcludeRules(), +}) del DEFAULT_PROTOCOL