From 298a09972cfccb4d73abf37b145d20635a6e30e1 Mon Sep 17 00:00:00 2001 From: Leonardo Gama Date: Thu, 12 May 2022 16:51:15 -0300 Subject: [PATCH] Session: check id against module being saved --- dill/session.py | 16 ++++++++++------ dill/utils.py | 4 ++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/dill/session.py b/dill/session.py index b6ebd4fc..d1f08b8a 100644 --- a/dill/session.py +++ b/dill/session.py @@ -17,7 +17,7 @@ import dill from dill import Pickler, Unpickler -from ._dill import ModuleType, _import_module, _is_builtin_module +from ._dill import ModuleType, _import_module, _is_builtin_module, _main_module from .utils import AttrDict, CheckerSet, TransSet from .settings import settings @@ -111,10 +111,14 @@ def _exclude_objs(main, exclude_extra, filters_extra, settings): categories = {'ids': int, 'names': str, 'regex': re.Pattern, 'types': type} exclude = AttrDict({cat: copy(settings.session_exclude[cat]) for cat in categories}) filters = copy(settings.session_filters) + del categories['ids'] # special case if exclude_extra is not None: if isinstance(exclude_extra, str): raise ValueError("'exclude' can be of type Iterable[str], but not str") for item in exclude_extra: + if isinstance(item, int): + exclude.ids.add(item, main=main) + continue for category, klass in categories.items(): if isinstance(item, klass): exclude[category].add(item) @@ -220,12 +224,12 @@ def load_session(filename: Union[os.PathLike, io.BytesIO] = '/tmp/session.pkl', # Settings # ############## -def _as_id(item): +def _as_id(item, *, main=_main_module): if isinstance(item, int): - import warnings, __main__ - if not any(id(obj) == item for obj in __main__.__dict__.values()): - warnings.warn("%d isn't the id of any object in __main__ namespace. " - "Did you mean 'id(%d)?'" % (item, item)) + import warnings + if not any(id(obj) == item for obj in main.__dict__.values()): + warnings.warn("%d isn't the id of any object in the '%s' namespace. " + "Did you mean 'id(%d)'?" % (item, main.__name__, item)) return item return id(item) diff --git a/dill/utils.py b/dill/utils.py index 87f63043..33ce4e0f 100644 --- a/dill/utils.py +++ b/dill/utils.py @@ -47,8 +47,8 @@ class TransSet(set): def __init__(self, func: Callable, *args): self.constructor = func super().__init__(*args) - def add(self, item): - super().add(self.constructor(item)) + def add(self, item, **kwargs): + super().add(self.constructor(item, **kwargs)) def discard(self, item): super().discard(self.constructor(item)) def remove(self, item):