From dac6f3d6673bc189759c350b200f414206887a6b Mon Sep 17 00:00:00 2001 From: Jonathan Daniel <36337649+jond01@users.noreply.github.com> Date: Tue, 17 Sep 2024 12:45:11 +0300 Subject: [PATCH] Use parameter binding in TDEngine target insertions (#536) * Small changes * Set pytest marker at the module level https://docs.pytest.org/en/stable/example/markers.html#marking-whole-classes-or-modules * Add type hints * Upgrade pytest * Use parameter binding in TDEngine INSERT * Upgrade pytest-benchmark * Fix `pytest.skip` usage https://docs.pytest.org/en/stable/how-to/skipping.html * Revert "Fix `pytest.skip` usage" This reverts commit b6cbb804579f23791b6d9a55a3ede80403d7c447. * Revert "Upgrade pytest-benchmark" This reverts commit c7894bb7e7c0fe81011144014659c1f797dccb1d. * Revert "Upgrade pytest" This reverts commit 5576ab8087425f14698cbe47d502c2174673ce88. * Suppress type hint for old pytest * Remove a redundant parameter * Move error class to dtypes * Rename fun -> func * val_names -> regular_column_names * Add `_TDEngineField` named tuple * Add `_to_tag` and `_to_column` mappings * Rename `_TDEngineField` -> `_TDEngineFieldData` * Get TDEngine schema from table/super-table * format * Improve type hints * empty * Call `DESCRIBE` once in the `_init` method * Rename "field data" to "field" * Check first instead of try-except * Validate DB and table names * Improve ms comment * Rewrite the test into `test_get_table_schema` --- integration/test_tdengine.py | 80 +++++++++++---- storey/dtypes.py | 23 ++++- storey/targets.py | 182 +++++++++++++++++++++++++++-------- tests/test_targets.py | 106 ++++++++++++++++++++ tests/test_types.py | 2 +- 5 files changed, 333 insertions(+), 60 deletions(-) create mode 100644 tests/test_targets.py diff --git a/integration/test_tdengine.py b/integration/test_tdengine.py index e03f49e6..10f5e9c3 100644 --- a/integration/test_tdengine.py +++ b/integration/test_tdengine.py @@ -1,33 +1,33 @@ import os -from datetime import datetime +from collections.abc import Iterator +from datetime import datetime, timezone +from typing import Optional import pytest -import pytz import taosws from storey import SyncEmitSource, build_flow from storey.targets import TDEngineTarget -url = os.getenv("TDENGINE_URL") +url = os.getenv("TDENGINE_URL") # e.g.: taosws://root:taosdata@localhost:6041 user = os.getenv("TDENGINE_USER") password = os.getenv("TDENGINE_PASSWORD") -has_tdengine_credentials = all([url, user, password]) or (url and url.startswith("taosws")) +has_tdengine_credentials = all([url, user, password]) or (url and url.startswith("taosws://")) +pytestmark = pytest.mark.skipif(not has_tdengine_credentials, reason="Missing TDEngine URL, user, and/or password") -@pytest.fixture() -def tdengine(): +TDEngineData = tuple[taosws.Connection, str, Optional[str], Optional[str], str, str] + + +@pytest.fixture(params=[10]) +def tdengine(request: "pytest.FixtureRequest") -> Iterator[TDEngineData]: db_name = "storey" supertable_name = "test_supertable" - if url.startswith("taosws"): + if url.startswith("taosws://"): connection = taosws.connect(url) else: - - connection = taosws.connect( - url=url, - user=user, - password=password, - ) + connection = taosws.connect(url=url, user=user, password=password) try: connection.execute(f"DROP DATABASE {db_name};") @@ -44,7 +44,9 @@ def tdengine(): if "STable not exist" not in str(err): raise err - connection.execute(f"CREATE STABLE {supertable_name} (time TIMESTAMP, my_string NCHAR(10)) TAGS (my_int INT);") + connection.execute( + f"CREATE STABLE {supertable_name} (time TIMESTAMP, my_string NCHAR({request.param})) TAGS (my_int INT);" + ) # Test runs yield connection, url, user, password, db_name, supertable_name @@ -55,8 +57,7 @@ def tdengine(): @pytest.mark.parametrize("table_col", [None, "$key", "table"]) -@pytest.mark.skipif(not has_tdengine_credentials, reason="Missing TDEngine URL, user, and/or password") -def test_tdengine_target(tdengine, table_col): +def test_tdengine_target(tdengine: TDEngineData, table_col: Optional[str]) -> None: connection, url, user, password, db_name, supertable_name = tdengine time_format = "%d/%m/%y %H:%M:%S UTC%z" @@ -116,7 +117,7 @@ def test_tdengine_target(tdengine, table_col): if typ == "TIMESTAMP": t = datetime.fromisoformat(row[field_index]) # websocket returns a timestamp with the local time zone - t = t.astimezone(pytz.UTC).replace(tzinfo=None) + t = t.astimezone(timezone.utc).replace(tzinfo=None) row[field_index] = t result_list.append(row) if table_col: @@ -133,3 +134,48 @@ def test_tdengine_target(tdengine, table_col): [datetime(2019, 9, 18, 1, 55, 14), "hello4", 4], ] assert result_list == expected_result + + +@pytest.mark.parametrize("tdengine", [100], indirect=["tdengine"]) +def test_sql_injection(tdengine: TDEngineData) -> None: + connection, url, user, password, db_name, supertable_name = tdengine + # Create another table to be dropped via SQL injection + tb_name = "dont_drop_me" + connection.execute(f"CREATE TABLE IF NOT EXISTS {tb_name} USING {supertable_name} TAGS (101);") + extra_table_query = f"SHOW TABLES LIKE '{tb_name}';" + assert list(connection.query(extra_table_query)), "The extra table was not created" + + # Try dropping the table + table_name = "test_table" + table_col = "table" + controller = build_flow( + [ + SyncEmitSource(), + TDEngineTarget( + url=url, + time_col="time", + columns=["my_string"], + user=user, + password=password, + database=db_name, + table_col=table_col, + supertable=supertable_name, + tag_cols=["my_int"], + time_format="%d/%m/%y %H:%M:%S UTC%z", + max_events=10, + ), + ] + ).run() + + date_time_str = "18/09/19 01:55:1" + for i in range(5): + timestamp = f"{date_time_str}{i} UTC-0000" + subtable_name = f"{table_name}{i}" + event_body = {"time": timestamp, "my_int": i, "my_string": f"s); DROP TABLE {tb_name};"} + event_body[table_col] = subtable_name + controller.emit(event_body) + + controller.terminate() + controller.await_termination() + + assert list(connection.query(extra_table_query)), "The extra table was dropped" diff --git a/storey/dtypes.py b/storey/dtypes.py index e485caff..d9a98d49 100644 --- a/storey/dtypes.py +++ b/storey/dtypes.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + from datetime import datetime, timezone from enum import Enum -from typing import Callable, List, Optional, Union +from typing import Callable, List, Literal, NamedTuple, Optional, Union import numpy @@ -103,6 +103,14 @@ class FlowError(Exception): pass +class TDEngineTypeError(TypeError): + pass + + +class TDEngineValueError(ValueError): + pass + + class WindowBase: def __init__(self, window, period, window_str): self.window_millis = window @@ -446,3 +454,14 @@ def should_aggregate(self, element): class FixedWindowType(Enum): CurrentOpenWindow = 1 LastClosedWindow = 2 + + +class _TDEngineField(NamedTuple): + field: str + # https://docs.tdengine.com/reference/taos-sql/data-type/ + type: Literal["TIMESTAMP", "INT", "FLOAT", "DOUBLE", "BINARY", "BOOL", "NCHAR", "JSON", "VARCHAR"] + length: int + note: Literal["", "TAG"] + encode: str + compress: str + level: str diff --git a/storey/targets.py b/storey/targets.py index 925d28c9..d5f5377b 100644 --- a/storey/targets.py +++ b/storey/targets.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + import asyncio import copy import csv @@ -21,10 +21,11 @@ import os import queue import random +import re import traceback import uuid from io import StringIO -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union from urllib.parse import urlparse import pandas as pd @@ -33,11 +34,20 @@ import xxhash from . import Driver -from .dtypes import Event, V3ioError +from .dtypes import ( + Event, + TDEngineTypeError, + TDEngineValueError, + V3ioError, + _TDEngineField, +) from .flow import Flow, _Batching, _split_path, _termination_obj from .table import Table, _PersistJob from .utils import stringify_key, url_to_file_system, wrap_event_for_serialization +if TYPE_CHECKING: + import taosws + class _Writer: def __init__( @@ -805,6 +815,10 @@ class TDEngineTarget(_Batching, _Writer): :type flush_after_seconds: int """ + # https://docs.tdengine.com/reference/taos-sql/limit/ + _DB_NAME_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,63}$") + _TABLE_NAME_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,191}$") + def __init__( self, url: str, @@ -819,8 +833,7 @@ def __init__( tag_cols: Union[str, List[str], None] = None, time_format: Optional[str] = None, **kwargs, - ): - + ) -> None: if table and table_col: raise ValueError("Cannot set both table and table_col") @@ -865,7 +878,6 @@ def __init__( _Batching.__init__(self, **kwargs) self._time_col = time_col tag_cols = tag_cols or [] - self._number_of_tags = len(tag_cols) _Writer.__init__( self, tag_cols + [time_col] + columns, @@ -878,8 +890,80 @@ def __init__( self._user = user self._password = password self._database = database + self._validate_db_and_table_names() + self._tdengine_type_to_column_func = self._get_tdengine_type_to_column_func() + self._tdengine_type_to_tag_func = self._get_tdengine_type_to_tag_func() + + def _validate_db_and_table_names(self) -> None: + """Check the names match their pattern""" + if not self._database: + raise TDEngineValueError("TDEngine database must be set") + if not self._DB_NAME_PATTERN.fullmatch(self._database): + raise TDEngineValueError(f"TDEngine database '{self._database}' does not comply with the naming convention") + + for table_name in (self._table, self._supertable): + if table_name: + if not self._TABLE_NAME_PATTERN.fullmatch(table_name): + raise TDEngineValueError( + f"TDEngine table name '{table_name}' does not comply with the naming convention" + ) - def _init(self): + @staticmethod + def _get_tdengine_type_to_column_func() -> dict[str, Callable[[list], "taosws.PyColumnView"]]: + import taosws + + return { + "BINARY": taosws.binary_to_column, + "BOOL": taosws.bools_to_column, + "DOUBLE": taosws.doubles_to_column, + "FLOAT": taosws.floats_to_column, + "INT": taosws.ints_to_column, + "TIMESTAMP": taosws.millis_timestamps_to_column, + "NCHAR": taosws.nchar_to_column, + "VARCHAR": taosws.varchar_to_column, + } + + @staticmethod + def _get_tdengine_type_to_tag_func() -> dict[str, Callable[[Any], "taosws.PyTagView"]]: + import taosws + + return { + "BOOL": taosws.bool_to_tag, + "DOUBLE": taosws.double_to_tag, + "FLOAT": taosws.float_to_tag, + "INT": taosws.int_to_tag, + "JSON": taosws.json_to_tag, + "NCHAR": taosws.nchar_to_tag, + "TIMESTAMP": taosws.timestamp_to_tag, + "VARCHAR": taosws.varchar_to_tag, + } + + def _get_table_schema( + self, table_name: str + ) -> tuple[ + list[tuple[str, Callable[[Any], "taosws.PyTagView"]]], list[tuple[str, Callable[[list], "taosws.PyColumnView"]]] + ]: + fields = [_TDEngineField(*raw) for raw in self._connection.query(f"DESCRIBE {table_name};")] + tags_schema = [] + reg_cols_schema = [] + for field in fields: + field_name = field.field + field_type = field.type + + if field.note == "TAG": + if field_type in self._tdengine_type_to_tag_func: + tags_schema.append((field_name, self._tdengine_type_to_tag_func[field_type])) + else: + raise TDEngineTypeError(f"Unsupported tag type '{field_type}' of field '{field_name}'") + else: + if field_type in self._tdengine_type_to_column_func: + reg_cols_schema.append((field_name, self._tdengine_type_to_column_func[field_type])) + else: + raise TDEngineTypeError(f"Unsupported column type '{field_type}' of field '{field_name}'") + + return tags_schema, reg_cols_schema + + def _init(self) -> None: import taosws _Batching._init(self) @@ -888,48 +972,66 @@ def _init(self): self._connection = taosws.connect(self._url) else: self._connection = taosws.connect(url=self._url, user=self._user, password=self._password) + self._closeables.append(self._connection) self._connection.execute(f"USE {self._database}") + self._tags_schema, self._reg_cols_schema = self._get_table_schema(self._table or self._supertable) + self._number_of_tags = len(self._tags_schema) + self._number_of_reg_cols = len(self._reg_cols_schema) + self._sql_template = self._get_sql_template() + def _event_to_batch_entry(self, event): return self._event_to_writer_entry(event) @staticmethod - def _sanitize_value(value): + def _get_params_template(num_param: int) -> str: + return f"({','.join(num_param * ['?'])})" + + def _get_sql_template(self) -> str: + with StringIO() as sql: + sql.write("INSERT INTO ?") + if self._supertable: + sql.write(f" USING {self._supertable} TAGS {self._get_params_template(self._number_of_tags)}") + sql.write(f" VALUES {self._get_params_template(self._number_of_reg_cols)};") + return sql.getvalue() + + @staticmethod + def _get_tags_from_event( + tags_schema: list[tuple[str, Callable[[Any], "taosws.PyTagView"]]], event: dict + ) -> list["taosws.PyTagView"]: + return [tag_func(event.get(tag_name)) for tag_name, tag_func in tags_schema] + + @staticmethod + def _raw_value_to_value(value): if isinstance(value, datetime.datetime): - value = round(value.timestamp() * 1000) - elif isinstance(value, str): - value = f"'{value}'" - return str(value) + # We currently support only the default millisecond precision + return int(value.timestamp() * 1000) + return value + + @classmethod + def _get_batch_values( + cls, reg_cols_schema: list[tuple[str, Callable[[list], "taosws.PyColumnView"]]], batch: list[dict] + ) -> list["taosws.PyColumnView"]: + return [ + col_func([cls._raw_value_to_value(event.get(col_name)) for event in batch]) + for col_name, col_func in reg_cols_schema + ] + + async def _emit(self, batch: list[dict], batch_key: str, batch_time, batch_events, last_event_time=None): + stmt = self._connection.statement() + stmt.prepare(self._sql_template) + try: + stmt.set_tbname(self._table or batch_key) - async def _emit(self, batch, batch_key, batch_time, batch_events, last_event_time=None): - with StringIO() as b: - b.write("INSERT INTO ") - if self._table: - b.write(self._table) - else: # table is dynamic - b.write(batch_key) if self._supertable: - b.write(" USING ") - b.write(self._supertable) - b.write(" TAGS (") - for column_index in range(self._number_of_tags): - value = batch[0].get(self._columns[column_index], "NULL") - b.write(self._sanitize_value(value)) - if column_index < self._number_of_tags - 1: - b.write(",") - b.write(")") - b.write(" VALUES ") - for record in batch: - b.write("(") - for column_index in range(self._number_of_tags, len(self._columns)): - value = record.get(self._columns[column_index], "NULL") - b.write(self._sanitize_value(value)) - if column_index < len(self._columns) - 1: - b.write(",") - b.write(") ") - b.write(";") - insert_statement = b.getvalue() - self._connection.execute(insert_statement) + # take the tags from the first event in the batch + stmt.set_tags(self._get_tags_from_event(self._tags_schema, batch[0])) + + stmt.bind_param(self._get_batch_values(self._reg_cols_schema, batch)) + stmt.add_batch() + stmt.execute() + finally: + stmt.close() class StreamTarget(Flow, _Writer): diff --git a/tests/test_targets.py b/tests/test_targets.py new file mode 100644 index 00000000..5412b91c --- /dev/null +++ b/tests/test_targets.py @@ -0,0 +1,106 @@ +# Copyright 2024 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from unittest.mock import Mock + +import pytest +import taosws + +from storey.dtypes import TDEngineValueError +from storey.targets import TDEngineTarget + + +class TestTDEngineTarget: + @staticmethod + def test_tags_mapping_consistency() -> None: + for type_, func in TDEngineTarget._get_tdengine_type_to_tag_func().items(): + assert func.__name__ == f"{type_.lower()}_to_tag" + + @staticmethod + def test_columns_mapping_consistency() -> None: + for type_, func in TDEngineTarget._get_tdengine_type_to_column_func().items(): + if type_ == "TIMESTAMP": + assert func.__name__.startswith("millis_timestamp") + else: + assert func.__name__.startswith(type_.lower()) + assert func.__name__.endswith("_to_column") + + @staticmethod + @pytest.mark.parametrize( + ("database", "table", "supertable", "table_col", "tag_cols"), + [ + (None, None, "my_super_tb", "pass_this_check", ["also_this_one"]), + ("mydb", None, "my super tb", "pass_this_check", ["also_this_one"]), + ("_db", "9table", None, None, None), + ("_db", " cars", None, None, None), + ], + ) + def test_invalid_names( + database: Optional[str], + table: Optional[str], + supertable: Optional[str], + table_col: Optional[str], + tag_cols: Optional[list[str]], + ) -> None: + with pytest.raises(TDEngineValueError): + TDEngineTarget( + url="taosws://root:taosdata@localhost:6041", + time_col="ts", + columns=["value"], + table_col=table_col, + tag_cols=tag_cols, + database=database, + table=table, + supertable=supertable, + ) + + @staticmethod + @pytest.fixture + def tdengine_target() -> TDEngineTarget: + target = TDEngineTarget( + url="taosws://root:taosdata@localhost:6041", + time_col="ts", + columns=["value"], + database="test", + table="d6241", + ) + + target._connection = Mock() + # The following test schema is obtained from the `taosBenchmark` data: + # https://docs.tdengine.com/get-started/docker/#test-data-insert-performance + # list(conn.query("describe test.d6241;")) + target._connection.query = Mock( + return_value=[ + ("ts", "TIMESTAMP", 8, "", "delta-i", "lz4", "medium"), + ("current", "FLOAT", 4, "", "delta-d", "lz4", "medium"), + ("voltage", "INT", 4, "", "simple8b", "lz4", "medium"), + ("phase", "FLOAT", 4, "", "delta-d", "lz4", "medium"), + ("groupid", "INT", 4, "TAG", "disabled", "disabled", "disabled"), + ("location", "VARCHAR", 24, "TAG", "disabled", "disabled", "disabled"), + ], + ) + return target + + @staticmethod + def test_get_table_schema(tdengine_target: TDEngineTarget) -> None: + """Test that the parsing works""" + tags_schema, reg_cols_schema = tdengine_target._get_table_schema("d6241") + assert tags_schema == [("groupid", taosws.int_to_tag), ("location", taosws.varchar_to_tag)] + assert reg_cols_schema == [ + ("ts", taosws.millis_timestamps_to_column), + ("current", taosws.floats_to_column), + ("voltage", taosws.ints_to_column), + ("phase", taosws.floats_to_column), + ] diff --git a/tests/test_types.py b/tests/test_types.py index 0cc477c6..9248edba 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + import pytest from storey.dtypes import (