Skip to content

Commit

Permalink
bug fix for subclass
Browse files Browse the repository at this point in the history
  • Loading branch information
yymao committed Nov 12, 2017
1 parent 5b9906f commit 74b5d60
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 44 deletions.
42 changes: 27 additions & 15 deletions easyquery/__init__.py → easyquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
NumPy structured arrays, astropy Table, and Pandas DataFrame.
Project website: https://github.com/yymao/easyquery
The MIT License (MIT)
Copyright (c) 2015-2017 Yao-Yuan Mao (yymao)
Copyright (c) 2017 Yao-Yuan Mao (yymao)
http://opensource.org/licenses/MIT
"""

if not hasattr(list, 'copy'):
from builtins import list
try:
from builtins import list
except ImportError:
raise ImportError('Please install python package "future"')
import numpy as np
import numexpr as ne

__all__ = ['Query', 'filter', 'count', 'mask']
__version__ = '0.1.0'
__all__ = ['Query']
__version__ = '0.1.2'

def _is_string_like(obj):
"""
Expand Down Expand Up @@ -69,10 +73,11 @@ class Query(object):
def __init__(self, *queries):
self._operator = None
self._operands = None
self._query_class = type(self)

if len(queries) == 1:
query = queries[0]
if isinstance(query, Query):
if isinstance(query, self._query_class):
self._operator = query._operator
self._operands = query._operands if query._operator is None else query._operands.copy()
else:
Expand All @@ -82,7 +87,7 @@ def __init__(self, *queries):

elif len(queries) > 1:
self._operator = 'AND'
self._operands = [Query(query) for query in queries]
self._operands = [self._query_class(query) for query in queries]


@staticmethod
Expand All @@ -109,11 +114,11 @@ def _combine_queries(self, other, operator, out=None):
if operator not in {'AND', 'OR', 'XOR'}:
raise ValueError('`operator` must be "AND" or "OR" or "XOR"')

if not isinstance(other, Query):
other = Query(other)
if not isinstance(other, self._query_class):
other = self._query_class(other)

if out is None:
out = Query()
out = self._query_class()

out._operator = operator

Expand Down Expand Up @@ -157,7 +162,7 @@ def __invert__(self):
if self._operator == 'NOT':
return self._operands.copy()
else:
out = Query()
out = self._query_class()
out._operator = 'NOT'
out._operands = self
return out
Expand Down Expand Up @@ -273,13 +278,20 @@ def copy(self):
-------
out : Query object
"""
out = Query()
out = self._query_class()
out._operator = self._operator
out._operands = self._operands if self._operator is None else self._operands.copy()
return out


_Query_Class = Query

_query_class = Query


def set_query_class(query_class):
assert issubclass(query_class, Query)
_query_class = query_class


def filter(table, *queries):
"""
Expand All @@ -296,7 +308,7 @@ def filter(table, *queries):
-------
table : filtered table
"""
return _Query_Class(*queries).filter(table)
return _query_class(*queries).filter(table)


def count(table, *queries):
Expand All @@ -315,7 +327,7 @@ def count(table, *queries):
-------
count : int
"""
return _Query_Class(*queries).count(table)
return _query_class(*queries).count(table)


def mask(table, *queries):
Expand All @@ -334,4 +346,4 @@ def mask(table, *queries):
-------
mask : numpy bool array
"""
return _Query_Class(*queries).mask(table)
return _query_class(*queries).mask(table)
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@

setup(
name='easyquery',
version='0.1.1',
version='0.1.2',
description='Create easy-to-use Query objects that can apply on NumPy structured arrays, astropy Table, and Pandas DataFrame.',
url='https://github.com/yymao/easyquery',
download_url = 'https://github.com/yymao/easyquery/archive/v0.1.2.zip',
author='Yao-Yuan Mao',
author_email='[email protected]',
maintainer='Yao-Yuan Mao',
Expand All @@ -28,6 +29,6 @@
'Programming Language :: Python :: 3.6',
],
keywords='easyquery query numpy',
packages=['easyquery'],
py_modules=['easyquery'],
install_requires=['numpy', 'numexpr'],
)
91 changes: 64 additions & 27 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_valid_init():
q8 = Query(q3, 'x > 2')


def _check_invalid_init(*queries):
def check_invalid_init(*queries):
try:
q = Query(*queries)
except ValueError:
Expand All @@ -29,15 +29,15 @@ def test_invalid_init():
test invalid Query object creation
"""
for q in (1, [lambda x: x>1, 'a'], (lambda x: x>1,), ('a', lambda x: x>1)):
_check_invalid_init(q)
check_invalid_init(q)


def _gen_test_table():
def gen_test_table():
return np.array([(1, 5, 4.5), (1, 1, 6.2), (3, 2, 0.5), (5, 5, -3.5)],
dtype=np.dtype([('a', '<i8'), ('b', '<i8'), ('c', '<f8')]))


def _check_query_on_table(table, query_object, true_mask=None):
def check_query_on_table(table, query_object, true_mask=None):
if true_mask is None:
true_mask = np.ones(len(table), np.bool)

Expand All @@ -46,32 +46,40 @@ def _check_query_on_table(table, query_object, true_mask=None):
assert (query_object.mask(table) == true_mask).all(), 'mask not correct'


def check_query_on_dict_table(table, query_object, true_mask=None):
if true_mask is None:
true_mask = np.ones(len(next(table.values())), np.bool)

ftable = query_object.filter(table)
ftable_true = {k: table[k][true_mask] for k in table}
assert set(ftable) == set(ftable_true), 'filter not correct'
assert all((ftable[k]==ftable_true[k]).all() for k in ftable), 'filter not correct'
assert query_object.count(table) == np.count_nonzero(true_mask), 'count not correct'
assert (query_object.mask(table) == true_mask).all(), 'mask not correct'


def test_simple_query():
"""
test simple queries
"""
t = _gen_test_table()
_check_query_on_table(t, Query(), None)
_check_query_on_table(t, Query('a > 3'), t['a'] > 3)
_check_query_on_table(t, Query('a == 100'), t['a'] == 100)
_check_query_on_table(t, Query('b > c'), t['b'] > t['c'])
_check_query_on_table(t, Query('a < 3', 'b > c'), (t['a'] < 3) & (t['b'] > t['c']))
t = gen_test_table()
check_query_on_table(t, Query(), None)
check_query_on_table(t, Query('a > 3'), t['a'] > 3)
check_query_on_table(t, Query('a == 100'), t['a'] == 100)
check_query_on_table(t, Query('b > c'), t['b'] > t['c'])
check_query_on_table(t, Query('a < 3', 'b > c'), (t['a'] < 3) & (t['b'] > t['c']))


def test_compound_query():
"""
test compound queries
"""
t = _gen_test_table()
q1 = Query('a == 1')
def do_compound_query(t, query_class, check_query):
q1 = query_class('a == 1')
m1 = t['a'] == 1
q2 = Query('a == b')
q2 = query_class('a == b')
m2 = t['a'] == t['b']
q3 = Query('b > c')
q3 = 'b > c'
m3 = t['b'] > t['c']

q4 = ~~q3
m4 = ~~m3
q4 = ~~q2
m4 = ~~m2
q5 = q1 & q2 | q3
m5 = m1 & m2 | m3
q6 = ~q1 | q2 ^ q3
Expand All @@ -80,18 +88,47 @@ def test_compound_query():
m7 = m5 ^ m6
q7 |= q2
m7 |= m2
q8 = q3 | q4
m8 = m3 | m4
q9 = q5.copy()
m9 = m5

check_query(t, q1, m1)
check_query(t, q2, m2)
check_query(t, q4, m4)
check_query(t, q5, m5)
check_query(t, q6, m6)
check_query(t, q7, m7)
check_query(t, q8, m8)
check_query(t, q9, m9)


def test_compound_query():
"""
test compound queries
"""
do_compound_query(gen_test_table(), Query, check_query_on_table)


class DictQuery(Query):
@staticmethod
def _get_table_len(table):
return len(next(table.values()))

@staticmethod
def _mask_table(table, mask):
return {k: v[mask] for k, v in table.items()}


_check_query_on_table(t, q1, m1)
_check_query_on_table(t, q2, m2)
_check_query_on_table(t, q3, m3)
_check_query_on_table(t, q4, m4)
_check_query_on_table(t, q5, m5)
_check_query_on_table(t, q6, m6)
_check_query_on_table(t, q7, m7)
def test_derive_class():
t = gen_test_table()
t = {k: t[k] for k in t.dtype.names}
do_compound_query(t, DictQuery, check_query_on_dict_table)


if __name__ == '__main__':
test_valid_init()
test_invalid_init()
test_simple_query()
test_compound_query()
test_derive_class()

0 comments on commit 74b5d60

Please sign in to comment.