Skip to content

Commit

Permalink
v0.1.6
Browse files Browse the repository at this point in the history
  • Loading branch information
yymao committed Dec 29, 2020
1 parent 727c0b4 commit 1d2176a
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 37 deletions.
85 changes: 54 additions & 31 deletions easyquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
NumPy structured arrays, astropy Table, and Pandas DataFrame.
Project website: https://github.com/yymao/easyquery
The MIT License (MIT)
Copyright (c) 2017-2019 Yao-Yuan Mao (yymao)
Copyright (c) 2017-2020 Yao-Yuan Mao (yymao)
http://opensource.org/licenses/MIT
"""

import warnings
import functools
import numpy as np
import numexpr as ne

Expand All @@ -19,7 +20,7 @@


__all__ = ['Query', 'QueryMaker']
__version__ = '0.1.5'
__version__ = '0.1.6'


def _is_string_like(obj):
Expand Down Expand Up @@ -96,27 +97,22 @@ def __init__(self, *queries):
self._operator = 'AND'
self._operands = [self._query_class(query) for query in queries]


@staticmethod
def _get_table_dict(table):
return table


@staticmethod
def _get_table_len(table):
return len(table)


@staticmethod
def _get_table_column(table, column):
return table[column]


@staticmethod
def _mask_table(table, mask_):
return table[mask_]


def _combine_queries(self, other, operator, out=None):
if operator not in {'AND', 'OR', 'XOR'}:
raise ValueError('`operator` must be "AND" or "OR" or "XOR"')
Expand All @@ -143,7 +139,6 @@ def _combine_queries(self, other, operator, out=None):

return out


def __and__(self, other):
return self._combine_queries(other, 'AND')

Expand Down Expand Up @@ -180,28 +175,35 @@ def __invert__(self):

@staticmethod
def _check_basic_query(basic_query):
return basic_query is None or _is_string_like(basic_query) or \
callable(basic_query) or (isinstance(basic_query, tuple) and \
len(basic_query) > 1 and callable(basic_query[0]))

return (
basic_query is None or
_is_string_like(basic_query) or
callable(basic_query) or
(
isinstance(basic_query, tuple) and
len(basic_query) > 1 and
callable(basic_query[0])
)
)

def _create_mask(self, table, basic_query):
if _is_string_like(basic_query):
return ne.evaluate(basic_query,
local_dict=self._get_table_dict(table),
global_dict={})
return ne.evaluate(
basic_query,
local_dict=self._get_table_dict(table),
global_dict={}
)

elif callable(basic_query):
return basic_query(table)

elif isinstance(basic_query, tuple) and len(basic_query) > 1 and callable(basic_query[0]):
return basic_query[0](*(self._get_table_column(table, c) for c in basic_query[1:]))


def mask(self, table):
"""
Use the current Query object to create a mask (a boolean array)
for `table`. Values in the returned mask are determined based on
for `table`. Values in the returned mask are determined based on
whether the corresponding rows satisfy input queries.
Parameters
Expand Down Expand Up @@ -234,15 +236,14 @@ def mask(self, table):

return mask_this


def filter(self, table, column_slice=None):
"""
Use the current Query object to select the rows in `table`
that satisfy input queries.
If `column_slice` is provided, also select on columns.
Equivalent to table[Query(...).mask(table)][column_slice]
but with more efficient implementaion.
but with more efficient implementaion.
Parameters
----------
Expand Down Expand Up @@ -288,7 +289,6 @@ def count(self, table):

return np.count_nonzero(self.mask(table))


def copy(self):
"""
Create a copy of the current Query object.
Expand All @@ -302,7 +302,6 @@ def copy(self):
out._operands = self._operands if self._operator is None else self._operands.copy()
return out


@staticmethod
def _get_variable_names(basic_query):
if _is_string_like(basic_query):
Expand All @@ -315,7 +314,6 @@ def _get_variable_names(basic_query):
elif isinstance(basic_query, tuple) and len(basic_query) > 1 and callable(basic_query[0]):
return tuple(set(basic_query[1:]))


@property
def variable_names(self):
"""
Expand Down Expand Up @@ -350,10 +348,11 @@ def set_query_class(query_class=Query):
"""
if not issubclass(query_class, Query):
raise ValueError('`query_class` must be a subclass of `Query`')
global _query_class
_query_class = query_class


def filter(table, *queries): # pylint: disable=redefined-builtin
def filter(table, *queries): # pylint: disable=redefined-builtin
"""
A convenient function to filter `table` with `queries`.
Equivalent to Query(*queries).filter(table)
Expand Down Expand Up @@ -411,8 +410,12 @@ class QueryMaker():
provides convenience functions to generate query objects
"""
@staticmethod
def in1d(col_name, arr, assume_unique=False, invert=False):
return _query_class((lambda x: np.in1d(x, arr, assume_unique, invert), col_name))
def in1d(col_name, test_elements, assume_unique=False, invert=False):
return _query_class((functools.partial(np.in1d, ar2=test_elements, assume_unique=assume_unique, invert=invert), col_name))

@staticmethod
def isin(col_name, test_elements, assume_unique=False, invert=False):
return _query_class((functools.partial(np.isin, test_elements=test_elements, assume_unique=assume_unique, invert=invert), col_name))

@staticmethod
def vectorize(row_function, *col_names):
Expand All @@ -423,13 +426,33 @@ def contains(col_name, test_value):
return QueryMaker.vectorize((lambda x: test_value in x), col_name)

@staticmethod
def equals(col_name, test_value):
return QueryMaker.vectorize((lambda x: x == test_value), col_name)
def find(col_name, test_value, start=0, end=None):
return _query_class((lambda x: np.char.find(x, test_value, start=start, end=end) > -1, col_name))

contains_str = find

@staticmethod
def equal(col_name, test_value):
return _query_class((lambda x: x == test_value, col_name))

equals = equal

@staticmethod
def not_equal(col_name, test_value):
return _query_class((lambda x: x != test_value, col_name))

@staticmethod
def equal_columns(col1_name, col2_name):
return _query_class((lambda x, y: x == y, col1_name, col2_name))

@staticmethod
def not_equal_columns(col1_name, col2_name):
return _query_class((lambda x, y: x != y, col1_name, col2_name))

@staticmethod
def startswith(col_name, test_value):
return QueryMaker.vectorize((lambda x: x.startswith(test_value)), col_name)
def startswith(col_name, prefix, start=0, end=None):
return _query_class((functools.partial(np.char.startswith, prefix=prefix, start=start, end=end), col_name))

@staticmethod
def endswith(col_name, test_value):
return QueryMaker.vectorize((lambda x: x.endswith(test_value)), col_name)
def endswith(col_name, suffix, start=0, end=None):
return _query_class((functools.partial(np.char.endswith, suffix=suffix, start=start, end=end), col_name))
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
NumPy structured arrays, astropy Table, and Pandas DataFrame.
Project website: https://github.com/yymao/easyquery
The MIT License (MIT)
Copyright (c) 2017-2019 Yao-Yuan Mao (yymao)
Copyright (c) 2017-2020 Yao-Yuan Mao (yymao)
http://opensource.org/licenses/MIT
"""

Expand Down Expand Up @@ -38,6 +38,8 @@
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
keywords='easyquery query numpy',
py_modules=[_name],
Expand Down
29 changes: 24 additions & 5 deletions test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from easyquery import Query
from easyquery import Query, QueryMaker


def test_valid_init():
"""
Expand Down Expand Up @@ -28,13 +29,13 @@ def test_invalid_init():
"""
test invalid Query object creation
"""
for q in (1, [lambda x: x>1, 'a'], (lambda x: x>1,), ('a', lambda x: x>1)):
for q in (1, [lambda x: x > 1, 'a'], (lambda x: x > 1,), ('a', lambda x: x > 1)):
check_invalid_init(q)


def gen_test_table():
return np.array([(1, 5, 4.5), (1, 1, 6.2), (3, 2, 0.5), (5, 5, -3.5)],
dtype=np.dtype([('a', '<i8'), ('b', '<i8'), ('c', '<f8')]))
return np.array([(1, 5, 4.5, "abcd"), (1, 1, 6.2, "pqrs"), (3, 2, 0.5, "asdf"), (5, 5, -3.5, "wxyz")],
dtype=np.dtype([('a', '<i8'), ('b', '<i8'), ('c', '<f8'), ('s', '<U4')]))


def check_query_on_table(table, query_object, true_mask=None):
Expand All @@ -53,7 +54,7 @@ def check_query_on_dict_table(table, query_object, true_mask=None):
ftable = query_object.filter(table)
ftable_true = {k: table[k][true_mask] for k in table}
assert set(ftable) == set(ftable_true), 'filter not correct'
assert all((ftable[k]==ftable_true[k]).all() for k in ftable), 'filter not correct'
assert all((ftable[k] == ftable_true[k]).all() for k in ftable), 'filter not correct'
assert query_object.count(table) == np.count_nonzero(true_mask), 'count not correct'
assert (query_object.mask(table) == true_mask).all(), 'mask not correct'

Expand Down Expand Up @@ -154,6 +155,23 @@ def test_filter_column_slice():
assert (q.filter(t, 'a') == t['a']).all()


def test_query_maker():
t = gen_test_table()
test_elements = [1, 2, 3]
check_query_on_table(t, QueryMaker.in1d("a", test_elements), np.in1d(t["a"], test_elements))
check_query_on_table(t, QueryMaker.isin("a", test_elements), np.isin(t["a"], test_elements))

check_query_on_table(t, QueryMaker.equal("s", "abcd"), t["s"] == "abcd")
check_query_on_table(t, QueryMaker.equals("s", "abcd"), t["s"] == "abcd")
check_query_on_table(t, QueryMaker.not_equal("s", "abcd"), t["s"] != "abcd")
check_query_on_table(t, QueryMaker.startswith("s", "a"), np.char.startswith(t["s"], "a"))
check_query_on_table(t, QueryMaker.endswith("s", "s"), np.char.endswith(t["s"], "s"))

check_query_on_table(t, QueryMaker.contains("s", "a"), np.char.find(t["s"], "a") > -1)
check_query_on_table(t, QueryMaker.find("s", "a"), np.char.find(t["s"], "a") > -1)

assert QueryMaker.equal_columns("s", "s").mask(t).all()

if __name__ == '__main__':
test_valid_init()
test_invalid_init()
Expand All @@ -162,3 +180,4 @@ def test_filter_column_slice():
test_derive_class()
test_variable_names()
test_filter_column_slice()
test_query_maker()

0 comments on commit 1d2176a

Please sign in to comment.