diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4e6a2fe6..4b279830 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -81,3 +81,6 @@ jobs: - name: Run Ibis expression tests if: ${{ matrix.producer == 'ibis' }} run: pytest -v substrait_consumer/ibis_expressions/ + + - name: Run adhoc tests + run: pytest -v substrait_consumer/adhoc --adhoc_producer=${{ matrix.producer }} --consumer=duckdb diff --git a/substrait_consumer/tests/__init__.py b/substrait_consumer/adhoc/__init__.py similarity index 100% rename from substrait_consumer/tests/__init__.py rename to substrait_consumer/adhoc/__init__.py diff --git a/substrait_consumer/tests/adhoc/ibis_expr.py b/substrait_consumer/adhoc/ibis_expr.py similarity index 100% rename from substrait_consumer/tests/adhoc/ibis_expr.py rename to substrait_consumer/adhoc/ibis_expr.py diff --git a/substrait_consumer/tests/adhoc/query.sql b/substrait_consumer/adhoc/query.sql similarity index 100% rename from substrait_consumer/tests/adhoc/query.sql rename to substrait_consumer/adhoc/query.sql diff --git a/substrait_consumer/adhoc/test_adhoc_producer.py b/substrait_consumer/adhoc/test_adhoc_producer.py new file mode 100644 index 00000000..48382276 --- /dev/null +++ b/substrait_consumer/adhoc/test_adhoc_producer.py @@ -0,0 +1,122 @@ +import json +from pathlib import Path +from typing import Any + +import duckdb +from ibis_substrait.tests.compiler.conftest import * + +from .ibis_expr import ibis_expr +from substrait_consumer.producers.duckdb_producer import DuckDBProducer +from substrait_consumer.producers.ibis_producer import IbisProducer + +SQL_FILE_PATH = Path(__file__).parent / "query.sql" + +FILE_NAMES = { + "customer": "customer.parquet", + "lineitem": "lineitem.parquet", + "nation": "nation.parquet", + "orders": "orders.parquet", + "part": "part.parquet", + "partsupp": "partsupp.parquet", + "region": "region.parquet", + "supplier": "supplier.parquet", +} + + +def verify_equals(actual: Any, expected: Any, message: str = "") -> None: + """ + Verify that 2 objects are equal. First check to see that object + types are the same. If they differ, log the objects types and raise + an error. + If object types are the same but values are not equal, an error is + raised and the message is shown. + + Parameters: + actual: + Object to evaluate against the expected object. + expected: + Object to be evaluated against. + message: + Message to be displayed if objects are not equal. + """ + msg = [f"TEST FAILURE: Verifying equals: {actual} == {expected}."] + msg = [message] if message else msg + + assert isinstance(actual, type(expected)), ( + f"TEST FAILURE: Object types are not the same. \nActual " + f"type: {type(actual)}\nExpected type: {type(expected)}" + ) + assert actual == expected, msg + + +@pytest.mark.usefixtures("prepare_tpch_parquet_data") +class TestAdhocExpression: + """ + Test CLI for generating substrait plans from adhoc SQL queries or ibis expressions + and testing them against different consumers. + """ + + @staticmethod + @pytest.fixture(autouse=True) + def setup_teardown_function(request): + cls = request.cls + cls.produced_plans = set() + + def test_adhoc_sql_query( + self, + adhoc_producer, + consumer, + saveplan, + db_con, + part, + supplier, + partsupp, + customer, + orders, + lineitem, + nation, + region, + ) -> None: + local_files = dict() + named_tables = FILE_NAMES + adhoc_producer.setup(db_con, local_files, named_tables) + consumer.setup(db_con, local_files, named_tables) + + with open(SQL_FILE_PATH, "r") as f: + sql_query = f.read() + + if not sql_query: + raise ValueError("No SQL query. Please write SQL into query.sql") + + if isinstance(adhoc_producer, IbisProducer): + expr = ibis_expr( + part, supplier, partsupp, customer, orders, lineitem, nation, region + ) + substrait_plan = adhoc_producer._produce_substrait(expr) + else: + substrait_plan = adhoc_producer.produce_substrait(sql_query) + + producer_name = adhoc_producer.name() + if isinstance(substrait_plan, str) and saveplan: + if producer_name not in self.produced_plans: + self.produced_plans.add(producer_name) + python_json = json.loads(substrait_plan) + with open(f"{producer_name}_substrait.json", "w") as outfile: + outfile.write(json.dumps(python_json, indent=4)) + else: + pytest.skip( + f"Plan already produced using the producer: {adhoc_producer.name()}" + ) + + actual_result = consumer.run_substrait_query(substrait_plan) + duckdb_producer = DuckDBProducer() + duckdb_producer.setup(db_con, local_files, named_tables) + expected_result = duckdb_producer.run_sql_query(sql_query) + + verify_equals( + actual_result.columns, + expected_result.columns, + message=f"Result: {actual_result.columns} " + f"is not equal to the expected: " + f"{expected_result.columns}", + ) diff --git a/substrait_consumer/tests/adhoc/__init__.py b/substrait_consumer/tests/adhoc/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/substrait_consumer/tests/adhoc/test_adhoc_expression.py b/substrait_consumer/tests/adhoc/test_adhoc_expression.py deleted file mode 100644 index 4fd7d206..00000000 --- a/substrait_consumer/tests/adhoc/test_adhoc_expression.py +++ /dev/null @@ -1,106 +0,0 @@ -import json -from pathlib import Path - -import duckdb -from ibis_substrait.tests.compiler.conftest import * - -from substrait_consumer.tests.adhoc.ibis_expr import ibis_expr -from substrait_consumer.verification import verify_equals - -CUR_DIR = Path(__file__).parent -SQL_FILE_PATH = CUR_DIR / "query.sql" - - -FILE_NAMES = [ - "customer.parquet", - "lineitem.parquet", - "nation.parquet", - "orders.parquet", - "part.parquet", - "partsupp.parquet", - "region.parquet", - "supplier.parquet", -] - - -@pytest.mark.usefixtures("prepare_tpch_parquet_data") -class TestAdhocExpression: - """ - Test CLI for generating substrait plans from adhoc SQL queries or ibis expressions - and testing them against different consumers. - """ - - @staticmethod - @pytest.fixture(autouse=True) - def setup_teardown_function(request): - cls = request.cls - cls.produced_plans = set() - - @staticmethod - @pytest.fixture(autouse=True) - def setup_teardown_function(request): - cls = request.cls - - cls.db_connection = duckdb.connect() - cls.db_connection.execute("install substrait") - cls.db_connection.execute("load substrait") - - yield - - cls.db_connection.close() - - def test_adhoc_expression( - self, - adhoc_producer, - consumer, - saveplan, - part, - supplier, - partsupp, - customer, - orders, - lineitem, - nation, - region, - ) -> None: - local_files = FILE_NAMES - named_tables = dict() - producer.setup(self.db_connection, local_files, named_tables) - consumer.setup(self.db_connection, local_files, named_tables) - - with open(SQL_FILE_PATH, "r") as f: - sql_query = f.read() - - if not sql_query: - raise ValueError("No SQL query. Please write SQL into query.sql") - substrait_plan = adhoc_producer.produce_substrait( - sql_query, - consumer, - ibis_expr( - part, supplier, partsupp, customer, orders, lineitem, nation, region - ), - ) - producer_name = type(adhoc_producer).__name__ - if isinstance(substrait_plan, str) and saveplan: - if producer_name not in self.produced_plans: - self.produced_plans.add(producer_name) - python_json = json.loads(substrait_plan) - with open(f"{producer_name}_substrait.json", "w") as outfile: - outfile.write(json.dumps(python_json, indent=4)) - else: - pytest.skip( - f"Plan already produced using the producer: {adhoc_producer.name()}" - ) - - actual_result = consumer.run_substrait_query(substrait_plan) - duckdb_producer = DuckDBProducer() - duckdb_producer.setup(self.db_connection, local_files, named_tables) - expected_result = duckdb_producer.run_substrait_query(sql_query) - - verify_equals( - actual_result.columns, - expected_result.columns, - message=f"Result: {actual_result.columns} " - f"is not equal to the expected: " - f"{expected_result.columns}", - ) diff --git a/substrait_consumer/verification.py b/substrait_consumer/verification.py deleted file mode 100644 index b35bc6c7..00000000 --- a/substrait_consumer/verification.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any - - -def verify_equals(actual: Any, expected: Any, message: str = "") -> None: - """ - Verify that 2 objects are equal. First check to see that object - types are the same. If they differ, log the objects types and raise - an error. - If object types are the same but values are not equal, an error is - raised and the message is shown. - - Parameters: - actual: - Object to evaluate against the expected object. - expected: - Object to be evaluated against. - message: - Message to be displayed if objects are not equal. - """ - msg = [f"TEST FAILURE: Verifying equals: {actual} == {expected}."] - msg = [message] if message else msg - - assert isinstance(actual, type(expected)), ( - f"TEST FAILURE: Object types are not the same. \nActual " - f"type: {type(actual)}\nExpected type: {type(expected)}" - ) - assert actual == expected, msg