Skip to content

Commit

Permalink
Dataset pull fixes (#560)
Browse files Browse the repository at this point in the history
* fixing dataset pull

* fixing pull test

* fixing random default

* testing uint32

* removing uint32 support from this PR
  • Loading branch information
ilongin authored Nov 6, 2024
1 parent 69f45eb commit 3dc4f3e
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 19 deletions.
8 changes: 2 additions & 6 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from datachain.node import DirType, Node, NodeWithPath
from datachain.nodes_thread_pool import NodesThreadPool
from datachain.remote.studio import StudioClient
from datachain.sql.types import DateTime, SQLType, String
from datachain.sql.types import DateTime, SQLType
from datachain.utils import (
DataChainDir,
batched,
Expand Down Expand Up @@ -196,11 +196,6 @@ def fix_columns(self, df) -> None:
for c in [c for c, t in self.schema.items() if t == DateTime]:
df[c] = pd.to_datetime(df[c], unit="s")

# strings are represented as binaries in parquet export so need to
# decode it back to strings
for c in [c for c, t in self.schema.items() if t == String]:
df[c] = df[c].str.decode("utf-8")

def do_task(self, urls):
import lz4.frame
import pandas as pd
Expand Down Expand Up @@ -1403,6 +1398,7 @@ def _instantiate_dataset():
query_script=remote_dataset_version.query_script,
create_rows=True,
columns=columns,
feature_schema=remote_dataset_version.feature_schema,
validate_version=False,
)

Expand Down
6 changes: 4 additions & 2 deletions src/datachain/data_storage/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def query(self, q):


class DataTable:
MAX_RANDOM = 2**63 - 1

def __init__(
self,
name: str,
Expand Down Expand Up @@ -269,8 +271,8 @@ def update(self):
def delete(self):
return self.apply_conditions(self.table.delete())

@staticmethod
def sys_columns():
@classmethod
def sys_columns(cls):
return [
sa.Column("sys__id", Int, primary_key=True),
sa.Column(
Expand Down
2 changes: 2 additions & 0 deletions src/datachain/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ def array(self, value, item_type, dialect):

def json(self, value):
if isinstance(value, str):
if value == "":
return {}
return orjson.loads(value)
return value

Expand Down
7 changes: 2 additions & 5 deletions tests/func/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import sqlalchemy as sa

from datachain.data_storage.sqlite import SQLiteWarehouse
from datachain.data_storage.schema import DataTable
from datachain.dataset import DatasetDependencyType, DatasetStatus
from datachain.error import (
DatasetInvalidVersionError,
Expand Down Expand Up @@ -827,10 +827,7 @@ def test_row_random(cloud_test_catalog):
# Random values are unique
assert len(set(random_values)) == len(random_values)

if isinstance(catalog.warehouse, SQLiteWarehouse):
RAND_MAX = 2**63 # noqa: N806
else:
RAND_MAX = 2**64 # noqa: N806
RAND_MAX = DataTable.MAX_RANDOM # noqa: N806

# Values are drawn uniformly from range(2**63)
assert 0 <= min(random_values) < 0.4 * RAND_MAX
Expand Down
10 changes: 4 additions & 6 deletions tests/func/test_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,17 @@ def _adapt_row(row):
"""
adapted = {}
for k, v in row.items():
if isinstance(v, str):
adapted[k] = v.encode("utf-8")
elif isinstance(v, datetime):
if isinstance(v, datetime):
adapted[k] = v.timestamp()
elif v is None:
adapted[k] = b""
adapted[k] = ""
else:
adapted[k] = v

adapted["sys__id"] = 1
adapted["sys__rand"] = 1
adapted["file__location"] = b""
adapted["file__source"] = b"s3://dogs"
adapted["file__location"] = ""
adapted["file__source"] = "s3://dogs"
return adapted

dog_entries = [_adapt_row(e) for e in dog_entries]
Expand Down

0 comments on commit 3dc4f3e

Please sign in to comment.