From 26c7b72438853d13b063ded5b9db63f6c3546660 Mon Sep 17 00:00:00 2001 From: Leonardo Gama Date: Thu, 12 May 2022 16:51:15 -0300 Subject: [PATCH] Session: check id against module being saved --- dill/session.py | 14 +++++++------- dill/utils.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/dill/session.py b/dill/session.py index b6ebd4fc..9adf605c 100644 --- a/dill/session.py +++ b/dill/session.py @@ -17,7 +17,7 @@ import dill from dill import Pickler, Unpickler -from ._dill import ModuleType, _import_module, _is_builtin_module +from ._dill import ModuleType, _import_module, _is_builtin_module, _main_module from .utils import AttrDict, CheckerSet, TransSet from .settings import settings @@ -117,7 +117,7 @@ def _exclude_objs(main, exclude_extra, filters_extra, settings): for item in exclude_extra: for category, klass in categories.items(): if isinstance(item, klass): - exclude[category].add(item) + exclude[category].add(item, main=main) break else: raise ValueError("bad value type for 'exclude' parameter: %r" % item) @@ -220,21 +220,21 @@ def load_session(filename: Union[os.PathLike, io.BytesIO] = '/tmp/session.pkl', # Settings # ############## -def _as_id(item): +def _as_id(item, *, main=_main_module): if isinstance(item, int): - import warnings, __main__ - if not any(id(obj) == item for obj in __main__.__dict__.values()): + import warnings + if not any(id(obj) == item for obj in main.__dict__.values()): warnings.warn("%d isn't the id of any object in __main__ namespace. " "Did you mean 'id(%d)?'" % (item, item)) return item return id(item) -def _as_regex(item): +def _as_regex(item, **kwargs): if isinstance(item, re.Pattern): return item return re.compile(item) -def _as_type(item): +def _as_type(item, **kwargs): if isinstance(item, str): import types if hasattr(types, item + 'Type'): diff --git a/dill/utils.py b/dill/utils.py index 87f63043..6d1bcd34 100644 --- a/dill/utils.py +++ b/dill/utils.py @@ -47,8 +47,8 @@ class TransSet(set): def __init__(self, func: Callable, *args): self.constructor = func super().__init__(*args) - def add(self, item): - super().add(self.constructor(item)) + def add(self, item, *, **kwags): + super().add(self.constructor(item, **kwargs)) def discard(self, item): super().discard(self.constructor(item)) def remove(self, item):