From 5cf20d35bd5775d782b2a7cce3ebe0f83dfdcf96 Mon Sep 17 00:00:00 2001 From: skshetry <18718008+skshetry@users.noreply.github.com> Date: Wed, 17 Jul 2024 10:23:07 +0545 Subject: [PATCH] Implement sys feature, and rename id/random columns (#28) --- examples/pose_detection.py | 3 +- src/datachain/__init__.py | 3 +- src/datachain/catalog/catalog.py | 10 ++-- src/datachain/cli.py | 2 - src/datachain/data_storage/schema.py | 10 ++-- src/datachain/data_storage/sqlite.py | 2 +- src/datachain/data_storage/warehouse.py | 14 ++--- src/datachain/lib/dc.py | 61 ++++++++++++++----- src/datachain/lib/image_transform.py | 1 - src/datachain/lib/signal_schema.py | 9 +++ src/datachain/node.py | 8 +-- src/datachain/query/batch.py | 2 +- src/datachain/query/dataset.py | 67 +++++++++++---------- src/datachain/query/udf.py | 4 +- src/datachain/remote/studio.py | 10 +--- tests/conftest.py | 2 +- tests/func/test_dataset_query.py | 80 ++++++++++++------------- tests/func/test_datasets.py | 4 +- tests/func/test_pull.py | 6 +- tests/scripts/feature_class_parallel.py | 2 +- tests/unit/lib/test_datachain.py | 33 +++++++++- tests/unit/test_dataset.py | 6 +- tests/unit/test_listing.py | 4 +- tests/unit/test_udf.py | 66 ++++---------------- 24 files changed, 213 insertions(+), 196 deletions(-) diff --git a/examples/pose_detection.py b/examples/pose_detection.py index d26099e0c..cd3897d63 100644 --- a/examples/pose_detection.py +++ b/examples/pose_detection.py @@ -130,7 +130,6 @@ def __call__( return # CLeanup to records - del record["random"] # random will be populated automatically record["is_latest"] = record["is_latest"] > 0 # needs to be a bool row = DatasetRow.create(**record) @@ -211,7 +210,7 @@ def __call__( data = ( DatasetQuery(os.path.join(cloud_prefix, bucket)) .filter(C.name.glob(file_type)) - .filter(C.random % filter_mod == chunk_num) + .filter(C.sys__rand % filter_mod == chunk_num) .generate(pose_udf) .results() ) diff --git a/src/datachain/__init__.py b/src/datachain/__init__.py index 712e1300e..95abb4d9f 100644 --- a/src/datachain/__init__.py +++ b/src/datachain/__init__.py @@ -1,5 +1,5 @@ from datachain.lib.data_model import DataModel, DataType, FileBasic, is_chain_type -from datachain.lib.dc import C, Column, DataChain +from datachain.lib.dc import C, Column, DataChain, Sys from datachain.lib.file import ( File, FileError, @@ -31,6 +31,7 @@ "IndexedFile", "Mapper", "Session", + "Sys", "TarVFile", "TextFile", "is_chain_type", diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index c5be5ed14..b42e9f65c 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -256,7 +256,7 @@ def do_task(self, urls): self.fix_columns(df) # id will be autogenerated in DB - df = df.drop("id", axis=1) + df = df.drop("sys__id", axis=1) inserted = warehouse.insert_dataset_rows( df, dataset, self.dataset_version @@ -1041,7 +1041,7 @@ def create_dataset( If version is None, then next unused version is created. If version is given, then it must be an unused version number. """ - assert [c.name for c in columns if c.name != "id"], f"got {columns=}" + assert [c.name for c in columns if c.name != "sys__id"], f"got {columns=}" if not listing and Client.is_data_source_uri(name): raise RuntimeError( "Cannot create dataset that starts with source prefix, e.g s3://" @@ -1103,7 +1103,7 @@ def create_new_dataset_version( Creates dataset version if it doesn't exist. If create_rows is False, dataset rows table will not be created """ - assert [c.name for c in columns if c.name != "id"], f"got {columns=}" + assert [c.name for c in columns if c.name != "sys__id"], f"got {columns=}" schema = { c.name: c.type.to_dict() for c in columns if isinstance(c.type, SQLType) } @@ -1433,7 +1433,7 @@ def ls_dataset_rows( if offset: q = q.offset(offset) - q = q.order_by("id") + q = q.order_by("sys__id") return q.to_records() @@ -1786,7 +1786,7 @@ def _instantiate_dataset(): schema = DatasetRecord.parse_schema(remote_dataset_version.schema) columns = tuple( - sa.Column(name, typ) for name, typ in schema.items() if name != "id" + sa.Column(name, typ) for name, typ in schema.items() if name != "sys__id" ) # creating new dataset (version) locally dataset = self.create_dataset( diff --git a/src/datachain/cli.py b/src/datachain/cli.py index e783da6c2..87aa9a64c 100644 --- a/src/datachain/cli.py +++ b/src/datachain/cli.py @@ -811,8 +811,6 @@ def show( from datachain.query import DatasetQuery from datachain.utils import show_records - if columns: - columns = ("id", *columns) query = ( DatasetQuery(name=name, version=version, catalog=catalog) .select(*columns) diff --git a/src/datachain/data_storage/schema.py b/src/datachain/data_storage/schema.py index d51c575e0..c2e3f7e25 100644 --- a/src/datachain/data_storage/schema.py +++ b/src/datachain/data_storage/schema.py @@ -72,7 +72,7 @@ class DirExpansion: @staticmethod def base_select(q): return sa.select( - q.c.id, + q.c.sys__id, q.c.vtype, (q.c.dir_type == DirType.DIR).label("is_dir"), q.c.source, @@ -86,7 +86,7 @@ def base_select(q): def apply_group_by(q): return ( sa.select( - f.min(q.c.id).label("id"), + f.min(q.c.sys__id).label("sys__id"), q.c.vtype, q.c.is_dir, q.c.source, @@ -111,7 +111,7 @@ def query(cls, q): parent_name = path.name(q.c.parent) q = q.union_all( sa.select( - sa.literal(-1).label("id"), + sa.literal(-1).label("sys__id"), sa.literal("").label("vtype"), true().label("is_dir"), q.c.source, @@ -233,9 +233,9 @@ def delete(self): @staticmethod def sys_columns(): return [ - sa.Column("id", Int, primary_key=True), + sa.Column("sys__id", Int, primary_key=True), sa.Column( - "random", UInt64, nullable=False, server_default=f.abs(f.random()) + "sys__rand", UInt64, nullable=False, server_default=f.abs(f.random()) ), ] diff --git a/src/datachain/data_storage/sqlite.py b/src/datachain/data_storage/sqlite.py index 60f5c67a8..0c6f493cc 100644 --- a/src/datachain/data_storage/sqlite.py +++ b/src/datachain/data_storage/sqlite.py @@ -631,7 +631,7 @@ def merge_dataset_rows( dst_empty = True dst_dr = self.dataset_rows(dst, dst_version).table - merge_fields = [c.name for c in src_dr.c if c.name != "id"] + merge_fields = [c.name for c in src_dr.c if c.name != "sys__id"] select_src = select(*(getattr(src_dr.c, f) for f in merge_fields)) if dst_empty: diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index 27588e67b..7c396a981 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -195,7 +195,7 @@ def dataset_select_paginated( cols_names = [c.name for c in cols] if not order_by: - ordering = [cols.id] + ordering = [cols.sys__id] else: ordering = order_by # type: ignore[assignment] @@ -372,7 +372,7 @@ def dataset_rows_count(self, dataset: DatasetRecord, version=None) -> int: """Returns total number of rows in a dataset""" dr = self.dataset_rows(dataset, version) table = dr.get_table() - query = select(sa.func.count(table.c.id)) + query = select(sa.func.count(table.c.sys__id)) (res,) = self.db.execute(query) return res[0] @@ -388,7 +388,7 @@ def dataset_stats( dr = self.dataset_rows(dataset, version) table = dr.get_table() expressions: tuple[_ColumnsClauseArgument[Any], ...] = ( - sa.func.count(table.c.id), + sa.func.count(table.c.sys__id), ) if "size" in table.columns: expressions = (*expressions, sa.func.sum(table.c.size)) @@ -607,7 +607,7 @@ def with_default(column): return func.coalesce(column, default).label(column.name) return sa.select( - de.c.id, + de.c.sys__id, with_default(dr.c.vtype), case((de.c.is_dir == true(), DirType.DIR), else_=dr.c.dir_type).label( "dir_type" @@ -621,10 +621,10 @@ def with_default(column): with_default(dr.c.size), with_default(dr.c.owner_name), with_default(dr.c.owner_id), - with_default(dr.c.random), + with_default(dr.c.sys__rand), dr.c.location, de.c.source, - ).select_from(de.outerjoin(dr.table, de.c.id == dr.c.id)) + ).select_from(de.outerjoin(dr.table, de.c.sys__id == dr.c.sys__id)) def get_node_by_path(self, dataset_rows: "DataTable", path: str) -> Node: """Gets node that corresponds to some path""" @@ -878,7 +878,7 @@ def create_udf_table( tbl = sa.Table( name, sa.MetaData(), - sa.Column("id", Int, primary_key=True), + sa.Column("sys__id", Int, primary_key=True), *columns, ) self.db.create_table(tbl, if_not_exists=True) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index e3dd869a5..47ece6154 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1,3 +1,4 @@ +import copy import re from collections.abc import Iterator, Sequence from typing import ( @@ -72,6 +73,11 @@ def __init__(self, on: Sequence[str], right_on: Optional[Sequence[str]], msg: st OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]] +class Sys(DataModel): + id: int + rand: int + + class DataChain(DatasetQuery): """AI 🔗 DataChain - a data structure for batch data processing and evaluation. @@ -124,12 +130,10 @@ class Rating(Feature): """ DEFAULT_FILE_RECORD: ClassVar[dict] = { - "id": 0, "source": "", "name": "", "vtype": "", "size": 0, - "random": 0, } def __init__(self, *args, **kwargs): @@ -155,8 +159,19 @@ def schema(self): def print_schema(self): self.signals_schema.print_tree() + def clone(self, new_table: bool = True) -> "Self": + obj = super().clone(new_table=new_table) + obj.signals_schema = copy.deepcopy(self.signals_schema) + return obj + def settings( - self, cache=None, batch=None, parallel=None, workers=None, min_task_size=None + self, + cache=None, + batch=None, + parallel=None, + workers=None, + min_task_size=None, + include_sys: Optional[bool] = None, ) -> "Self": """Change settings for chain. @@ -180,8 +195,13 @@ def settings( ) ``` """ - self._settings.add(Settings(cache, batch, parallel, workers, min_task_size)) - return self + chain = self.clone() + if include_sys is True: + chain.signals_schema = SignalSchema({"sys": Sys}) | chain.signals_schema + elif include_sys is False and "sys" in chain.signals_schema: + chain.signals_schema.remove("sys") + chain._settings.add(Settings(cache, batch, parallel, workers, min_task_size)) + return chain def reset_settings(self, settings: Optional[Settings] = None) -> "Self": """Reset all settings to default values.""" @@ -193,12 +213,11 @@ def reset_schema(self, signals_schema: SignalSchema) -> "Self": return self def add_schema(self, signals_schema: SignalSchema) -> "Self": - union = self.signals_schema.values | signals_schema.values - self.signals_schema = SignalSchema(union) + self.signals_schema |= signals_schema return self def get_file_signals(self) -> list[str]: - return self.signals_schema.get_file_signals() + return list(self.signals_schema.get_file_signals()) @classmethod def from_storage( @@ -352,6 +371,7 @@ def save( # type: ignore[override] version : version of a dataset. Default - the last version that exist. """ schema = self.signals_schema.serialize() + schema.pop("sys", None) return super().save(name=name, version=version, feature_schema=schema) def apply(self, func, *args, **kwargs): @@ -528,6 +548,20 @@ def select_except(self, *args: str) -> "Self": chain.signals_schema = new_schema return chain + def iterate_flatten(self) -> Iterator[tuple[Any]]: + db_signals = self.signals_schema.db_signals() + with super().select(*db_signals).as_iterable() as rows: + yield from rows + + def results( + self, row_factory: Optional[Callable] = None, **kwargs + ) -> list[tuple[Any, ...]]: + rows = self.iterate_flatten() + if row_factory: + db_signals = self.signals_schema.db_signals() + rows = (row_factory(db_signals, r) for r in rows) + return list(rows) + def iterate(self, *cols: str) -> Iterator[list[DataType]]: """Iterate over rows. @@ -535,13 +569,10 @@ def iterate(self, *cols: str) -> Iterator[list[DataType]]: columns. """ chain = self.select(*cols) if cols else self - - db_signals = chain.signals_schema.db_signals() - with super().select(*db_signals).as_iterable() as rows_iter: - for row in rows_iter: - yield chain.signals_schema.row_to_features( - row, catalog=chain.session.catalog, cache=chain._settings.cache - ) + for row in chain.iterate_flatten(): + yield chain.signals_schema.row_to_features( + row, catalog=chain.session.catalog, cache=chain._settings.cache + ) def iterate_one(self, col: str) -> Iterator[DataType]: for item in self.iterate(col): diff --git a/src/datachain/lib/image_transform.py b/src/datachain/lib/image_transform.py index 958311ed0..0c33c9c43 100644 --- a/src/datachain/lib/image_transform.py +++ b/src/datachain/lib/image_transform.py @@ -66,7 +66,6 @@ def __call__( ): # Build a dict from row contents record = dict(zip(DatasetRow.schema.keys(), args)) - del record["random"] # random will be populated automatically record["is_latest"] = record["is_latest"] > 0 # needs to be a bool # yield same row back diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index e58dd1729..1e7f233d8 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -331,6 +331,15 @@ def print_tree(self, indent: int = 4, start_at: int = 0): sub_schema = SignalSchema({"* list of": args[0]}) sub_schema.print_tree(indent=indent, start_at=total_indent + indent) + def __or__(self, other): + return self.__class__(self.values | other.values) + + def __contains__(self, name: str): + return name in self.values + + def remove(self, name: str): + return self.values.pop(name) + @staticmethod def _type_to_str(type_): # noqa: PLR0911 origin = get_origin(type_) diff --git a/src/datachain/node.py b/src/datachain/node.py index 7d5b7ad06..7e82ba1c8 100644 --- a/src/datachain/node.py +++ b/src/datachain/node.py @@ -46,8 +46,8 @@ class DirTypeGroup: @attrs.define class Node: - id: int = 0 - random: int = -1 + sys__id: int = 0 + sys__rand: int = -1 vtype: str = "" dir_type: Optional[int] = None parent: str = "" @@ -127,11 +127,11 @@ def from_dict(cls, d: dict[str, Any]) -> "Self": @classmethod def from_dir(cls, parent, name, **kwargs) -> "Node": - return cls(id=-1, dir_type=DirType.DIR, parent=parent, name=name, **kwargs) + return cls(sys__id=-1, dir_type=DirType.DIR, parent=parent, name=name, **kwargs) @classmethod def root(cls) -> "Node": - return cls(-1, dir_type=DirType.DIR) + return cls(sys__id=-1, dir_type=DirType.DIR) @attrs.define diff --git a/src/datachain/query/batch.py b/src/datachain/query/batch.py index a11d07151..8e6e15abb 100644 --- a/src/datachain/query/batch.py +++ b/src/datachain/query/batch.py @@ -104,7 +104,7 @@ def __call__( with contextlib.closing( execute( query, - order_by=(PARTITION_COLUMN_ID, "id", *query._order_by_clauses), + order_by=(PARTITION_COLUMN_ID, "sys__id", *query._order_by_clauses), limit=query._limit, ) ) as rows: diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 12eb2da88..e88c71c12 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -257,7 +257,7 @@ def query( """ def apply(self, query_generator, temp_tables: list[str]): - source_query = query_generator.exclude(("id",)) + source_query = query_generator.exclude(("sys__id",)) target_query = self.dq.apply_steps().select() temp_tables.extend(self.dq.temp_table_names) @@ -640,7 +640,7 @@ def create_partitions_table(self, query: Select) -> "Table": # fill table with partitions cols = [ - query.selected_columns.id, + query.selected_columns.sys__id, f.dense_rank().over(order_by=list_partition_by).label(PARTITION_COLUMN_ID), ] self.catalog.warehouse.db.execute( @@ -674,7 +674,7 @@ def apply( subq = query.subquery() query = ( sqlalchemy.select(*subq.c) - .outerjoin(partition_tbl, partition_tbl.c.id == subq.c.id) + .outerjoin(partition_tbl, partition_tbl.c.sys__id == subq.c.sys__id) .add_columns(*partition_columns()) ) @@ -706,18 +706,18 @@ def create_pre_udf_table(self, query: Select) -> "Table": columns = [ sqlalchemy.Column(c.name, c.type) for c in query.selected_columns - if c.name != "id" + if c.name != "sys__id" ] table = self.catalog.warehouse.create_udf_table(self.udf_table_name(), columns) select_q = query.with_only_columns( - *[c for c in query.selected_columns if c.name != "id"] + *[c for c in query.selected_columns if c.name != "sys__id"] ) # if there is order by clause we need row_number to preserve order # if there is no order by clause we still need row_number to generate # unique ids as uniqueness is important for this table select_q = select_q.add_columns( - f.row_number().over(order_by=select_q._order_by_clauses).label("id") + f.row_number().over(order_by=select_q._order_by_clauses).label("sys__id") ) self.catalog.warehouse.db.execute( @@ -733,7 +733,7 @@ def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]: if query._order_by_clauses: # we are adding ordering only if it's explicitly added by user in # query part before adding signals - q = q.order_by(table.c.id) + q = q.order_by(table.c.sys__id) return q, [table] def create_result_query( @@ -743,7 +743,7 @@ def create_result_query( original_cols = [c for c in subq.c if c.name not in partition_col_names] # new signal columns that are added to udf_table - signal_cols = [c for c in udf_table.c if c.name != "id"] + signal_cols = [c for c in udf_table.c if c.name != "sys__id"] signal_name_cols = {c.name: c for c in signal_cols} cols = signal_cols @@ -763,7 +763,7 @@ def q(*columns): res = ( sqlalchemy.select(*cols1) .select_from(subq) - .outerjoin(udf_table, udf_table.c.id == subq.c.id) + .outerjoin(udf_table, udf_table.c.sys__id == subq.c.sys__id) .add_columns(*cols2) ) else: @@ -772,7 +772,7 @@ def q(*columns): if query._order_by_clauses: # if ordering is used in query part before adding signals, we # will have it as order by id from select from pre-created udf table - res = res.order_by(subq.c.id) + res = res.order_by(subq.c.sys__id) if self.partition_by is not None: subquery = res.subquery() @@ -810,7 +810,7 @@ def create_result_query( # we get the same rows as we got as inputs of UDF since selecting # without ordering can be non deterministic in some databases c = query.selected_columns - query = query.order_by(c.id) + query = query.order_by(c.sys__id) udf_table_query = udf_table.select().subquery() udf_table_cols: list[sqlalchemy.Label[Any]] = [ @@ -1002,7 +1002,7 @@ def apply( q1_column_names = {c.name for c in q1_columns} q2_columns = [ c - if c.name not in q1_column_names and c.name != "id" + if c.name not in q1_column_names and c.name != "sys__id" else c.label(self.rname.format(name=c.name)) for c in q2.c ] @@ -1142,8 +1142,8 @@ def __init__( self.version = version or ds.latest_version self.feature_schema = ds.get_version(self.version).feature_schema self.column_types = copy(ds.schema) - if "id" in self.column_types: - self.column_types.pop("id") + if "sys__id" in self.column_types: + self.column_types.pop("sys__id") self.starting_step = QueryStep(self.catalog, name, self.version) # attaching to specific dataset self.name = name @@ -1216,7 +1216,7 @@ def apply_steps(self) -> QueryGenerator: query.steps = self._chunk_limit(query.steps, index, total) # Prepend the chunk filter to the step chain. - query = query.filter(C.random % total == index) + query = query.filter(C.sys__rand % total == index) query.steps = query.steps[-1:] + query.steps[:-1] result = query.starting_step.apply() @@ -1343,10 +1343,8 @@ async def get_params(row: RowDict) -> tuple: finally: self.cleanup() - def to_records(self) -> list[dict]: - with self.as_iterable() as result: - cols = result.columns - return [dict(zip(cols, row)) for row in result] + def to_records(self) -> list[dict[str, Any]]: + return self.results(lambda cols, row: dict(zip(cols, row))) def to_pandas(self) -> "pd.DataFrame": records = self.to_records() @@ -1356,7 +1354,7 @@ def to_pandas(self) -> "pd.DataFrame": def shuffle(self) -> "Self": # ToDo: implement shaffle based on seed and/or generating random column - return self.order_by(C.random) + return self.order_by(C.sys__rand) def sample(self, n) -> "Self": """ @@ -1485,30 +1483,35 @@ def offset(self, offset: int) -> "Self": query.steps.append(SQLOffset(offset)) return query + def as_scalar(self) -> Any: + with self.as_iterable() as rows: + row = next(iter(rows)) + return row[0] + def count(self) -> int: query = self.clone() query.steps.append(SQLCount()) - return query.results()[0][0] + return query.as_scalar() - def sum(self, col: ColumnElement): + def sum(self, col: ColumnElement) -> int: query = self.clone() query.steps.append(SQLSelect((f.sum(col),))) - return query.results()[0][0] + return query.as_scalar() - def avg(self, col: ColumnElement): + def avg(self, col: ColumnElement) -> int: query = self.clone() query.steps.append(SQLSelect((f.avg(col),))) - return query.results()[0][0] + return query.as_scalar() - def min(self, col: ColumnElement): + def min(self, col: ColumnElement) -> int: query = self.clone() query.steps.append(SQLSelect((f.min(col),))) - return query.results()[0][0] + return query.as_scalar() - def max(self, col: ColumnElement): + def max(self, col: ColumnElement) -> int: query = self.clone() query.steps.append(SQLSelect((f.max(col),))) - return query.results()[0][0] + return query.as_scalar() @detach def group_by(self, *cols: ColumnElement) -> "Self": @@ -1700,7 +1703,7 @@ def save( c if isinstance(c, Column) else Column(c.name, c.type) for c in query.columns ] - if not [c for c in columns if c.name != "id"]: + if not [c for c in columns if c.name != "sys__id"]: raise RuntimeError( "No columns to save in the query. " "Ensure at least one column (other than 'id') is selected." @@ -1719,11 +1722,11 @@ def save( # Exclude the id column and let the db create it to avoid unique # constraint violations. - q = query.exclude(("id",)) + q = query.exclude(("sys__id",)) if q._order_by_clauses: # ensuring we have id sorted by order by clause if it exists in a query q = q.add_columns( - f.row_number().over(order_by=q._order_by_clauses).label("id") + f.row_number().over(order_by=q._order_by_clauses).label("sys__id") ) cols = tuple(c.name for c in q.columns) diff --git a/src/datachain/query/udf.py b/src/datachain/query/udf.py index 21d7cb330..e0f74a52f 100644 --- a/src/datachain/query/udf.py +++ b/src/datachain/query/udf.py @@ -147,9 +147,9 @@ def _process_results( return (dict(zip(self.signal_names, row)) for row in results) # outputting signals - row_ids = [row["id"] for row in rows] + row_ids = [row["sys__id"] for row in rows] return [ - dict(id=row_id, **dict(zip(self.signal_names, signals))) + {"sys__id": row_id} | dict(zip(self.signal_names, signals)) for row_id, signals in zip(row_ids, results) if signals is not None # skip rows with no output ] diff --git a/src/datachain/remote/studio.py b/src/datachain/remote/studio.py index f2b85c55e..5934020e4 100644 --- a/src/datachain/remote/studio.py +++ b/src/datachain/remote/studio.py @@ -190,19 +190,11 @@ def _parse_dataset_info(dataset_info): def dataset_rows_chunk( self, name: str, version: int, offset: int ) -> Response[DatasetRowsData]: - def _parse_row(row): - row["id"] = int(row["id"]) - return row - req_data = {"dataset_name": name, "dataset_version": version} - response = self._send_request_msgpack( + return self._send_request_msgpack( "dataset-rows", {**req_data, "offset": offset, "limit": DATASET_ROWS_CHUNK_SIZE}, ) - if response.ok: - response.data = [_parse_row(r) for r in response.data] - - return response def dataset_stats(self, name: str, version: int) -> Response[DatasetStatsData]: response = self._send_request( diff --git a/tests/conftest.py b/tests/conftest.py index c79000ddd..bfc330458 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -539,7 +539,7 @@ def dataset_rows(): "owner_name": "aws-iterative-sandbox", "last_modified": "2024-02-23T10:42:31.842944+00:00", "size": 49807360, - "random": 12123123123, + "sys__rand": 12123123123, "int_col": 5, "int_col_32": 5, "int_col_64": 5, diff --git a/tests/func/test_dataset_query.py b/tests/func/test_dataset_query.py index 8111e1737..df04b4187 100644 --- a/tests/func/test_dataset_query.py +++ b/tests/func/test_dataset_query.py @@ -537,16 +537,16 @@ def test_row_number_with_order_by_name_descending(cloud_test_catalog): results = DatasetQuery(name=ds_name, catalog=catalog).to_records() results_name_id = [ - {k: v for k, v in r.items() if k in ["id", "name"]} for r in results + {k: v for k, v in r.items() if k in ["sys__id", "name"]} for r in results ] - assert sorted(results_name_id, key=lambda k: k["id"]) == [ - {"id": 1, "name": "dog4"}, - {"id": 2, "name": "dog3"}, - {"id": 3, "name": "dog2"}, - {"id": 4, "name": "dog1"}, - {"id": 5, "name": "description"}, - {"id": 6, "name": "cat2"}, - {"id": 7, "name": "cat1"}, + assert sorted(results_name_id, key=lambda k: k["sys__id"]) == [ + {"sys__id": 1, "name": "dog4"}, + {"sys__id": 2, "name": "dog3"}, + {"sys__id": 3, "name": "dog2"}, + {"sys__id": 4, "name": "dog1"}, + {"sys__id": 5, "name": "description"}, + {"sys__id": 6, "name": "cat2"}, + {"sys__id": 7, "name": "cat1"}, ] @@ -567,16 +567,16 @@ def test_row_number_with_order_by_name_ascending(cloud_test_catalog): results = DatasetQuery(name=ds_name, catalog=catalog).to_records() results_name_id = [ - {k: v for k, v in r.items() if k in ["id", "name"]} for r in results + {k: v for k, v in r.items() if k in ["sys__id", "name"]} for r in results ] - assert sorted(results_name_id, key=lambda k: k["id"]) == [ - {"id": 1, "name": "cat1"}, - {"id": 2, "name": "cat2"}, - {"id": 3, "name": "description"}, - {"id": 4, "name": "dog1"}, - {"id": 5, "name": "dog2"}, - {"id": 6, "name": "dog3"}, - {"id": 7, "name": "dog4"}, + assert sorted(results_name_id, key=lambda k: k["sys__id"]) == [ + {"sys__id": 1, "name": "cat1"}, + {"sys__id": 2, "name": "cat2"}, + {"sys__id": 3, "name": "description"}, + {"sys__id": 4, "name": "dog1"}, + {"sys__id": 5, "name": "dog2"}, + {"sys__id": 6, "name": "dog3"}, + {"sys__id": 7, "name": "dog4"}, ] @@ -601,16 +601,16 @@ def name_len(name): results = DatasetQuery(name=ds_name, catalog=catalog).to_records() results_name_id = [ - {k: v for k, v in r.items() if k in ["id", "name"]} for r in results + {k: v for k, v in r.items() if k in ["sys__id", "name"]} for r in results ] - assert sorted(results_name_id, key=lambda k: k["id"]) == [ - {"id": 1, "name": "description"}, - {"id": 2, "name": "cat1"}, - {"id": 3, "name": "cat2"}, - {"id": 4, "name": "dog1"}, - {"id": 5, "name": "dog2"}, - {"id": 6, "name": "dog3"}, - {"id": 7, "name": "dog4"}, + assert sorted(results_name_id, key=lambda k: k["sys__id"]) == [ + {"sys__id": 1, "name": "description"}, + {"sys__id": 2, "name": "cat1"}, + {"sys__id": 3, "name": "cat2"}, + {"sys__id": 4, "name": "dog1"}, + {"sys__id": 5, "name": "dog2"}, + {"sys__id": 6, "name": "dog3"}, + {"sys__id": 7, "name": "dog4"}, ] @@ -635,18 +635,18 @@ def name_len(name): results = DatasetQuery(name=ds_name, catalog=catalog).to_records() results_name_id = [ - {k: v for k, v in r.items() if k in ["id", "name"]} for r in results + {k: v for k, v in r.items() if k in ["sys__id", "name"]} for r in results ] # we should preserve order in final result based on order by which was added # before add_signals - assert sorted(results_name_id, key=lambda k: k["id"]) == [ - {"id": 1, "name": "cat1"}, - {"id": 2, "name": "cat2"}, - {"id": 3, "name": "description"}, - {"id": 4, "name": "dog1"}, - {"id": 5, "name": "dog2"}, - {"id": 6, "name": "dog3"}, - {"id": 7, "name": "dog4"}, + assert sorted(results_name_id, key=lambda k: k["sys__id"]) == [ + {"sys__id": 1, "name": "cat1"}, + {"sys__id": 2, "name": "cat2"}, + {"sys__id": 3, "name": "description"}, + {"sys__id": 4, "name": "dog1"}, + {"sys__id": 5, "name": "dog2"}, + {"sys__id": 6, "name": "dog3"}, + {"sys__id": 7, "name": "dog4"}, ] @@ -1366,8 +1366,8 @@ def test_extract_limit(cloud_test_catalog, dogs_dataset): def test_extract_order_by(cloud_test_catalog, dogs_dataset): catalog = cloud_test_catalog.catalog q = DatasetQuery(name=dogs_dataset.name, version=1, catalog=catalog) - results = list(q.order_by("random").extract("name")) - pairs = list(q.extract("random", "name")) + results = list(q.order_by("sys__rand").extract("name")) + pairs = list(q.extract("sys__rand", "name")) assert results == [(p[1],) for p in sorted(pairs)] @@ -2770,7 +2770,7 @@ def test_simple_dataset_query(cloud_test_catalog): ds1, ds2 = ( [ - {k.name: v for k, v in zip(q.selected_columns, r) if k.name != "id"} + {k.name: v for k, v in zip(q.selected_columns, r) if k.name != "sys__id"} for r in warehouse.db.execute(q) ] for q in ds_queries @@ -3413,7 +3413,7 @@ def get_result(query): # incorrect results on clickhouse cloud. # See https://github.com/iterative/dvcx/issues/940 assert get_result(ds.order_by("name")) == expected - assert len(get_result(ds.order_by("random"))) == 100 + assert len(get_result(ds.order_by("sys__rand"))) == 100 assert len(get_result(ds)) == 100 diff --git a/tests/func/test_datasets.py b/tests/func/test_datasets.py index 11e78af80..29c898d89 100644 --- a/tests/func/test_datasets.py +++ b/tests/func/test_datasets.py @@ -892,7 +892,7 @@ def test_row_random(cloud_test_catalog): catalog = ctc.catalog catalog.index([ctc.src_uri]) catalog.create_dataset_from_sources("test", [ctc.src_uri]) - random_values = [row["random"] for row in catalog.ls_dataset_rows("test", 1)] + random_values = [row["sys__rand"] for row in catalog.ls_dataset_rows("test", 1)] # Random values are unique assert len(set(random_values)) == len(random_values) @@ -908,7 +908,7 @@ def test_row_random(cloud_test_catalog): # Creating a new dataset preserves random values catalog.create_dataset_from_sources("test2", [ctc.src_uri]) - random_values2 = {row["random"] for row in catalog.ls_dataset_rows("test2", 1)} + random_values2 = {row["sys__rand"] for row in catalog.ls_dataset_rows("test2", 1)} assert random_values2 == set(random_values) diff --git a/tests/func/test_pull.py b/tests/func/test_pull.py index 6e5391f5b..bce057e30 100644 --- a/tests/func/test_pull.py +++ b/tests/func/test_pull.py @@ -43,12 +43,12 @@ def _adapt_row(row): else: adapted[k] = v - adapted["id"] = 1 + adapted["sys__id"] = 1 + adapted["sys__rand"] = 1 adapted["vtype"] = b"" adapted["location"] = b"" adapted["source"] = b"s3://dogs" adapted["dir_type"] = DirType.FILE - adapted["random"] = 1 return adapted dog_entries = [_adapt_row(e) for e in dog_entries] @@ -74,7 +74,7 @@ def schema(): "size": {"type": "Int64"}, "owner_name": {"type": "String"}, "owner_id": {"type": "String"}, - "random": {"type": "Int64"}, + "sys__rand": {"type": "Int64"}, "location": {"type": "String"}, "source": {"type": "String"}, } diff --git a/tests/scripts/feature_class_parallel.py b/tests/scripts/feature_class_parallel.py index 36d82a9ca..d38639525 100644 --- a/tests/scripts/feature_class_parallel.py +++ b/tests/scripts/feature_class_parallel.py @@ -27,4 +27,4 @@ class Embedding(BaseModel): ) for row in ds.results(): - print(row[5]) + print(row[2]) diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 4bbd44f24..b18d03bde 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -1,13 +1,14 @@ import datetime import math from collections.abc import Generator, Iterator +from unittest.mock import ANY import numpy as np import pandas as pd import pytest from pydantic import BaseModel -from datachain.lib.dc import C, DataChain +from datachain.lib.dc import C, DataChain, Sys from datachain.lib.file import File from datachain.lib.signal_schema import ( SignalResolvingError, @@ -837,3 +838,33 @@ def test_parse_tabular_object_name(tmp_dir, catalog): df.to_parquet(path) dc = DataChain.from_storage(path.as_uri()).parse_tabular(object_name="name") assert "name.first_name" in dc.to_pandas().columns + + +def test_sys_feature(tmp_dir, catalog): + ds = DataChain.from_values(t1=features) + ds_sys = ds.settings(include_sys=True) + assert ds.signals_schema.values == {"t1": MyFr} + assert ds_sys.signals_schema.values == {"t1": MyFr, "sys": Sys} + + args = [] + ds_sys.map(res=lambda sys, t1: args.append((sys, t1))).save("ds_sys") + + sys_cls = Sys.model_construct + assert args == [ + (sys_cls(id=1, rand=ANY), MyFr(nnn="n1", count=3)), + (sys_cls(id=2, rand=ANY), MyFr(nnn="n2", count=5)), + (sys_cls(id=3, rand=ANY), MyFr(nnn="n1", count=1)), + ] + assert "sys" not in ds_sys.catalog.get_dataset("ds_sys").feature_schema + + ds_no_sys = ds_sys.settings(include_sys=False) + assert ds_no_sys.signals_schema.values == {"t1": MyFr} + + args = [] + ds_no_sys.map(res=lambda t1: args.append(t1)).save("ds_no_sys") + assert args == [ + MyFr(nnn="n1", count=3), + MyFr(nnn="n2", count=5), + MyFr(nnn="n1", count=1), + ] + assert "sys" not in ds_no_sys.catalog.get_dataset("ds_no_sys").feature_schema diff --git a/tests/unit/test_dataset.py b/tests/unit/test_dataset.py index fc0dd2be6..eed7d4b0f 100644 --- a/tests/unit/test_dataset.py +++ b/tests/unit/test_dataset.py @@ -43,8 +43,8 @@ def test_dataset_table_compilation(): assert result.string == ( "\n" 'CREATE TABLE IF NOT EXISTS "ds-1" (\n' - "\tid INTEGER NOT NULL, \n" - "\trandom INTEGER DEFAULT (abs(random())) NOT NULL, \n" + "\tsys__id INTEGER NOT NULL, \n" + "\tsys__rand INTEGER DEFAULT (abs(random())) NOT NULL, \n" "\tvtype VARCHAR NOT NULL, \n" "\tdir_type INTEGER, \n" "\tparent VARCHAR, \n" @@ -60,7 +60,7 @@ def test_dataset_table_compilation(): "\tsource VARCHAR NOT NULL, \n" "\tscore FLOAT NOT NULL, \n" "\tmeta_info JSON, \n" - "\tPRIMARY KEY (id)\n" + "\tPRIMARY KEY (sys__id)\n" ")\n" "\n" ) diff --git a/tests/unit/test_listing.py b/tests/unit/test_listing.py index 241b849d8..db0c75f2c 100644 --- a/tests/unit/test_listing.py +++ b/tests/unit/test_listing.py @@ -121,9 +121,9 @@ def test_list_dir(listing): def test_list_file(listing): file = listing.resolve_path("dir1/dataset.csv") src = DataSource(listing, file) - results = list(src.ls(["id", "name", "dir_type"])) + results = list(src.ls(["sys__id", "name", "dir_type"])) assert {r[1] for r in results} == {"dataset.csv"} - assert results[0][0] == file.id + assert results[0][0] == file.sys__id assert results[0][1] == file.name assert results[0][2] == DirType.FILE diff --git a/tests/unit/test_udf.py b/tests/unit/test_udf.py index fb1f3a820..00acf3975 100644 --- a/tests/unit/test_udf.py +++ b/tests/unit/test_udf.py @@ -12,23 +12,7 @@ def test_udf_single_signal(): def t(a, b): return (a * b,) - row = RowDict( - id=6, - vtype="", - dir_type=1, - parent="", - name="obj", - last_modified=None, - etag="", - version="", - is_latest=True, - size=7, - owner_name="", - owner_id="", - source="", - random=1234, - location=None, - ) + row = RowDict(sys__id=1, sys__rand=1234, id=6, size=7) result = t.run_once(None, row) assert result[0]["mul"] == (42) @@ -38,25 +22,9 @@ def test_udf_multiple_signals(): def t(a, b): return (a * b, a + b) - row = RowDict( - id=6, - vtype="", - dir_type=1, - parent="", - name="obj", - last_modified=None, - etag="", - version="", - is_latest=True, - size=7, - owner_name="", - owner_id="", - source="", - random=1234, - location=None, - ) + row = RowDict(sys__id=1, sys__rand=1234, id=6, size=7) result = t.run_once(None, row) - assert result[0] == {"id": 6, "mul": 42, "sum": 13} + assert result[0] == {"sys__id": 1, "mul": 42, "sum": 13} def test_udf_batching(): @@ -66,24 +34,8 @@ def t(vals): inputs = list(zip(range(1, 11), range(21, 31))) results = [] - for size, row_id in inputs: - row = RowDict( - id=row_id, - vtype="", - dir_type=1, - parent="", - name="obj", - last_modified=None, - etag="", - version="", - is_latest=True, - size=size, - owner_name="", - owner_id="", - source="", - random=1234, - location=None, - ) + for row_id, (size, id) in enumerate(inputs): + row = RowDict(sys__id=row_id, sys__rand=1234 + row_id, id=id, size=size) batch = RowBatch([row]) result = t.run_once(None, batch) if result: @@ -91,7 +43,9 @@ def t(vals): results.extend(result) assert len(results) == len(inputs) - assert results == [{"id": b, "mul": a * b} for (a, b) in inputs] + assert results == [ + {"sys__id": id, "mul": a * b} for id, (a, b) in enumerate(inputs) + ] def test_stateful_udf(): @@ -108,7 +62,7 @@ def sum(self, size): results = [] for size in inputs: row = RowDict( - id=5, + sys__id=5, vtype="", dir_type=1, parent="", @@ -127,7 +81,7 @@ def sum(self, size): results.extend(udf_inst.run_once(None, row)) assert len(results) == len(inputs) - assert results == [{"id": 5, "sum": 5 + size} for size in inputs] + assert results == [{"sys__id": 5, "sum": 5 + size} for size in inputs] @pytest.mark.parametrize("param", ["foo", ("foo",)])