Skip to content

Commit

Permalink
tests: FilterSet methods; added extra filter type check
Browse files Browse the repository at this point in the history
  • Loading branch information
leogama committed Sep 28, 2022
1 parent 4244d3b commit 78f5e2d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 5 deletions.
9 changes: 5 additions & 4 deletions dill/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,12 @@ def __repr__(self):
Filter = Union[str, Pattern[str], int, type, FilterFunction]
Rule = Tuple[RuleType, Union[Filter, Iterable[Filter]]]

def _iter(filters):
if isinstance(filters, str):
def _iter(obj):
"""return iterator of object if it's not a string"""
if isinstance(obj, (str, bytes)):
return None
try:
return iter(filters)
return iter(obj)
except TypeError:
return None

Expand Down Expand Up @@ -199,7 +200,7 @@ def _match_type(self, filter: Filter) -> Tuple[filter, str]:
else:
filter = re.compile(filter)
field = 'regexes'
elif filter_type is re.Pattern:
elif filter_type is re.Pattern and type(filter.pattern) is str:
field = 'regexes'
elif filter_type is int:
field = 'ids'
Expand Down
58 changes: 57 additions & 1 deletion dill/tests/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,64 @@
import dill
from dill import _dill
from dill.session import (
EXCLUDE, INCLUDE, FilterRules, RuleType, ipython_filter, size_filter, settings
EXCLUDE, INCLUDE, FilterRules, FilterSet, RuleType, ipython_filter, size_filter, settings
)

def test_filterset():
import re

name = 'test'
regex1 = re.compile(r'\w+\d+')
regex2 = r'_\w+'
id_ = id(FilterSet)
type1 = FilterSet
type2 = 'type:List'
func = lambda obj: obj.name == 'Arthur'

empty_filters = FilterSet()
assert bool(empty_filters) is False
assert len(empty_filters) == 0
assert len([*empty_filters]) == 0

# also tests add() and __ior__() for non-FilterSet other
filters = FilterSet._from_iterable([name, regex1, regex2, id_, type1, type2, func])
assert filters.names == {name}
assert filters.regexes == {regex1, re.compile(regex2)}
assert filters.ids == {id_}
assert filters.types == {type1, list}
assert filters.funcs == {func}

assert bool(filters) is True
assert len(filters) == 7
assert all(x in filters for x in [name, regex1, id_, type1, func])

try:
filters.add(re.compile(b'an 8-bit string regex'))
except ValueError:
pass
else:
raise AssertionError("adding invalid filter should raise error")

filters_copy = filters.copy()
for field in FilterSet._fields:
original, copy = getattr(filters, field), getattr(filters_copy, field)
assert copy is not original
assert copy == original

filters.remove(re.compile(regex2))
assert filters.regexes == {regex1}
filters.discard(list)
filters.discard(list) # should not raise error
assert filters.types == {type1}
assert [*filters] == [name, regex1, id_, type1, func]

# also tests __ior__() for FilterSet other
filters.update(filters_copy)
assert filters.types == {type1, list}

filters.clear()
assert len(filters) == 0

NS = {
'a': 1,
'aa': 2,
Expand Down Expand Up @@ -184,6 +239,7 @@ def test_size_filter():
assert did_exclude(NS_copy, filter_size, excluded_subset={'large'})

if __name__ == '__main__':
test_filterset()
test_basic_filtering()
test_exclude_include()
test_add_type()
Expand Down

0 comments on commit 78f5e2d

Please sign in to comment.