From 3919f2f90605c968f0089e393d8a6e55b90f5fcf Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Fri, 22 Nov 2024 10:10:58 -0800 Subject: [PATCH] feat(ux): include basename of path in generated table names in read_*() --- ibis/backends/clickhouse/__init__.py | 4 ++-- ibis/backends/datafusion/__init__.py | 30 +++++++++++++-------------- ibis/backends/duckdb/__init__.py | 22 ++++++++------------ ibis/backends/flink/__init__.py | 3 +-- ibis/backends/polars/__init__.py | 31 +++++++++++++--------------- ibis/backends/pyspark/__init__.py | 14 ++++++------- ibis/backends/snowflake/__init__.py | 6 +++--- ibis/util.py | 15 ++++++++++++++ 8 files changed, 66 insertions(+), 59 deletions(-) diff --git a/ibis/backends/clickhouse/__init__.py b/ibis/backends/clickhouse/__init__.py index 8707ec00cf06..c42096e4160a 100644 --- a/ibis/backends/clickhouse/__init__.py +++ b/ibis/backends/clickhouse/__init__.py @@ -581,7 +581,7 @@ def read_parquet( paths = list(glob.glob(str(path))) schema = PyArrowSchema.to_ibis(ds.dataset(paths, format="parquet").schema) - name = table_name or util.gen_name("read_parquet") + name = table_name or util.gen_name_from_path(paths[0], "parquet") table = self.create_table(name, engine=engine, schema=schema, temp=True) for file_path in paths: @@ -609,7 +609,7 @@ def read_csv( paths = list(glob.glob(str(path))) schema = PyArrowSchema.to_ibis(ds.dataset(paths, format="csv").schema) - name = table_name or util.gen_name("read_csv") + name = table_name or util.gen_name_from_path(paths[0], "csv") table = self.create_table(name, engine=engine, schema=schema, temp=True) for file_path in paths: diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index 733d6a772b48..c2350dd128ff 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -28,7 +28,7 @@ from ibis.common.dispatch import lazy_singledispatch from ibis.expr.operations.udf import InputType from ibis.formats.pyarrow import PyArrowSchema, PyArrowType -from ibis.util import deprecated, gen_name, normalize_filename, normalize_filenames +from ibis.util import deprecated, normalize_filename, normalize_filenames try: from datafusion import ExecutionContext as SessionContext @@ -160,7 +160,7 @@ def _safe_raw_sql(self, sql: sge.Statement) -> Any: yield self.raw_sql(sql).collect() def _get_schema_using_query(self, query: str) -> sch.Schema: - name = gen_name("datafusion_metadata_view") + name = util.gen_name("datafusion_metadata_view") table = sg.table(name, quoted=self.compiler.quoted) src = sge.Create( this=table, @@ -437,11 +437,11 @@ def read_csv( The just-registered table """ - path = normalize_filenames(source_list) - table_name = table_name or gen_name("read_csv") + paths = normalize_filenames(source_list) + table_name = table_name or util.gen_name_from_path(paths[0], "csv") # Our other backends support overwriting views / tables when re-registering self.con.deregister_table(table_name) - self.con.register_csv(table_name, path, **kwargs) + self.con.register_csv(table_name, paths, **kwargs) return self.table(table_name) def read_parquet( @@ -466,7 +466,7 @@ def read_parquet( """ path = normalize_filename(path) - table_name = table_name or gen_name("read_parquet") + table_name = table_name or util.gen_name_from_path(path, "parquet") # Our other backends support overwriting views / tables when reregistering self.con.deregister_table(table_name) self.con.register_parquet(table_name, path, **kwargs) @@ -496,7 +496,7 @@ def read_delta( """ source_table = normalize_filename(source_table) - table_name = table_name or gen_name("read_delta") + table_name = table_name or util.gen_name_from_path(source_table, "delta") # Our other backends support overwriting views / tables when reregistering self.con.deregister_table(table_name) @@ -730,55 +730,55 @@ def _read_in_memory( @_read_in_memory.register(dict) def _pydict(source, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("pydict") + tmp_name = util.gen_name("pydict") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.from_pydict(source, name=tmp_name) @_read_in_memory.register("polars.DataFrame") def _polars(source, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("polars") + tmp_name = util.gen_name("polars") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.from_polars(source, name=tmp_name) @_read_in_memory.register("polars.LazyFrame") def _polars(source, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("polars") + tmp_name = util.gen_name("polars") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.from_polars(source.collect(), name=tmp_name) @_read_in_memory.register("pyarrow.Table") def _pyarrow_table(source, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("pyarrow") + tmp_name = util.gen_name("pyarrow") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.from_arrow(source, name=tmp_name) @_read_in_memory.register("pyarrow.RecordBatchReader") def _pyarrow_rbr(source, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("pyarrow") + tmp_name = util.gen_name("pyarrow") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.from_arrow(source.read_all(), name=tmp_name) @_read_in_memory.register("pyarrow.RecordBatch") def _pyarrow_rb(source, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("pyarrow") + tmp_name = util.gen_name("pyarrow") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.register_record_batches(tmp_name, [[source]]) @_read_in_memory.register("pyarrow.dataset.Dataset") def _pyarrow_rb(source, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("pyarrow") + tmp_name = util.gen_name("pyarrow") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.register_dataset(tmp_name, source) @_read_in_memory.register("pandas.DataFrame") def _pandas(source: pd.DataFrame, table_name, _conn, overwrite: bool = False): - tmp_name = gen_name("pandas") + tmp_name = util.gen_name("pandas") with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): _conn.con.from_pandas(source, name=tmp_name) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 03c463907443..6637def85509 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -588,8 +588,9 @@ def read_json( An ibis table expression """ + filenames = util.normalize_filenames(source_list) if not table_name: - table_name = util.gen_name("read_json") + table_name = util.gen_name_from_path(filenames[0], "json") options = [ sg.to_identifier(key).eq(sge.convert(val)) for key, val in kwargs.items() @@ -612,11 +613,7 @@ def read_json( self._create_temp_view( table_name, - sg.select(STAR).from_( - self.compiler.f.read_json_auto( - util.normalize_filenames(source_list), *options - ) - ), + sg.select(STAR).from_(self.compiler.f.read_json_auto(filenames, *options)), ) return self.table(table_name) @@ -703,7 +700,7 @@ def read_csv( source_list = util.normalize_filenames(source_list) if not table_name: - table_name = util.gen_name("read_csv") + table_name = util.gen_name_from_path(source_list[0], "csv") # auto_detect and columns collide, so we set auto_detect=True # unless COLUMNS has been specified @@ -779,10 +776,6 @@ def read_geo( The just-registered table """ - - if not table_name: - table_name = util.gen_name("read_geo") - # load geospatial extension self.load_extension("spatial") @@ -790,6 +783,9 @@ def read_geo( if source.startswith(("http://", "https://", "s3://")): self._load_extensions(["httpfs"]) + if not table_name: + table_name = util.gen_name_from_path(source, "geo") + source_expr = sg.select(STAR).from_( self.compiler.f.st_read( source, @@ -835,7 +831,7 @@ def read_parquet( """ source_list = util.normalize_filenames(source_list) - table_name = table_name or util.gen_name("read_parquet") + table_name = table_name or util.gen_name_from_path(source_list[0], "parquet") # Default to using the native duckdb parquet reader # If that fails because of auth issues, fall back to ingesting via @@ -944,7 +940,7 @@ def read_delta( """ source_table = util.normalize_filenames(source_table)[0] - table_name = table_name or util.gen_name("read_delta") + table_name = table_name or util.gen_name_from_path(source_table, "delta") try: from deltalake import DeltaTable diff --git a/ibis/backends/flink/__init__.py b/ibis/backends/flink/__init__.py index b411118de25a..a798bc693196 100644 --- a/ibis/backends/flink/__init__.py +++ b/ibis/backends/flink/__init__.py @@ -27,7 +27,6 @@ from ibis.backends.sql import SQLBackend from ibis.backends.tests.errors import Py4JJavaError from ibis.expr.operations.udf import InputType -from ibis.util import gen_name if TYPE_CHECKING: from collections.abc import Mapping @@ -767,7 +766,7 @@ def _read_file( f"`schema` must be explicitly provided when calling `read_{file_type}`" ) - table_name = table_name or gen_name(f"read_{file_type}") + table_name = table_name or util.gen_name_from_path(path, file_type) tbl_properties = { "connector": "filesystem", "path": path, diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index ff65f5fb3876..0f4cc79ed91c 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -12,6 +12,7 @@ import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir +from ibis import util from ibis.backends import BaseBackend, NoUrl from ibis.backends.polars.compiler import translate from ibis.backends.polars.rewrites import bind_unbound_table, rewrite_join @@ -19,7 +20,6 @@ from ibis.common.dispatch import lazy_singledispatch from ibis.expr.rewrites import lower_stringslice, replace_parameter from ibis.formats.polars import PolarsSchema -from ibis.util import deprecated, gen_name, normalize_filename, normalize_filenames if TYPE_CHECKING: from collections.abc import Iterable @@ -100,7 +100,7 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: def _finalize_memtable(self, name: str) -> None: self.drop_table(name, force=True) - @deprecated( + @util.deprecated( as_of="9.1", instead="use the explicit `read_*` method for the filetype you are trying to read, e.g., read_parquet, read_csv, etc.", ) @@ -209,12 +209,12 @@ def read_csv( The just-registered table """ - source_list = normalize_filenames(path) + source_list = util.normalize_filenames(path) + table_name = table_name or util.gen_name_from_path(source_list[0], "csv") # Flatten the list if there's only one element because Polars # can't handle glob strings, or compressed CSVs in a single-element list if len(source_list) == 1: source_list = source_list[0] - table_name = table_name or gen_name("read_csv") try: table = pl.scan_csv(source_list, **kwargs) # triggers a schema computation to handle compressed csv inference @@ -250,8 +250,8 @@ def read_json( The just-registered table """ - path = normalize_filename(path) - table_name = table_name or gen_name("read_json") + path = util.normalize_filename(path) + table_name = table_name or util.gen_name_from_path(path, "json") try: self._add_table(table_name, pl.scan_ndjson(path, **kwargs)) except pl.exceptions.ComputeError: @@ -290,8 +290,8 @@ def read_delta( "read_delta method. You can install it using pip:\n\n" "pip install 'ibis-framework[polars,deltalake]'\n" ) - path = normalize_filename(path) - table_name = table_name or gen_name("read_delta") + path = util.normalize_filename(path) + table_name = table_name or util.gen_name_from_path(path, "delta") self._add_table(table_name, pl.scan_delta(path, **kwargs)) return self.table(table_name) @@ -318,7 +318,7 @@ def read_pandas( The just-registered table """ - table_name = table_name or gen_name("read_in_memory") + table_name = table_name or util.gen_name("read_in_memory") self._add_table(table_name, pl.from_pandas(source, **kwargs).lazy()) return self.table(table_name) @@ -351,24 +351,21 @@ def read_parquet( The just-registered table """ - table_name = table_name or gen_name("read_parquet") - if not isinstance(path, (str, Path)) and len(path) == 1: - path = path[0] + paths = util.normalize_filenames(path) + table_name = table_name or util.gen_name_from_path(paths[0], "parquet") - if not isinstance(path, (str, Path)) and len(path) > 1: + if len(paths) > 1: self._import_pyarrow() import pyarrow.dataset as ds - paths = [normalize_filename(p) for p in path] obj = pl.scan_pyarrow_dataset( source=ds.dataset(paths, format="parquet"), **kwargs, ) - self._add_table(table_name, obj) else: - path = normalize_filename(path) - self._add_table(table_name, pl.scan_parquet(path, **kwargs)) + obj = pl.scan_parquet(paths[0], **kwargs) + self._add_table(table_name, obj) return self.table(table_name) def create_table( diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index f8d9cfffb3c6..0612fa70f325 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -790,7 +790,7 @@ def read_delta( ) path = util.normalize_filename(path) spark_df = self._session.read.format("delta").load(path, **kwargs) - table_name = table_name or util.gen_name("read_delta") + table_name = table_name or util.gen_name_from_path(path, "delta") spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -827,7 +827,7 @@ def read_parquet( ) path = util.normalize_filename(path) spark_df = self._session.read.parquet(path, **kwargs) - table_name = table_name or util.gen_name("read_parquet") + table_name = table_name or util.gen_name_from_path(path, "parquet") spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -869,7 +869,7 @@ def read_csv( spark_df = self._session.read.csv( source_list, inferSchema=inferSchema, header=header, **kwargs ) - table_name = table_name or util.gen_name("read_csv") + table_name = table_name or util.gen_name_from_path(source_list[0], "csv") spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -907,7 +907,7 @@ def read_json( ) source_list = util.normalize_filenames(source_list) spark_df = self._session.read.json(source_list, **kwargs) - table_name = table_name or util.gen_name("read_json") + table_name = table_name or util.gen_name_from_path(source_list[0], "json") spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -1217,7 +1217,7 @@ def read_csv_dir( watermark.time_col, _interval_to_string(watermark.allowed_delay), ) - table_name = table_name or util.gen_name("read_csv_dir") + table_name = table_name or util.gen_name_from_path(path, "csv_dir") spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -1272,7 +1272,7 @@ def read_parquet_dir( watermark.time_col, _interval_to_string(watermark.allowed_delay), ) - table_name = table_name or util.gen_name("read_parquet_dir") + table_name = table_name or util.gen_name_from_path(path, "parquet_dir") spark_df.createOrReplaceTempView(table_name) return self.table(table_name) @@ -1318,7 +1318,7 @@ def read_json_dir( watermark.time_col, _interval_to_string(watermark.allowed_delay), ) - table_name = table_name or util.gen_name("read_json_dir") + table_name = table_name or util.gen_name_from_path(path, "json_dir") spark_df.createOrReplaceTempView(table_name) return self.table(table_name) diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index f60e06fdc162..0c6fe79a3bc0 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -883,7 +883,7 @@ def read_csv( # 99 is the maximum allowed number of threads by Snowflake: # https://docs.snowflake.com/en/sql-reference/sql/put#optional-parameters threads = min((os.cpu_count() or 2) // 2, 99) - table = table_name or ibis.util.gen_name("read_csv_snowflake") + table = table_name or ibis.util.gen_name_from_path(path, "csv_snowflake") compiler = self.compiler quoted = compiler.quoted qtable = sg.to_identifier(table, quoted=quoted) @@ -1010,7 +1010,7 @@ def read_json( """ stage = util.gen_name("read_json_stage") file_format = util.gen_name("read_json_format") - table = table_name or util.gen_name("read_json_snowflake") + table = table_name or util.gen_name_from_path(path, "json_snowflake") quoted = self.compiler.quoted qtable = sg.table(table, quoted=quoted) threads = min((os.cpu_count() or 2) // 2, 99) @@ -1107,7 +1107,7 @@ def read_parquet( ) stage = util.gen_name("read_parquet_stage") - table = table_name or util.gen_name("read_parquet_snowflake") + table = table_name or util.gen_name_from_path(abspath, "parquet_snowflake") quoted = self.compiler.quoted qtable = sg.table(table, quoted=quoted) threads = min((os.cpu_count() or 2) // 2, 99) diff --git a/ibis/util.py b/ibis/util.py index 993e539dfa88..84b420e084fa 100644 --- a/ibis/util.py +++ b/ibis/util.py @@ -522,6 +522,21 @@ def gen_name(namespace: str) -> str: return f"ibis_{namespace}_{uid}" +def gen_name_from_path(path: str | Path, file_type: str) -> str: + """Create a user-friendly unique identifier from a file path. + + Examples + -------- + >>> gen_name_from_file("path/to/myfile.csv", "csv") + 'read_csv_myfile_1a2b3c4d' + """ + root, _ext = os.path.splitext(path) + basename = os.path.basename(root) + basename = re.sub(r"[^a-zA-Z0-9_]", "_", basename).strip("_") + prefix = f"read_{file_type}_{basename}" + return gen_name(prefix) + + def slice_to_limit_offset( what: slice, count: ir.IntegerScalar ) -> tuple[int | ir.IntegerScalar, int | ir.IntegerScalar]: