From 218fd74b9df46690f124c5a128d5b5dd49bace66 Mon Sep 17 00:00:00 2001 From: Spencer McIntyre Date: Sun, 9 Jun 2024 10:43:33 -0400 Subject: [PATCH] Add BYTES support to some expressions --- lib/rule_engine/ast.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/lib/rule_engine/ast.py b/lib/rule_engine/ast.py index 75e33ef..56e1c1b 100644 --- a/lib/rule_engine/ast.py +++ b/lib/rule_engine/ast.py @@ -42,6 +42,10 @@ from .suggestions import suggest_symbol from .types import * +def _assert_is_bytes(*values): + if not all(map(isinstance, values, [bytes])): + raise errors.EvaluationError('data type mismatch (not a bytes value)') + def _assert_is_integer_number(*values): if not all(map(is_integer_number, values)): raise errors.EvaluationError('data type mismatch (not an integer number)') @@ -401,7 +405,7 @@ def to_graphviz(self, digraph, *args, **kwargs): class AddExpression(LeftOperatorRightExpressionBase): """A class for representing addition expressions from the grammar text.""" - compatible_types = (DataType.FLOAT, DataType.STRING, DataType.DATETIME, DataType.TIMEDELTA) + compatible_types = (DataType.BYTES, DataType.FLOAT, DataType.STRING, DataType.DATETIME, DataType.TIMEDELTA) result_type = DataType.UNDEFINED def __init__(self, *args, **kwargs): @@ -429,6 +433,8 @@ def _op_add(self, thing): elif isinstance(left_value, datetime.timedelta): if not isinstance(right_value, (datetime.timedelta, datetime.datetime)): raise errors.EvaluationError('data type mismatch (not a datetime or timedelta value)') + elif isinstance(left_value, bytes) or isinstance(right_value, bytes): + _assert_is_bytes(left_value, right_value) elif isinstance(left_value, str) or isinstance(right_value, str): _assert_is_string(left_value, right_value) else: @@ -711,8 +717,8 @@ class ContainsExpression(ExpressionBase): __slots__ = ('container', 'member') result_type = DataType.BOOLEAN def __init__(self, context, container, member): - if container.result_type == DataType.STRING: - if member.result_type != DataType.UNDEFINED and member.result_type != DataType.STRING: + if container.result_type == DataType.BYTES or container.result_type == DataType.STRING: + if member.result_type != DataType.UNDEFINED and member.result_type != container.result_type: raise errors.EvaluationError('data type mismatch') elif container.result_type != DataType.UNDEFINED and container.result_type.is_scalar: raise errors.EvaluationError('data type mismatch') @@ -835,7 +841,11 @@ def __init__(self, context, container, item, safe=False): """ self.context = context self.container = container - if container.result_type == DataType.STRING: + if container.result_type == DataType.BYTES: + if not DataType.is_compatible(item.result_type, DataType.FLOAT): + raise errors.EvaluationError('data type mismatch (not an integer number)') + self.result_type = DataType.FLOAT + elif container.result_type == DataType.STRING: if not DataType.is_compatible(item.result_type, DataType.FLOAT): raise errors.EvaluationError('data type mismatch (not an integer number)') self.result_type = DataType.STRING @@ -872,7 +882,7 @@ def evaluate(self, thing): raise errors.EvaluationError('data type mismatch (container is null)') resolved_item = self.item.evaluate(thing) - if isinstance(resolved_obj, (str, tuple)): + if isinstance(resolved_obj, (bytes, str, tuple)): _assert_is_integer_number(resolved_item) resolved_item = int(resolved_item) try: @@ -915,7 +925,9 @@ def __init__(self, context, container, start=None, stop=None, safe=False): """ self.context = context self.container = container - if container.result_type == DataType.STRING: + if container.result_type == DataType.BYTES: + self.result_type = DataType.BYTES + elif container.result_type == DataType.STRING: self.result_type = DataType.STRING # check against __class__ so the parent class is dynamic in case it changes in the future, what we're doing here # is explicitly checking if result_type is an array with out checking the value_type