Skip to content

Commit

Permalink
feat: overhaul SQL string reformatting and Producer/Consumer interfac…
Browse files Browse the repository at this point in the history
…es (#128)

This PR does two someone independent changes that makes writing test
cases simpler and more robust: an overhaul of the `Producer` and
`Consumer` interfaces and a change in how the format arguments of named
tables and local files in SQL strings are specified. The two changes
have been tied together into one PR because changing the arguments alone
proved to be difficult due to the previous brittleness of the two
interfaces.

The change related to the format arguments consists in the following:
instead of specifying a list of `local_files` for each test case, each
of which would then either be loaded into a table whose name is derived
from the corresponding file name *or* would be processed as is if the
SQL string contained the magic works `read_parquet`, local files and
named tables are now specified independently of each other and both are
specified as a dict: the value of each entry corresponds to the
placeholder used in the format string (such as '{customer}') and the
value consists of the local file path (such as
`customer_small.parquet`). For named tables, the idea is that the
corresponding system loads the local file into a table with the given
name; local files are processed directly. Since the definition of test
cases is used to create parametrized test fixtures, this change involves
all test functions uses these parametrized fixtures. As another
consequence of this change, some plan snapshots change: some table names
now don't have the `_small` suffix anymore because the table name is
specified explicitly rather than being derived from the file name and in
one case the order of the input tables in the `FROM` clause has changed
(the new order corresponds to the one in the official TPC-H query wheras
the previous order didn't).

The change related to the `Producer` and `Consumer` interfaces
simplifies how consumers are created and used. First, both interfaces
now have a `setup` method implemented by the interface which takes care
of expanding the relative file paths into absolute ones. This removes
the need to do that expansion in various other places. Similarly,
`Producer.format_sql` takes care of replacing format arguments such that
derived classes don't have to. Again in the same spirit,
`Producer.produce_substrait` also takes care of formatting the SQL query
such that call sites can directly call that function instead of having
to remember to reformat the SQL string beforehand. The PR also replaces
some direct usages of the DuckDB connection with more high-level usages
of `DuckDBConsumer`, such that the encapsulated functionaly described
above can be used. To that aim, that class also gets a new method
`run_sql_query`.

Finally, the PR also removes some duplicate or unused code related to
loading local files and formattting queries.

I have manually checked and there are now tests that change their
fail/pass status compared to the current `main`.

Signed-off-by: Ingo Müller <[email protected]>
  • Loading branch information
ingomueller-net authored Nov 12, 2024
1 parent fda68c9 commit d4f2aa7
Show file tree
Hide file tree
Showing 186 changed files with 1,622 additions and 1,176 deletions.
59 changes: 20 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,58 +98,38 @@ query_1.py
TPCH_QUERY_TESTS = (
{
"test_name": "test_tpch_sql_1",
"file_names": ["lineitem.parquet"],
"local_files": {},
"named_tables": {"lineitem": "lineitem.parquet"},
"sql_query": get_sql("q1.sql"),
"substrait_query": get_substrait_plan("query_01_plan.json"),
},
{
"test_name": "test_tpch_sql_2",
"file_names": [
"part.parquet",
"supplier.parquet",
"partsupp.parquet",
"nation.parquet",
"region.parquet",
"partsupp.parquet",
"supplier.parquet",
"nation.parquet",
"region.parquet",
],
"local_files": {},
"named_tables": {
"part": "part.parquet",
"supplier": "supplier.parquet",
"partsupp": "partsupp.parquet",
"nation": "nation.parquet",
"region": "region.parquet",
"partsupp": "partsupp.parquet",
"supplier": "supplier.parquet",
"nation": "nation.parquet",
"region": "region.parquet",
},
"sql_query": get_sql("q2.sql"),
"substrait_query": get_substrait_plan("query_02_plan.json"),
},
]
)
```
## Substrait Plans
Substrait query plans are located in `substrait_consumer/tests/integration/queries/tpch_substrait_plans`.
The substrait query plans have placeholder strings in the `local_files` objects in the json
structure.
```json
"local_files": {
"items": [
{
"uri_file": "file://FILENAME_PLACEHOLDER_0",
"parquet": {}
}
]
}
```


When the tests are run, these placeholders are replaced by the parquet data listed
listed in `"file_names"` in the test case args file. The order of parquet file appearance in the
`"file_names"` list should be consistent with the ordering for the table names in the substrait
query plan.

## SQL Queries
SQL queries are located in `substrait_consumer/tests/integration/queries/tpch_sql`.

The SQL queries have empty bracket placeholders (`'{}'`) where the table names will be inserted.
Table names are determined based on the `"file_names"` in the test case args file. The order of
parquet file appearance in the `"file_names"` list should be consistent with the ordering for the
table names in the SQL query. The actual format after replacement will depend on the consumer being
used.

The SQL queries have named placeholders (`'{customer}'`) where the table names or file paths will be inserted.
Table names are determined based on the `"named_tables"` and `"local_files"` in the test case args file.

# Function Tests
The substrait function tests aim to test the functions available in Substrait. This is done
Expand Down Expand Up @@ -182,7 +162,8 @@ arithmetic_tests.py
SCALAR_FUNCTIONS = (
{
"test_name": "add",
"file_names": ["partsupp.parquet"],
"local_files": {},
"named_tables": {"partsupp": "partsupp.parquet"},
"sql_query": SQL_SCALAR["add"],
"ibis_expr": IBIS_SCALAR["add"],
},
Expand All @@ -196,7 +177,7 @@ SQL_SCALAR = {
"add":
"""
SELECT PS_PARTKEY, PS_SUPPKEY, add(PS_PARTKEY, PS_SUPPKEY) AS ADD_KEY
FROM '{}';
FROM '{partsupp}';
""",
```

Expand Down
63 changes: 6 additions & 57 deletions substrait_consumer/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,67 +44,16 @@ class SubstraitUtils:
"""

@staticmethod
def get_full_path(file_names: Iterable[str]) -> list[str]:
def compute_full_paths(local_files: dict[str, str]) -> dict[str, str]:
"""
Get full paths for the TPCH parquet data.
Get the full paths for the given local files.
Parameters:
file_names:
List of TPCH parquet data file names provided by the test case.
local_files:
A `dict` mapping format argument names to local files paths.
Returns:
List of full paths.
A `dict` where the paths are expanded to absolute paths.
"""
data_dir = CUR_DIR / "data" / "tpch_parquet"
full_paths_list = [f"{data_dir}/{dataset}" for dataset in file_names]

return full_paths_list

def format_sql_query(self, sql_query: str, file_names: list[str]) -> str:
"""
Replace the 'Table' Parameters from the SQL query with the relative
file paths of the parquet data.
Parameters:
sql_query:
SQL query.
file_names:
List of file names.
Returns:
SQL Query with file paths.
"""
sql_commands_list = [line.strip() for line in sql_query.strip().split("\n")]
sql_query = " ".join(sql_commands_list)
# Get full path for all datasets used in the query
parquet_file_paths = self.get_full_path(file_names)

return sql_query.format(*parquet_file_paths)

def format_substrait_query(
self, substrait_query: str, file_names: list[str]
) -> str:
"""
Replace the 'local_files' path in the substrait query plan with
the full path of the parquet data.
Parameters:
substrait_query:
Substrait query.
file_names:
List of file names.
Returns:
Substrait query plan in byte format.
"""
# Get full path for all datasets used in the query
parquet_file_paths = self.get_full_path(file_names)

# Replace the filename placeholder in the substrait query plan with
# the proper parquet data file paths.
for count, file_path in enumerate(parquet_file_paths):
substrait_query = substrait_query.replace(
f"FILENAME_PLACEHOLDER_{count}", file_path
)

return substrait_query
return {k: f"{data_dir}/{v}" for k, v in local_files.items()}
25 changes: 9 additions & 16 deletions substrait_consumer/consumers/acero_consumer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
from __future__ import annotations

import string
from pathlib import Path
from typing import Iterable

import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.substrait as substrait

from substrait_consumer.common import SubstraitUtils

from .consumer import COLUMN_A, COLUMN_B, COLUMN_C, COLUMN_D, Consumer


Expand All @@ -19,15 +13,14 @@ class AceroConsumer(Consumer):
"""

def __init__(self):
self.tables = {}
self.table_provider = lambda names, schema: self.tables[names[0].lower()]

def setup(self, db_connection, file_names: Iterable[str]):
if len(file_names) > 0:
parquet_file_paths = SubstraitUtils.get_full_path(file_names)
for file_name, file_path in zip(file_names, parquet_file_paths):
table_name = Path(file_name).stem
self.tables[table_name] = pq.read_table(file_path)
self.named_tables = {}
self.table_provider = lambda names, schema: self.named_tables[names[0].lower()]

def _setup(
self, db_connection, local_files: dict[str, str], named_tables: dict[str, str]
):
for table_name, file_path in named_tables.items():
self.named_tables[table_name] = pq.read_table(file_path)
else:
table = pa.table(
{
Expand All @@ -37,7 +30,7 @@ def setup(self, db_connection, file_names: Iterable[str]):
"d": COLUMN_D,
}
)
self.tables["t"] = table
self.named_tables["t"] = table

def run_substrait_query(self, substrait_query: str) -> pa.Table:
"""
Expand Down
31 changes: 29 additions & 2 deletions substrait_consumer/consumers/consumer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Iterable

import pyarrow as pa

from substrait_consumer.common import SubstraitUtils


COLUMN_A = [1, 2, 3, -4, 5, -6, 7, 8, 9, None]
COLUMN_B = [1, 1, 1, 1, 1, 2, 2, 2, 2, 2]
COLUMN_C = [
Expand Down Expand Up @@ -34,8 +36,33 @@


class Consumer(ABC):

def setup(
self, db_connection, local_files: dict[str, str], named_tables: dict[str, str]
):
"""
Initializes this `Consumer` instance.
In particular, expands the paths in `local_files` and `named_tables` to
absolute paths and forwards the arguments to `self._setup` implemented
by classes inheriting from `Consumer`.
Parameters:
db_connection:
DuckDB connection for this `Consumer`.
local_files:
A `dict` mapping format argument names to local files paths.
named_tables:
A `dict` mapping table names to local file paths.
"""
local_files = SubstraitUtils.compute_full_paths(local_files)
named_tables = SubstraitUtils.compute_full_paths(named_tables)
self._setup(db_connection, local_files, named_tables)

@abstractmethod
def setup(self, db_connection, file_names: Iterable[str]):
def _setup(
self, db_connection, local_files: dict[str, str], named_tables: dict[str, str]
):
pass

@abstractmethod
Expand Down
24 changes: 9 additions & 15 deletions substrait_consumer/consumers/datafusion_consumer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
from __future__ import annotations

import json
import string
from pathlib import Path
from typing import Iterable

import pyarrow as pa
from datafusion import SessionContext
from datafusion import substrait as ds
from google.protobuf.json_format import Parse
from substrait.gen.proto.plan_pb2 import Plan

from substrait_consumer.common import SubstraitUtils

from .consumer import COLUMN_A, COLUMN_B, COLUMN_C, COLUMN_D, Consumer


Expand All @@ -24,17 +19,16 @@ class DataFusionConsumer(Consumer):
def __init__(self):
self._ctx = SessionContext()

def setup(self, db_connection, file_names: Iterable[str]):
if len(file_names) > 0:
parquet_file_paths = SubstraitUtils.get_full_path(file_names)
for file_name, file_path in zip(file_names, parquet_file_paths):
table_name = Path(file_name).stem
if self._ctx.table_exist(table_name):
self._ctx.deregister_table(table_name)
self._ctx.register_parquet(table_name, file_path)
def _setup(
self, db_connection, local_files: dict[str, str], named_tables: dict[str, str]
):
for table_name, file_path in named_tables.items():
if self._ctx.table_exist(table_name):
self._ctx.deregister_table(table_name)
self._ctx.register_parquet(table_name, file_path)
else:
if not self._ctx.table_exist("t"):
tables = pa.RecordBatch.from_arrays(
named_tables = pa.RecordBatch.from_arrays(
[
pa.array(COLUMN_A),
pa.array(COLUMN_B),
Expand All @@ -44,7 +38,7 @@ def setup(self, db_connection, file_names: Iterable[str]):
names=["a", "b", "c", "d"],
)

self._ctx.register_record_batches("t", [[tables]])
self._ctx.register_record_batches("t", [[named_tables]])

def run_substrait_query(self, substrait_query: str) -> pa.Table:
"""
Expand Down
41 changes: 5 additions & 36 deletions substrait_consumer/consumers/duckdb_consumer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
from __future__ import annotations

import string
from pathlib import Path
from typing import Iterable

import duckdb
import pyarrow as pa

from substrait_consumer.common import SubstraitUtils

from .consumer import Consumer
from substrait_consumer.producers.producer import load_named_tables


class DuckDBConsumer(Consumer):
Expand All @@ -26,9 +21,11 @@ def __init__(self, db_connection=None):
self.db_connection.execute("INSTALL substrait")
self.db_connection.execute("LOAD substrait")

def setup(self, db_connection, file_names: Iterable[str]):
def _setup(
self, db_connection, local_files: dict[str, str], named_tables: dict[str, str]
):
self.db_connection = db_connection
self.load_tables_from_parquet(file_names)
load_named_tables(db_connection, named_tables)

def run_substrait_query(self, substrait_query: str) -> pa.Table:
"""
Expand All @@ -42,31 +39,3 @@ def run_substrait_query(self, substrait_query: str) -> pa.Table:
A pyarrow table resulting from running the substrait query plan.
"""
return self.db_connection.from_substrait_json(substrait_query).arrow()

def load_tables_from_parquet(
self,
file_names: Iterable[str],
) -> list:
"""
Load all the parquet files into separate tables in DuckDB.
Parameters:
file_names:
Name of parquet files.
Returns:
A list of the table names.
"""
parquet_file_paths = SubstraitUtils.get_full_path(file_names)
table_names = []
for file_name, file_path in zip(file_names, parquet_file_paths):
table_name = Path(file_name).stem
try:
self.db_connection.execute(f"DROP TABLE {table_name}")
except:
pass
create_table_sql = f"CREATE TABLE {table_name} AS SELECT * FROM read_parquet('{file_path}');"
self.db_connection.execute(create_table_sql)
table_names.append(table_name)

return table_names
Loading

0 comments on commit d4f2aa7

Please sign in to comment.