diff --git a/dill/_utils.py b/dill/_utils.py index 0d9d86e2..0aaf65a5 100644 --- a/dill/_utils.py +++ b/dill/_utils.py @@ -166,11 +166,12 @@ def __repr__(self): Filter = Union[str, Pattern[str], int, type, FilterFunction] Rule = Tuple[RuleType, Union[Filter, Iterable[Filter]]] -def _iter(filters): - if isinstance(filters, str): +def _iter(obj): + """return iterator of object if it's not a string""" + if isinstance(obj, (str, bytes)): return None try: - return iter(filters) + return iter(obj) except TypeError: return None @@ -199,7 +200,7 @@ def _match_type(self, filter: Filter) -> Tuple[filter, str]: else: filter = re.compile(filter) field = 'regexes' - elif filter_type is re.Pattern: + elif filter_type is re.Pattern and type(filter.pattern) is str: field = 'regexes' elif filter_type is int: field = 'ids' diff --git a/dill/tests/test_filtering.py b/dill/tests/test_filtering.py index 259b5fd6..3bcc0c9c 100644 --- a/dill/tests/test_filtering.py +++ b/dill/tests/test_filtering.py @@ -11,9 +11,64 @@ import dill from dill import _dill from dill.session import ( - EXCLUDE, INCLUDE, FilterRules, RuleType, ipython_filter, size_filter, settings + EXCLUDE, INCLUDE, FilterRules, FilterSet, RuleType, ipython_filter, size_filter, settings ) +def test_filterset(): + import re + + name = 'test' + regex1 = re.compile(r'\w+\d+') + regex2 = r'_\w+' + id_ = id(FilterSet) + type1 = FilterSet + type2 = 'type:List' + func = lambda obj: obj.name == 'Arthur' + + empty_filters = FilterSet() + assert bool(empty_filters) is False + assert len(empty_filters) == 0 + assert len([*empty_filters]) == 0 + + # also tests add() and __ior__() for non-FilterSet other + filters = FilterSet._from_iterable([name, regex1, regex2, id_, type1, type2, func]) + assert filters.names == {name} + assert filters.regexes == {regex1, re.compile(regex2)} + assert filters.ids == {id_} + assert filters.types == {type1, list} + assert filters.funcs == {func} + + assert bool(filters) is True + assert len(filters) == 7 + assert all(x in filters for x in [name, regex1, id_, type1, func]) + + try: + filters.add(re.compile(b'an 8-bit string regex')) + except ValueError: + pass + else: + raise AssertionError("adding invalid filter should raise error") + + filters_copy = filters.copy() + for field in FilterSet._fields: + original, copy = getattr(filters, field), getattr(filters_copy, field) + assert copy is not original + assert copy == original + + filters.remove(re.compile(regex2)) + assert filters.regexes == {regex1} + filters.discard(list) + filters.discard(list) # should not raise error + assert filters.types == {type1} + assert [*filters] == [name, regex1, id_, type1, func] + + # also tests __ior__() for FilterSet other + filters.update(filters_copy) + assert filters.types == {type1, list} + + filters.clear() + assert len(filters) == 0 + NS = { 'a': 1, 'aa': 2, @@ -184,6 +239,7 @@ def test_size_filter(): assert did_exclude(NS_copy, filter_size, excluded_subset={'large'}) if __name__ == '__main__': + test_filterset() test_basic_filtering() test_exclude_include() test_add_type()