From edee1ea4af53bddf02dc7adc217e6c26c4a1450d Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Thu, 5 Dec 2024 09:09:00 +0900 Subject: [PATCH 1/5] feat(ux): include basename of path in generated table names in read_*() --- ibis/backends/clickhouse/__init__.py | 4 ++-- ibis/backends/datafusion/__init__.py | 35 ++++++++++++++-------------- ibis/backends/duckdb/__init__.py | 22 +++++++---------- ibis/backends/flink/__init__.py | 3 +-- ibis/backends/polars/__init__.py | 31 +++++++++++------------- ibis/backends/pyspark/__init__.py | 14 +++++------ ibis/backends/snowflake/__init__.py | 6 ++--- ibis/tests/test_util.py | 34 ++++++++++++++++++++++++++- ibis/util.py | 24 ++++++++++++++++++- 9 files changed, 109 insertions(+), 64 deletions(-) diff --git a/ibis/backends/clickhouse/__init__.py b/ibis/backends/clickhouse/__init__.py index 8707ec00cf06..aa3079058989 100644 --- a/ibis/backends/clickhouse/__init__.py +++ b/ibis/backends/clickhouse/__init__.py @@ -581,7 +581,7 @@ def read_parquet( paths = list(glob.glob(str(path))) schema = PyArrowSchema.to_ibis(ds.dataset(paths, format="parquet").schema) - name = table_name or util.gen_name("read_parquet") + name = table_name or util.gen_name_from_path(paths[0]) table = self.create_table(name, engine=engine, schema=schema, temp=True) for file_path in paths: @@ -609,7 +609,7 @@ def read_csv( paths = list(glob.glob(str(path))) schema = PyArrowSchema.to_ibis(ds.dataset(paths, format="csv").schema) - name = table_name or util.gen_name("read_csv") + name = table_name or util.gen_name_from_path(paths[0]) table = self.create_table(name, engine=engine, schema=schema, temp=True) for file_path in paths: diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index 733d6a772b48..4d3f9586dbc6 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -28,7 +28,6 @@ from ibis.common.dispatch import lazy_singledispatch from ibis.expr.operations.udf import InputType from ibis.formats.pyarrow import PyArrowSchema, PyArrowType -from ibis.util import deprecated, gen_name, normalize_filename, normalize_filenames try: from datafusion import ExecutionContext as SessionContext @@ -160,7 +159,7 @@ def _safe_raw_sql(self, sql: sge.Statement) -> Any: yield self.raw_sql(sql).collect() def _get_schema_using_query(self, query: str) -> sch.Schema: - name = gen_name("datafusion_metadata_view") + name = util.gen_name("datafusion_metadata_view") table = sg.table(name, quoted=self.compiler.quoted) src = sge.Create( this=table, @@ -345,7 +344,7 @@ def get_schema( table = database.table(table_name) return sch.schema(table.schema) - @deprecated( + @util.deprecated( as_of="9.1", instead="use the explicit `read_*` method for the filetype you are trying to read, e.g., read_parquet, read_csv, etc.", ) @@ -437,11 +436,11 @@ def read_csv( The just-registered table """ - path = normalize_filenames(source_list) - table_name = table_name or gen_name("read_csv") + paths = util.normalize_filenames(source_list) + table_name = table_name or util.gen_name_from_path(paths[0]) # Our other backends support overwriting views / tables when re-registering self.con.deregister_table(table_name) - self.con.register_csv(table_name, path, **kwargs) + self.con.register_csv(table_name, paths, **kwargs) return self.table(table_name) def read_parquet( @@ -465,8 +464,8 @@ def read_parquet( The just-registered table """ - path = normalize_filename(path) - table_name = table_name or gen_name("read_parquet") + path = util.normalize_filename(path) + table_name = table_name or util.gen_name_from_path(path) # Our other backends support overwriting views / tables when reregistering self.con.deregister_table(table_name) self.con.register_parquet(table_name, path, **kwargs) @@ -494,9 +493,9 @@ def read_delta( The just-registered table """ - source_table = normalize_filename(source_table) + source_table = util.normalize_filename(source_table) - table_name = table_name or gen_name("read_delta") + table_name = table_name or util.gen_name_from_path(source_table) # Our other backends support overwriting views / tables when reregistering self.con.deregister_table(table_name) @@ -730,55 +729,55 @@ def _read_in_memory( @_read_in_memory.register(dict) def _pydict(source, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("pydict") + tmp_name = util.gen_name("pydict") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.from_pydict(source, name=tmp_name) @_read_in_memory.register("polars.DataFrame") def _polars(source, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("polars") + tmp_name = util.gen_name("polars") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.from_polars(source, name=tmp_name) @_read_in_memory.register("polars.LazyFrame") def _polars(source, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("polars") + tmp_name = util.gen_name("polars") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.from_polars(source.collect(), name=tmp_name) @_read_in_memory.register("pyarrow.Table") def _pyarrow_table(source, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("pyarrow") + tmp_name = util.gen_name("pyarrow") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.from_arrow(source, name=tmp_name) @_read_in_memory.register("pyarrow.RecordBatchReader") def _pyarrow_rbr(source, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("pyarrow") + tmp_name = util.gen_name("pyarrow") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.from_arrow(source.read_all(), name=tmp_name) @_read_in_memory.register("pyarrow.RecordBatch") def _pyarrow_rb(source, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("pyarrow") + tmp_name = util.gen_name("pyarrow") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.register_record_batches(tmp_name, [[source]]) @_read_in_memory.register("pyarrow.dataset.Dataset") def _pyarrow_rb(source, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("pyarrow") + tmp_name = util.gen_name("pyarrow") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.register_dataset(tmp_name, source) @_read_in_memory.register("pandas.DataFrame") def _pandas(source: pd.DataFrame, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("pandas") + tmp_name = util.gen_name("pandas") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.from_pandas(source, name=tmp_name) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 03c463907443..6bb66f57114c 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -588,8 +588,9 @@ def read_json( An ibis table expression """ + filenames = util.normalize_filenames(source_list) if not table_name: - table_name = util.gen_name("read_json") + table_name = util.gen_name_from_path(filenames[0]) options = [ sg.to_identifier(key).eq(sge.convert(val)) for key, val in kwargs.items() @@ -612,11 +613,7 @@ def read_json( self._create_temp_view( table_name, - sg.select(STAR).from_( - self.compiler.f.read_json_auto( - util.normalize_filenames(source_list), *options - ) - ), + sg.select(STAR).from_(self.compiler.f.read_json_auto(filenames, *options)), ) return self.table(table_name) @@ -703,7 +700,7 @@ def read_csv( source_list = util.normalize_filenames(source_list) if not table_name: - table_name = util.gen_name("read_csv") + table_name = util.gen_name_from_path(source_list[0]) # auto_detect and columns collide, so we set auto_detect=True # unless COLUMNS has been specified @@ -779,10 +776,6 @@ def read_geo( The just-registered table """ - - if not table_name: - table_name = util.gen_name("read_geo") - # load geospatial extension self.load_extension("spatial") @@ -790,6 +783,9 @@ def read_geo( if source.startswith(("http://", "https://", "s3://")): self._load_extensions(["httpfs"]) + if not table_name: + table_name = util.gen_name_from_path(source) + source_expr = sg.select(STAR).from_( self.compiler.f.st_read( source, @@ -835,7 +831,7 @@ def read_parquet( """ source_list = util.normalize_filenames(source_list) - table_name = table_name or util.gen_name("read_parquet") + table_name = table_name or util.gen_name_from_path(source_list[0]) # Default to using the native duckdb parquet reader # If that fails because of auth issues, fall back to ingesting via @@ -944,7 +940,7 @@ def read_delta( """ source_table = util.normalize_filenames(source_table)[0] - table_name = table_name or util.gen_name("read_delta") + table_name = table_name or util.gen_name_from_path(source_table) try: from deltalake import DeltaTable diff --git a/ibis/backends/flink/__init__.py b/ibis/backends/flink/__init__.py index b411118de25a..cf85324fd04c 100644 --- a/ibis/backends/flink/__init__.py +++ b/ibis/backends/flink/__init__.py @@ -27,7 +27,6 @@ from ibis.backends.sql import SQLBackend from ibis.backends.tests.errors import Py4JJavaError from ibis.expr.operations.udf import InputType -from ibis.util import gen_name if TYPE_CHECKING: from collections.abc import Mapping @@ -767,7 +766,7 @@ def _read_file( f"`schema` must be explicitly provided when calling `read_{file_type}`" ) - table_name = table_name or gen_name(f"read_{file_type}") + table_name = table_name or util.gen_name_from_path(path) tbl_properties = { "connector": "filesystem", "path": path, diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index ff65f5fb3876..1b5d286a58f2 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -12,6 +12,7 @@ import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir +from ibis import util from ibis.backends import BaseBackend, NoUrl from ibis.backends.polars.compiler import translate from ibis.backends.polars.rewrites import bind_unbound_table, rewrite_join @@ -19,7 +20,6 @@ from ibis.common.dispatch import lazy_singledispatch from ibis.expr.rewrites import lower_stringslice, replace_parameter from ibis.formats.polars import PolarsSchema -from ibis.util import deprecated, gen_name, normalize_filename, normalize_filenames if TYPE_CHECKING: from collections.abc import Iterable @@ -100,7 +100,7 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: def _finalize_memtable(self, name: str) -> None: self.drop_table(name, force=True) - @deprecated( + @util.deprecated( as_of="9.1", instead="use the explicit `read_*` method for the filetype you are trying to read, e.g., read_parquet, read_csv, etc.", ) @@ -209,12 +209,12 @@ def read_csv( The just-registered table """ - source_list = normalize_filenames(path) + source_list = util.normalize_filenames(path) + table_name = table_name or util.gen_name_from_path(source_list[0]) # Flatten the list if there's only one element because Polars # can't handle glob strings, or compressed CSVs in a single-element list if len(source_list) == 1: source_list = source_list[0] - table_name = table_name or gen_name("read_csv") try: table = pl.scan_csv(source_list, **kwargs) # triggers a schema computation to handle compressed csv inference @@ -250,8 +250,8 @@ def read_json( The just-registered table """ - path = normalize_filename(path) - table_name = table_name or gen_name("read_json") + path = util.normalize_filename(path) + table_name = table_name or util.gen_name_from_path(path) try: self._add_table(table_name, pl.scan_ndjson(path, **kwargs)) except pl.exceptions.ComputeError: @@ -290,8 +290,8 @@ def read_delta( "read_delta method. You can install it using pip:\n\n" "pip install 'ibis-framework[polars,deltalake]'\n" ) - path = normalize_filename(path) - table_name = table_name or gen_name("read_delta") + path = util.normalize_filename(path) + table_name = table_name or util.gen_name_from_path(path) self._add_table(table_name, pl.scan_delta(path, **kwargs)) return self.table(table_name) @@ -318,7 +318,7 @@ def read_pandas( The just-registered table """ - table_name = table_name or gen_name("read_in_memory") + table_name = table_name or util.gen_name("read_in_memory") self._add_table(table_name, pl.from_pandas(source, **kwargs).lazy()) return self.table(table_name) @@ -351,24 +351,21 @@ def read_parquet( The just-registered table """ - table_name = table_name or gen_name("read_parquet") - if not isinstance(path, (str, Path)) and len(path) == 1: - path = path[0] + paths = util.normalize_filenames(path) + table_name = table_name or util.gen_name_from_path(paths[0]) - if not isinstance(path, (str, Path)) and len(path) > 1: + if len(paths) > 1: self._import_pyarrow() import pyarrow.dataset as ds - paths = [normalize_filename(p) for p in path] obj = pl.scan_pyarrow_dataset( source=ds.dataset(paths, format="parquet"), **kwargs, ) - self._add_table(table_name, obj) else: - path = normalize_filename(path) - self._add_table(table_name, pl.scan_parquet(path, **kwargs)) + obj = pl.scan_parquet(paths[0], **kwargs) + self._add_table(table_name, obj) return self.table(table_name) def create_table( diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index f8d9cfffb3c6..de9f2067a6ea 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -790,7 +790,7 @@ def read_delta( ) path = util.normalize_filename(path) spark_df = self._session.read.format("delta").load(path, **kwargs) - table_name = table_name or util.gen_name("read_delta") + table_name = table_name or util.gen_name_from_path(path) spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -827,7 +827,7 @@ def read_parquet( ) path = util.normalize_filename(path) spark_df = self._session.read.parquet(path, **kwargs) - table_name = table_name or util.gen_name("read_parquet") + table_name = table_name or util.gen_name_from_path(path, "parquet") spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -869,7 +869,7 @@ def read_csv( spark_df = self._session.read.csv( source_list, inferSchema=inferSchema, header=header, **kwargs ) - table_name = table_name or util.gen_name("read_csv") + table_name = table_name or util.gen_name_from_path(source_list[0], "csv") spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -907,7 +907,7 @@ def read_json( ) source_list = util.normalize_filenames(source_list) spark_df = self._session.read.json(source_list, **kwargs) - table_name = table_name or util.gen_name("read_json") + table_name = table_name or util.gen_name_from_path(source_list[0], "json") spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -1217,7 +1217,7 @@ def read_csv_dir( watermark.time_col, _interval_to_string(watermark.allowed_delay), ) - table_name = table_name or util.gen_name("read_csv_dir") + table_name = table_name or util.gen_name_from_path(path, "csv_dir") spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -1272,7 +1272,7 @@ def read_parquet_dir( watermark.time_col, _interval_to_string(watermark.allowed_delay), ) - table_name = table_name or util.gen_name("read_parquet_dir") + table_name = table_name or util.gen_name_from_path(path, "parquet_dir") spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -1318,7 +1318,7 @@ def read_json_dir( watermark.time_col, _interval_to_string(watermark.allowed_delay), ) - table_name = table_name or util.gen_name("read_json_dir") + table_name = table_name or util.gen_name_from_path(path, "json_dir") spark_df.createOrReplaceTempView(table_name) return self.table(table_name) diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index f60e06fdc162..c3ec56fe26cf 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -883,7 +883,7 @@ def read_csv( # 99 is the maximum allowed number of threads by Snowflake: # https://docs.snowflake.com/en/sql-reference/sql/put#optional-parameters threads = min((os.cpu_count() or 2) // 2, 99) - table = table_name or ibis.util.gen_name("read_csv_snowflake") + table = table_name or ibis.util.gen_name_from_path(path) compiler = self.compiler quoted = compiler.quoted qtable = sg.to_identifier(table, quoted=quoted) @@ -1010,7 +1010,7 @@ def read_json( """ stage = util.gen_name("read_json_stage") file_format = util.gen_name("read_json_format") - table = table_name or util.gen_name("read_json_snowflake") + table = table_name or util.gen_name_from_path(path, "json_snowflake") quoted = self.compiler.quoted qtable = sg.table(table, quoted=quoted) threads = min((os.cpu_count() or 2) // 2, 99) @@ -1107,7 +1107,7 @@ def read_parquet( ) stage = util.gen_name("read_parquet_stage") - table = table_name or util.gen_name("read_parquet_snowflake") + table = table_name or util.gen_name_from_path(abspath, "parquet_snowflake") quoted = self.compiler.quoted qtable = sg.table(table, quoted=quoted) threads = min((os.cpu_count() or 2) // 2, 99) diff --git a/ibis/tests/test_util.py b/ibis/tests/test_util.py index 9b0872f77407..1dbf7d2433c2 100644 --- a/ibis/tests/test_util.py +++ b/ibis/tests/test_util.py @@ -2,9 +2,16 @@ from __future__ import annotations +from pathlib import Path + import pytest -from ibis.util import PseudoHashable, flatten_iterable, import_object +from ibis.util import ( + PseudoHashable, + flatten_iterable, + gen_name_from_path, + import_object, +) @pytest.mark.parametrize( @@ -138,3 +145,28 @@ class MyMap(dict): assert ph2 != ph3 assert ph3 == ph3 assert ph1 == ph4 + + +@pytest.mark.parametrize( + ("path", "expected"), + [ + ("my file.csv", "ibis_read_my_file_csv"), + ("/my file.csv", "ibis_read_my_file_csv"), + ( + "/really extra super long file name.csv", + "ibis_read_super_long_file_name_csv", + ), + ("s3://my file.csv", "ibis_read_my_file_csv"), + ("PATH-TO/my-file.csv", "ibis_read_my_file_csv"), + ("/PATH-TO/my-file.csv", "ibis_read_my_file_csv"), + ("s3://PATH-TO/my-file.csv", "ibis_read_my_file_csv"), + ], +) +@pytest.mark.parametrize("use_path", [True, False]) +def test_gen_name_from_path(path, expected, use_path): + if use_path: + path = Path(path) + result = gen_name_from_path(path) + # MySQL has a 64 character limit for table names + assert len(result) <= 64 + assert result.startswith(expected) diff --git a/ibis/util.py b/ibis/util.py index 993e539dfa88..fe1fb22866b1 100644 --- a/ibis/util.py +++ b/ibis/util.py @@ -16,6 +16,7 @@ import types import uuid import warnings +from pathlib import Path from types import ModuleType from typing import TYPE_CHECKING, Any, Generic, TypeVar from uuid import uuid4 @@ -27,7 +28,6 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterator, Sequence from numbers import Real - from pathlib import Path import ibis.expr.types as ir @@ -522,6 +522,28 @@ def gen_name(namespace: str) -> str: return f"ibis_{namespace}_{uid}" +def gen_name_from_path(path: str | Path) -> str: + """Create a user-friendly unique identifier from a file path. + + This is NOT a stable API. We may change the implementation at any time. + + Examples + -------- + >>> gen_name_from_path("s3://path/to/myfile.csv") # doctest: +ELLIPSIS + 'ibis_read_s3__path__to__myfile__csv...' + >>> gen_name_from_path("s3://long_long_long_path/to/myfile.csv") # doctest: +ELLIPSIS + 'ibis_read_s3__myfile__csv...' + """ + basename = os.path.basename(path) + basename = re.sub(r"[^a-zA-Z0-9_]", "_", basename) + # MySQL has a limit of 64 characters for table names. + # Let's not give users runtime errors because of this. + basename = basename[-25:] + basename = basename.strip("_") + prefix = f"read_{basename}" + return gen_name(prefix) + + def slice_to_limit_offset( what: slice, count: ir.IntegerScalar ) -> tuple[int | ir.IntegerScalar, int | ir.IntegerScalar]: From 99e6fa56d6bb058d28d92ed0cfb51e254e6c4a19 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Thu, 5 Dec 2024 10:16:20 +0900 Subject: [PATCH 2/5] chore: update register tests to not rely on generated table names This is an unstable API and shouldn't be relied on. --- ibis/backends/tests/test_register.py | 53 +++++++++++++--------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 1ca96eb42221..85b4e7336a72 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -50,13 +50,12 @@ def gzip_csv(data_dir, tmp_path): # TODO: rewrite or delete test when register api is removed @pytest.mark.parametrize( - ("fname", "in_table_name", "out_table_name"), + ("fname", "table_name"), [ - param("diamonds.csv", None, "ibis_read_csv_", id="default"), + param("diamonds.csv", None, id="default"), param( "csv://diamonds.csv", "Diamonds2", - "Diamonds2", id="csv_name", marks=pytest.mark.notyet( ["pyspark"], reason="pyspark lowercases view names" @@ -65,13 +64,11 @@ def gzip_csv(data_dir, tmp_path): param( "file://diamonds.csv", "fancy_stones", - "fancy_stones", id="file_name", ), param( "file://diamonds.csv", "fancy stones", - "fancy stones", id="file_atypical_name", marks=pytest.mark.notyet( ["pyspark"], reason="no spaces allowed in view names" @@ -80,7 +77,6 @@ def gzip_csv(data_dir, tmp_path): param( ["file://diamonds.csv", "diamonds.csv"], "fancy_stones2", - "fancy_stones2", id="multi_csv", marks=pytest.mark.notyet( ["datafusion"], @@ -105,12 +101,16 @@ def gzip_csv(data_dir, tmp_path): "databricks", ] ) -def test_register_csv(con, data_dir, fname, in_table_name, out_table_name): +def test_register_csv(con, data_dir, fname, table_name): + tables_before = set(con.list_tables()) with pushd(data_dir / "csv"): with pytest.warns(FutureWarning, match="v9.1"): - table = con.register(fname, table_name=in_table_name) + table = con.register(fname, table_name=table_name) + new_tables = set(con.list_tables()) - tables_before + assert len(new_tables) == 1 + if table_name is not None: + assert new_tables.pop() == table_name - assert any(out_table_name in t for t in con.list_tables()) if con.name != "datafusion": table.count().execute() @@ -185,18 +185,12 @@ def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]: # TODO: rewrite or delete test when register api is removed @pytest.mark.parametrize( - ("fname", "in_table_name", "out_table_name"), + ("fname", "table_name"), [ - param( - "parquet://functional_alltypes.parquet", None, "ibis_read_parquet", id="url" - ), - param("functional_alltypes.parquet", "funk_all", "funk_all", id="basename"), - param( - "parquet://functional_alltypes.parq", "funk_all", "funk_all", id="url_parq" - ), - param( - "parquet://functional_alltypes", None, "ibis_read_parquet", id="url_no_ext" - ), + param("parquet://functional_alltypes.parquet", None, id="url"), + param("functional_alltypes.parquet", "my_table1", id="basename"), + param("parquet://functional_alltypes.parq", "my_table2", id="url_parq"), + param("parquet://functional_alltypes", None, id="url_no_ext"), ], ) @pytest.mark.notyet( @@ -214,9 +208,7 @@ def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]: "trino", ] ) -def test_register_parquet( - con, tmp_path, data_dir, fname, in_table_name, out_table_name -): +def test_register_parquet(con, tmp_path, data_dir, fname, table_name): pq = pytest.importorskip("pyarrow.parquet") fname = Path(fname) @@ -224,12 +216,14 @@ def test_register_parquet( pq.write_table(table, tmp_path / fname.name) + tables_before = set(con.list_tables()) with pushd(tmp_path): with pytest.warns(FutureWarning, match="v9.1"): - table = con.register(f"parquet://{fname.name}", table_name=in_table_name) - - assert any(out_table_name in t for t in con.list_tables()) - + table = con.register(f"parquet://{fname.name}", table_name=table_name) + new_tables = set(con.list_tables()) - tables_before + assert len(new_tables) == 1 + if table_name is not None: + assert new_tables.pop() == table_name if con.name != "datafusion": table.count().execute() @@ -263,6 +257,7 @@ def test_register_iterator_parquet( pq.write_table(table, tmp_path / "functional_alltypes.parquet") + tables_before = set(con.list_tables()) with pushd(tmp_path): with pytest.warns(FutureWarning, match="v9.1"): table = con.register( @@ -272,8 +267,8 @@ def test_register_iterator_parquet( ], table_name=None, ) - - assert any("ibis_read_parquet" in t for t in con.list_tables()) + new_tables = set(con.list_tables()) - tables_before + assert len(new_tables) == 1 assert table.count().execute() From f7cb256ffdcd6df34fd3a0553801db8ab1b5e8d2 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Fri, 6 Dec 2024 22:45:50 +0900 Subject: [PATCH 3/5] chore: clean up registered tables after every test --- ibis/backends/tests/test_register.py | 27 +++++++++++++++++++++++++++ ibis/util.py | 4 +--- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 85b4e7336a72..055e1aaede6f 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -39,6 +39,19 @@ def pushd(new_dir): os.chdir(previous_dir) +def drop(table: ibis.Table): + backend = table._find_backend() + name = table.get_name() + try: + backend.drop_table(name) + return + except Exception as e: + # This is a lazy way to check if the error is due to the table being a view + if "view" not in str(e).lower(): + raise + backend.drop_view(name) + + @pytest.fixture def gzip_csv(data_dir, tmp_path): basename = "diamonds.csv" @@ -113,6 +126,7 @@ def test_register_csv(con, data_dir, fname, table_name): if con.name != "datafusion": table.count().execute() + drop(table) # TODO: rewrite or delete test when register api is removed @@ -139,6 +153,7 @@ def test_register_csv_gz(con, data_dir, gzip_csv): table = con.register(gzip_csv) assert table.count().execute() + drop(table) # TODO: rewrite or delete test when register api is removed @@ -168,6 +183,7 @@ def test_register_with_dotted_name(con, data_dir, tmp_path): if con.name != "datafusion": table.count().execute() + drop(table) def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]: @@ -226,6 +242,7 @@ def test_register_parquet(con, tmp_path, data_dir, fname, table_name): assert new_tables.pop() == table_name if con.name != "datafusion": table.count().execute() + drop(table) # TODO: rewrite or delete test when register api is removed @@ -270,6 +287,7 @@ def test_register_iterator_parquet( new_tables = set(con.list_tables()) - tables_before assert len(new_tables) == 1 assert table.count().execute() + drop(table) # TODO: remove entirely when `register` is removed @@ -299,11 +317,13 @@ def test_register_pandas(con): with pytest.warns(FutureWarning, match="v9.1"): t = con.register(df) assert t.x.sum().execute() == 6 + drop(t) with pytest.warns(FutureWarning, match="v9.1"): t = con.register(df, "my_table") assert t.op().name == "my_table" assert t.x.sum().execute() == 6 + drop(t) # TODO: remove entirely when `register` is removed @@ -333,6 +353,7 @@ def test_register_pyarrow_tables(con): with pytest.warns(FutureWarning, match="v9.1"): t = con.register(pa_t) assert t.x.sum().execute() == 6 + drop(t) @pytest.mark.notyet( @@ -371,6 +392,7 @@ def test_csv_reregister_schema(con, tmp_path): assert result_schema["cola"].is_integer() assert result_schema["colb"].is_float64() assert result_schema["colc"].is_string() + drop(foo_table) @pytest.mark.notimpl( @@ -433,6 +455,7 @@ def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name): if in_table_name is not None: assert table.op().name == in_table_name assert table.count().execute() + drop(table) @pytest.fixture(scope="module") @@ -469,6 +492,7 @@ def test_read_parquet_glob(con, tmp_path, ft_data): table = con.read_parquet(tmp_path / f"*.{ext}") assert table.count().execute() == nrows * ntables + drop(table) @pytest.mark.notyet( @@ -497,6 +521,7 @@ def test_read_csv_glob(con, tmp_path, ft_data): table = con.read_csv(tmp_path / f"*.{ext}") assert table.count().execute() == nrows * ntables + drop(table) @pytest.mark.notyet( @@ -532,6 +557,7 @@ def test_read_json_glob(con, tmp_path, ft_data): table = con.read_json(tmp_path / f"*.{ext}") assert table.count().execute() == nrows * ntables + drop(table) @pytest.fixture(scope="module") @@ -592,3 +618,4 @@ def test_read_csv(con, data_dir, in_table_name, num_diamonds): } ) assert table.count().execute() == num_diamonds + drop(table) diff --git a/ibis/util.py b/ibis/util.py index fe1fb22866b1..6ae821cc2d8a 100644 --- a/ibis/util.py +++ b/ibis/util.py @@ -530,9 +530,7 @@ def gen_name_from_path(path: str | Path) -> str: Examples -------- >>> gen_name_from_path("s3://path/to/myfile.csv") # doctest: +ELLIPSIS - 'ibis_read_s3__path__to__myfile__csv...' - >>> gen_name_from_path("s3://long_long_long_path/to/myfile.csv") # doctest: +ELLIPSIS - 'ibis_read_s3__myfile__csv...' + 'ibis_read_myfile_csv...' """ basename = os.path.basename(path) basename = re.sub(r"[^a-zA-Z0-9_]", "_", basename) From dcd19dfd35bb77397f7e1a77554bd39188bbd1e9 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Fri, 6 Dec 2024 22:48:33 +0900 Subject: [PATCH 4/5] chore: move import into TYPE_CHECKING block --- ibis/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/util.py b/ibis/util.py index 6ae821cc2d8a..b804674f13fe 100644 --- a/ibis/util.py +++ b/ibis/util.py @@ -16,7 +16,6 @@ import types import uuid import warnings -from pathlib import Path from types import ModuleType from typing import TYPE_CHECKING, Any, Generic, TypeVar from uuid import uuid4 @@ -28,6 +27,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterator, Sequence from numbers import Real + from pathlib import Path import ibis.expr.types as ir From e79e26df21fe78573082101036c612a51ee96a71 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 7 Dec 2024 07:48:18 +0900 Subject: [PATCH 5/5] chore: fixup usage in snowflake and pyspark --- ibis/backends/pyspark/__init__.py | 12 ++++++------ ibis/backends/snowflake/__init__.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index de9f2067a6ea..8a7968edd110 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -827,7 +827,7 @@ def read_parquet( ) path = util.normalize_filename(path) spark_df = self._session.read.parquet(path, **kwargs) - table_name = table_name or util.gen_name_from_path(path, "parquet") + table_name = table_name or util.gen_name_from_path(path) spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -869,7 +869,7 @@ def read_csv( spark_df = self._session.read.csv( source_list, inferSchema=inferSchema, header=header, **kwargs ) - table_name = table_name or util.gen_name_from_path(source_list[0], "csv") + table_name = table_name or util.gen_name_from_path(source_list[0]) spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -907,7 +907,7 @@ def read_json( ) source_list = util.normalize_filenames(source_list) spark_df = self._session.read.json(source_list, **kwargs) - table_name = table_name or util.gen_name_from_path(source_list[0], "json") + table_name = table_name or util.gen_name_from_path(source_list[0]) spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -1217,7 +1217,7 @@ def read_csv_dir( watermark.time_col, _interval_to_string(watermark.allowed_delay), ) - table_name = table_name or util.gen_name_from_path(path, "csv_dir") + table_name = table_name or util.gen_name_from_path(path) spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -1272,7 +1272,7 @@ def read_parquet_dir( watermark.time_col, _interval_to_string(watermark.allowed_delay), ) - table_name = table_name or util.gen_name_from_path(path, "parquet_dir") + table_name = table_name or util.gen_name_from_path(path) spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -1318,7 +1318,7 @@ def read_json_dir( watermark.time_col, _interval_to_string(watermark.allowed_delay), ) - table_name = table_name or util.gen_name_from_path(path, "json_dir") + table_name = table_name or util.gen_name_from_path(path) spark_df.createOrReplaceTempView(table_name) return self.table(table_name) diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index c3ec56fe26cf..fd34a2b4755e 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -1010,7 +1010,7 @@ def read_json( """ stage = util.gen_name("read_json_stage") file_format = util.gen_name("read_json_format") - table = table_name or util.gen_name_from_path(path, "json_snowflake") + table = table_name or util.gen_name_from_path(path) quoted = self.compiler.quoted qtable = sg.table(table, quoted=quoted) threads = min((os.cpu_count() or 2) // 2, 99) @@ -1107,7 +1107,7 @@ def read_parquet( ) stage = util.gen_name("read_parquet_stage") - table = table_name or util.gen_name_from_path(abspath, "parquet_snowflake") + table = table_name or util.gen_name_from_path(abspath) quoted = self.compiler.quoted qtable = sg.table(table, quoted=quoted) threads = min((os.cpu_count() or 2) // 2, 99)