Skip to content

Commit

Permalink
Implement sys feature, and rename id/random columns (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Jul 17, 2024
1 parent d39d9af commit 5cf20d3
Show file tree
Hide file tree
Showing 24 changed files with 213 additions and 196 deletions.
3 changes: 1 addition & 2 deletions examples/pose_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def __call__(
return

# CLeanup to records
del record["random"] # random will be populated automatically
record["is_latest"] = record["is_latest"] > 0 # needs to be a bool
row = DatasetRow.create(**record)

Expand Down Expand Up @@ -211,7 +210,7 @@ def __call__(
data = (
DatasetQuery(os.path.join(cloud_prefix, bucket))
.filter(C.name.glob(file_type))
.filter(C.random % filter_mod == chunk_num)
.filter(C.sys__rand % filter_mod == chunk_num)
.generate(pose_udf)
.results()
)
Expand Down
3 changes: 2 additions & 1 deletion src/datachain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datachain.lib.data_model import DataModel, DataType, FileBasic, is_chain_type
from datachain.lib.dc import C, Column, DataChain
from datachain.lib.dc import C, Column, DataChain, Sys
from datachain.lib.file import (
File,
FileError,
Expand Down Expand Up @@ -31,6 +31,7 @@
"IndexedFile",
"Mapper",
"Session",
"Sys",
"TarVFile",
"TextFile",
"is_chain_type",
Expand Down
10 changes: 5 additions & 5 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def do_task(self, urls):
self.fix_columns(df)

# id will be autogenerated in DB
df = df.drop("id", axis=1)
df = df.drop("sys__id", axis=1)

inserted = warehouse.insert_dataset_rows(
df, dataset, self.dataset_version
Expand Down Expand Up @@ -1041,7 +1041,7 @@ def create_dataset(
If version is None, then next unused version is created.
If version is given, then it must be an unused version number.
"""
assert [c.name for c in columns if c.name != "id"], f"got {columns=}"
assert [c.name for c in columns if c.name != "sys__id"], f"got {columns=}"
if not listing and Client.is_data_source_uri(name):
raise RuntimeError(
"Cannot create dataset that starts with source prefix, e.g s3://"
Expand Down Expand Up @@ -1103,7 +1103,7 @@ def create_new_dataset_version(
Creates dataset version if it doesn't exist.
If create_rows is False, dataset rows table will not be created
"""
assert [c.name for c in columns if c.name != "id"], f"got {columns=}"
assert [c.name for c in columns if c.name != "sys__id"], f"got {columns=}"
schema = {
c.name: c.type.to_dict() for c in columns if isinstance(c.type, SQLType)
}
Expand Down Expand Up @@ -1433,7 +1433,7 @@ def ls_dataset_rows(
if offset:
q = q.offset(offset)

q = q.order_by("id")
q = q.order_by("sys__id")

return q.to_records()

Expand Down Expand Up @@ -1786,7 +1786,7 @@ def _instantiate_dataset():
schema = DatasetRecord.parse_schema(remote_dataset_version.schema)

columns = tuple(
sa.Column(name, typ) for name, typ in schema.items() if name != "id"
sa.Column(name, typ) for name, typ in schema.items() if name != "sys__id"
)
# creating new dataset (version) locally
dataset = self.create_dataset(
Expand Down
2 changes: 0 additions & 2 deletions src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,6 @@ def show(
from datachain.query import DatasetQuery
from datachain.utils import show_records

if columns:
columns = ("id", *columns)
query = (
DatasetQuery(name=name, version=version, catalog=catalog)
.select(*columns)
Expand Down
10 changes: 5 additions & 5 deletions src/datachain/data_storage/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class DirExpansion:
@staticmethod
def base_select(q):
return sa.select(
q.c.id,
q.c.sys__id,
q.c.vtype,
(q.c.dir_type == DirType.DIR).label("is_dir"),
q.c.source,
Expand All @@ -86,7 +86,7 @@ def base_select(q):
def apply_group_by(q):
return (
sa.select(
f.min(q.c.id).label("id"),
f.min(q.c.sys__id).label("sys__id"),
q.c.vtype,
q.c.is_dir,
q.c.source,
Expand All @@ -111,7 +111,7 @@ def query(cls, q):
parent_name = path.name(q.c.parent)
q = q.union_all(
sa.select(
sa.literal(-1).label("id"),
sa.literal(-1).label("sys__id"),
sa.literal("").label("vtype"),
true().label("is_dir"),
q.c.source,
Expand Down Expand Up @@ -233,9 +233,9 @@ def delete(self):
@staticmethod
def sys_columns():
return [
sa.Column("id", Int, primary_key=True),
sa.Column("sys__id", Int, primary_key=True),
sa.Column(
"random", UInt64, nullable=False, server_default=f.abs(f.random())
"sys__rand", UInt64, nullable=False, server_default=f.abs(f.random())
),
]

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def merge_dataset_rows(
dst_empty = True

dst_dr = self.dataset_rows(dst, dst_version).table
merge_fields = [c.name for c in src_dr.c if c.name != "id"]
merge_fields = [c.name for c in src_dr.c if c.name != "sys__id"]
select_src = select(*(getattr(src_dr.c, f) for f in merge_fields))

if dst_empty:
Expand Down
14 changes: 7 additions & 7 deletions src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def dataset_select_paginated(
cols_names = [c.name for c in cols]

if not order_by:
ordering = [cols.id]
ordering = [cols.sys__id]
else:
ordering = order_by # type: ignore[assignment]

Expand Down Expand Up @@ -372,7 +372,7 @@ def dataset_rows_count(self, dataset: DatasetRecord, version=None) -> int:
"""Returns total number of rows in a dataset"""
dr = self.dataset_rows(dataset, version)
table = dr.get_table()
query = select(sa.func.count(table.c.id))
query = select(sa.func.count(table.c.sys__id))
(res,) = self.db.execute(query)
return res[0]

Expand All @@ -388,7 +388,7 @@ def dataset_stats(
dr = self.dataset_rows(dataset, version)
table = dr.get_table()
expressions: tuple[_ColumnsClauseArgument[Any], ...] = (
sa.func.count(table.c.id),
sa.func.count(table.c.sys__id),
)
if "size" in table.columns:
expressions = (*expressions, sa.func.sum(table.c.size))
Expand Down Expand Up @@ -607,7 +607,7 @@ def with_default(column):
return func.coalesce(column, default).label(column.name)

return sa.select(
de.c.id,
de.c.sys__id,
with_default(dr.c.vtype),
case((de.c.is_dir == true(), DirType.DIR), else_=dr.c.dir_type).label(
"dir_type"
Expand All @@ -621,10 +621,10 @@ def with_default(column):
with_default(dr.c.size),
with_default(dr.c.owner_name),
with_default(dr.c.owner_id),
with_default(dr.c.random),
with_default(dr.c.sys__rand),
dr.c.location,
de.c.source,
).select_from(de.outerjoin(dr.table, de.c.id == dr.c.id))
).select_from(de.outerjoin(dr.table, de.c.sys__id == dr.c.sys__id))

def get_node_by_path(self, dataset_rows: "DataTable", path: str) -> Node:
"""Gets node that corresponds to some path"""
Expand Down Expand Up @@ -878,7 +878,7 @@ def create_udf_table(
tbl = sa.Table(
name,
sa.MetaData(),
sa.Column("id", Int, primary_key=True),
sa.Column("sys__id", Int, primary_key=True),
*columns,
)
self.db.create_table(tbl, if_not_exists=True)
Expand Down
61 changes: 46 additions & 15 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import re
from collections.abc import Iterator, Sequence
from typing import (
Expand Down Expand Up @@ -72,6 +73,11 @@ def __init__(self, on: Sequence[str], right_on: Optional[Sequence[str]], msg: st
OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]


class Sys(DataModel):
id: int
rand: int


class DataChain(DatasetQuery):
"""AI 🔗 DataChain - a data structure for batch data processing and evaluation.
Expand Down Expand Up @@ -124,12 +130,10 @@ class Rating(Feature):
"""

DEFAULT_FILE_RECORD: ClassVar[dict] = {
"id": 0,
"source": "",
"name": "",
"vtype": "",
"size": 0,
"random": 0,
}

def __init__(self, *args, **kwargs):
Expand All @@ -155,8 +159,19 @@ def schema(self):
def print_schema(self):
self.signals_schema.print_tree()

def clone(self, new_table: bool = True) -> "Self":
obj = super().clone(new_table=new_table)
obj.signals_schema = copy.deepcopy(self.signals_schema)
return obj

def settings(
self, cache=None, batch=None, parallel=None, workers=None, min_task_size=None
self,
cache=None,
batch=None,
parallel=None,
workers=None,
min_task_size=None,
include_sys: Optional[bool] = None,
) -> "Self":
"""Change settings for chain.
Expand All @@ -180,8 +195,13 @@ def settings(
)
```
"""
self._settings.add(Settings(cache, batch, parallel, workers, min_task_size))
return self
chain = self.clone()
if include_sys is True:
chain.signals_schema = SignalSchema({"sys": Sys}) | chain.signals_schema
elif include_sys is False and "sys" in chain.signals_schema:
chain.signals_schema.remove("sys")
chain._settings.add(Settings(cache, batch, parallel, workers, min_task_size))
return chain

def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
"""Reset all settings to default values."""
Expand All @@ -193,12 +213,11 @@ def reset_schema(self, signals_schema: SignalSchema) -> "Self":
return self

def add_schema(self, signals_schema: SignalSchema) -> "Self":
union = self.signals_schema.values | signals_schema.values
self.signals_schema = SignalSchema(union)
self.signals_schema |= signals_schema
return self

def get_file_signals(self) -> list[str]:
return self.signals_schema.get_file_signals()
return list(self.signals_schema.get_file_signals())

@classmethod
def from_storage(
Expand Down Expand Up @@ -352,6 +371,7 @@ def save( # type: ignore[override]
version : version of a dataset. Default - the last version that exist.
"""
schema = self.signals_schema.serialize()
schema.pop("sys", None)
return super().save(name=name, version=version, feature_schema=schema)

def apply(self, func, *args, **kwargs):
Expand Down Expand Up @@ -528,20 +548,31 @@ def select_except(self, *args: str) -> "Self":
chain.signals_schema = new_schema
return chain

def iterate_flatten(self) -> Iterator[tuple[Any]]:
db_signals = self.signals_schema.db_signals()
with super().select(*db_signals).as_iterable() as rows:
yield from rows

def results(
self, row_factory: Optional[Callable] = None, **kwargs
) -> list[tuple[Any, ...]]:
rows = self.iterate_flatten()
if row_factory:
db_signals = self.signals_schema.db_signals()
rows = (row_factory(db_signals, r) for r in rows)
return list(rows)

def iterate(self, *cols: str) -> Iterator[list[DataType]]:
"""Iterate over rows.
If columns are specified - limit them to specified
columns.
"""
chain = self.select(*cols) if cols else self

db_signals = chain.signals_schema.db_signals()
with super().select(*db_signals).as_iterable() as rows_iter:
for row in rows_iter:
yield chain.signals_schema.row_to_features(
row, catalog=chain.session.catalog, cache=chain._settings.cache
)
for row in chain.iterate_flatten():
yield chain.signals_schema.row_to_features(
row, catalog=chain.session.catalog, cache=chain._settings.cache
)

def iterate_one(self, col: str) -> Iterator[DataType]:
for item in self.iterate(col):
Expand Down
1 change: 0 additions & 1 deletion src/datachain/lib/image_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __call__(
):
# Build a dict from row contents
record = dict(zip(DatasetRow.schema.keys(), args))
del record["random"] # random will be populated automatically
record["is_latest"] = record["is_latest"] > 0 # needs to be a bool

# yield same row back
Expand Down
9 changes: 9 additions & 0 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,15 @@ def print_tree(self, indent: int = 4, start_at: int = 0):
sub_schema = SignalSchema({"* list of": args[0]})
sub_schema.print_tree(indent=indent, start_at=total_indent + indent)

def __or__(self, other):
return self.__class__(self.values | other.values)

def __contains__(self, name: str):
return name in self.values

def remove(self, name: str):
return self.values.pop(name)

@staticmethod
def _type_to_str(type_): # noqa: PLR0911
origin = get_origin(type_)
Expand Down
8 changes: 4 additions & 4 deletions src/datachain/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class DirTypeGroup:

@attrs.define
class Node:
id: int = 0
random: int = -1
sys__id: int = 0
sys__rand: int = -1
vtype: str = ""
dir_type: Optional[int] = None
parent: str = ""
Expand Down Expand Up @@ -127,11 +127,11 @@ def from_dict(cls, d: dict[str, Any]) -> "Self":

@classmethod
def from_dir(cls, parent, name, **kwargs) -> "Node":
return cls(id=-1, dir_type=DirType.DIR, parent=parent, name=name, **kwargs)
return cls(sys__id=-1, dir_type=DirType.DIR, parent=parent, name=name, **kwargs)

@classmethod
def root(cls) -> "Node":
return cls(-1, dir_type=DirType.DIR)
return cls(sys__id=-1, dir_type=DirType.DIR)


@attrs.define
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/query/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __call__(
with contextlib.closing(
execute(
query,
order_by=(PARTITION_COLUMN_ID, "id", *query._order_by_clauses),
order_by=(PARTITION_COLUMN_ID, "sys__id", *query._order_by_clauses),
limit=query._limit,
)
) as rows:
Expand Down
Loading

0 comments on commit 5cf20d3

Please sign in to comment.