Skip to content

Commit

Permalink
Add a bunch of BYTES unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroSteiner committed Jun 15, 2024
1 parent 73674f8 commit 40bbe61
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 14 deletions.
21 changes: 15 additions & 6 deletions lib/rule_engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,10 @@ def _bytes_decode(self, value, encoding):
return binascii.b2a_hex(value).decode()
elif encoding == 'base64':
return binascii.b2a_base64(value).decode().strip()
return value.decode(encoding)
try:
return value.decode(encoding)
except LookupError as error:
raise errors.FunctionCallError("invalid encoding name {}".format(encoding), error=error, function_name='decode')

@attribute('to_epoch', ast.DataType.DATETIME, result_type=ast.DataType.FLOAT)
def datetime_to_epoch(self, value):
Expand Down Expand Up @@ -302,11 +305,17 @@ def string_encode(self, value):
@classmethod
def _string_encode(self, value, encoding):
encoding = encoding.lower()
if encoding == 'base16' or encoding == 'hex':
return binascii.a2b_hex(value.encode())
elif encoding == 'base64':
return binascii.a2b_base64(value.encode())
return value.encode(encoding)
try:
if encoding == 'base16' or encoding == 'hex':
return binascii.a2b_hex(value.encode())
elif encoding == 'base64':
return binascii.a2b_base64(value.encode())
except binascii.Error as error:
raise errors.FunctionCallError("error converting to {}".format(encoding), error=error, function_name='encode')
try:
return value.encode(encoding)
except LookupError as error:
raise errors.FunctionCallError("invalid encoding name {}".format(encoding), error=error, function_name='encode')

@attribute('to_ary', ast.DataType.STRING, result_type=ast.DataType.ARRAY(ast.DataType.STRING))
def string_to_ary(self, value):
Expand Down
50 changes: 49 additions & 1 deletion tests/ast/expression/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,35 @@ def test_ast_expression_array_attributes(self):
expression = ast.GetAttributeExpression(typed_context, typed_symbol, 'to_ary')
self.assertEqual(expression.result_type, typed_context.resolve_type(symbol.name))

def test_ast_expression_bytes_attributes(self):
value = b'Rule Engine'
symbol = ast.BytesExpression(context, value)

attributes = {
'to_set': set(value),
'length': len(value),
'is_empty': False
}
for attribute_name, value in attributes.items():
expression = ast.GetAttributeExpression(context, symbol, attribute_name)
self.assertEqual(expression.evaluate(None), value, "attribute {} failed".format(attribute_name))

def test_ast_expression_bytes_method_decode(self):
combos = [
('utf-8', 'Rule Engine'),
('hex', '52756c6520456e67696e65'),
('base16', '52756c6520456e67696e65'),
('base64', 'UnVsZSBFbmdpbmU=')
]
for encoding, string in combos:
bytes_expression = ast.BytesExpression(context, b'Rule Engine')
expression = ast.GetAttributeExpression(context, bytes_expression, 'decode')
method = expression.evaluate(None)
self.assertTrue(callable(method), "attribute decode failed (method not callable)")
self.assertEqual(method(encoding), string)
with self.assertRaises(errors.FunctionCallError):
method('invalid-encoding')

def test_ast_expression_datetime_attributes(self):
timestamp = datetime.datetime(2019, 9, 11, 20, 46, 57, 506406, tzinfo=dateutil.tz.UTC)
symbol = ast.DatetimeExpression(context, timestamp)
Expand Down Expand Up @@ -216,12 +245,31 @@ def test_ast_expression_string_attributes(self):
'to_ary': tuple(string),
'to_set': set(string),
'to_str': string,
'length': len(string)
'length': len(string),
'is_empty': False
}
for attribute_name, value in attributes.items():
expression = ast.GetAttributeExpression(context, symbol, attribute_name)
self.assertEqual(expression.evaluate(None), value, "attribute {} failed".format(attribute_name))

def test_ast_expression_string_method_encode(self):
combos = [
('utf-8', 'Rule Engine'),
('hex', '52756c6520456e67696e65'),
('base16', '52756c6520456e67696e65'),
('base64', 'UnVsZSBFbmdpbmU=')
]
for encoding, string in combos:
string_expression = ast.StringExpression(context, string)
expression = ast.GetAttributeExpression(context, string_expression, 'encode')
method = expression.evaluate(None)
self.assertTrue(callable(method), "attribute encode failed (method not callable)")
self.assertEqual(method(encoding), b'Rule Engine')
with self.assertRaises(errors.FunctionCallError):
method('invalid-encoding')
with self.assertRaises(errors.FunctionCallError):
method('base16') # last one is base64 so this should fail

def test_ast_expression_string_attributes_flt(self):
combos = (
('3.14159', decimal.Decimal('3.14159')),
Expand Down
17 changes: 10 additions & 7 deletions tests/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,15 @@ def test_data_type_equality_set(self):
def test_data_type_from_name(self):
self.assertIs(DataType.from_name('ARRAY'), DataType.ARRAY)
self.assertIs(DataType.from_name('BOOLEAN'), DataType.BOOLEAN)
self.assertIs(DataType.from_name('BYTES'), DataType.BYTES)
self.assertIs(DataType.from_name('DATETIME'), DataType.DATETIME)
self.assertIs(DataType.from_name('TIMEDELTA'), DataType.TIMEDELTA)
self.assertIs(DataType.from_name('FLOAT'), DataType.FLOAT)
self.assertIs(DataType.from_name('FUNCTION'), DataType.FUNCTION)
self.assertIs(DataType.from_name('MAPPING'), DataType.MAPPING)
self.assertIs(DataType.from_name('NULL'), DataType.NULL)
self.assertIs(DataType.from_name('SET'), DataType.SET)
self.assertIs(DataType.from_name('STRING'), DataType.STRING)
self.assertIs(DataType.from_name('FUNCTION'), DataType.FUNCTION)
self.assertIs(DataType.from_name('TIMEDELTA'), DataType.TIMEDELTA)

def test_data_type_from_name_error(self):
with self.assertRaises(TypeError):
Expand All @@ -103,16 +104,17 @@ def test_data_type_from_type(self):
self.assertIs(DataType.from_type(list), DataType.ARRAY)
self.assertIs(DataType.from_type(tuple), DataType.ARRAY)
self.assertIs(DataType.from_type(bool), DataType.BOOLEAN)
self.assertIs(DataType.from_type(bytes), DataType.BYTES)
self.assertIs(DataType.from_type(datetime.date), DataType.DATETIME)
self.assertIs(DataType.from_type(datetime.datetime), DataType.DATETIME)
self.assertIs(DataType.from_type(datetime.timedelta), DataType.TIMEDELTA)
self.assertIs(DataType.from_type(float), DataType.FLOAT)
self.assertIs(DataType.from_type(int), DataType.FLOAT)
self.assertIs(DataType.from_type(type(lambda: None)), DataType.FUNCTION)
self.assertIs(DataType.from_type(dict), DataType.MAPPING)
self.assertIs(DataType.from_type(type(None)), DataType.NULL)
self.assertIs(DataType.from_type(set), DataType.SET)
self.assertIs(DataType.from_type(str), DataType.STRING)
self.assertIs(DataType.from_type(type(lambda: None)), DataType.FUNCTION)
self.assertIs(DataType.from_type(datetime.timedelta), DataType.TIMEDELTA)

def test_data_type_from_type_hint(self):
# simple compound tests
Expand Down Expand Up @@ -183,15 +185,16 @@ def test_data_type_from_value_compound_set(self):

def test_data_type_from_value_scalar(self):
self.assertIs(DataType.from_value(False), DataType.BOOLEAN)
self.assertIs(DataType.from_value(b''), DataType.BYTES)
self.assertIs(DataType.from_value(datetime.date.today()), DataType.DATETIME)
self.assertIs(DataType.from_value(datetime.datetime.now()), DataType.DATETIME)
self.assertIs(DataType.from_value(datetime.timedelta()), DataType.TIMEDELTA)
self.assertIs(DataType.from_value(0), DataType.FLOAT)
self.assertIs(DataType.from_value(0.0), DataType.FLOAT)
self.assertIs(DataType.from_value(None), DataType.NULL)
self.assertIs(DataType.from_value(''), DataType.STRING)
self.assertIs(DataType.from_value(lambda: None), DataType.FUNCTION)
self.assertIs(DataType.from_value(print), DataType.FUNCTION)
self.assertIs(DataType.from_value(None), DataType.NULL)
self.assertIs(DataType.from_value(''), DataType.STRING)
self.assertIs(DataType.from_value(datetime.timedelta()), DataType.TIMEDELTA)

def test_data_type_from_value_error(self):
with self.assertRaisesRegex(TypeError, r'^can not map python type \'_UnsupportedType\' to a compatible data type$'):
Expand Down

0 comments on commit 40bbe61

Please sign in to comment.