Skip to content

Commit

Permalink
Add more tests for coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroSteiner committed Jun 16, 2024
1 parent 40bbe61 commit 21bc606
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
5 changes: 3 additions & 2 deletions lib/rule_engine/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,9 +738,10 @@ def __repr__(self):

def evaluate(self, thing):
container_value = self.container.evaluate(thing)
container_value_type = DataType.from_value(container_value)
member_value = self.member.evaluate(thing)
if DataType.from_value(container_value) == DataType.STRING:
if DataType.from_value(member_value) != DataType.STRING:
if container_value_type == DataType.BYTES or container_value_type == DataType.STRING:
if DataType.from_value(member_value) != container_value_type:
raise errors.EvaluationError('data type mismatch')
return bool(member_value in container_value)

Expand Down
10 changes: 10 additions & 0 deletions tests/ast/expression/function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,13 @@ def _function():
raise SomeException()
with self.assertRaises(errors.EvaluationError):
function_call.evaluate({'function': _function})

def test_ast_expression_function_call_error_on_incompatible_return_type(self):
symbol = ast.SymbolExpression(context, 'function')
function_call = ast.FunctionCallExpression(context, symbol, [])
function_call.result_type = ast.DataType.FUNCTION('function', return_type=ast.DataType.FLOAT)

def _function():
return ''
with self.assertRaises(errors.FunctionCallError):
function_call.evaluate({'function': _function})
4 changes: 3 additions & 1 deletion tests/ast/expression/miscellaneous.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def test_ast_expression_contains_error(self):
class GetItemExpressionTests(unittest.TestCase):
containers = {
types.DataType.ARRAY: ast.LiteralExpressionBase.from_value(context, ['one', 'two']), # ARRAY
types.DataType.BYTES: ast.LiteralExpressionBase.from_value(context, b'Rule Engine!'), # BYTES
types.DataType.MAPPING: ast.LiteralExpressionBase.from_value(context, {'foo': 'bar'}), # MAPPING
types.DataType.STRING: ast.LiteralExpressionBase.from_value(context, 'Rule Engine!') # STRING
}
Expand Down Expand Up @@ -206,8 +207,9 @@ class GetSliceExpressionTests(unittest.TestCase):
def test_ast_expression_getslice(self):
ary_value = tuple(random.choice(string.ascii_letters) for _ in range(12))
str_value = ''.join(ary_value)
byt_value = str_value.encode()
cases = (
(ary_value, str_value),
(ary_value, byt_value, str_value),
(None, 0, 2),
(None, -1, -3),
)
Expand Down

0 comments on commit 21bc606

Please sign in to comment.