From 3dc4f3ea9d44b49a3a30d71f05879a33dd418e02 Mon Sep 17 00:00:00 2001 From: Ivan Longin Date: Wed, 6 Nov 2024 14:17:54 +0100 Subject: [PATCH] Dataset pull fixes (#560) * fixing dataset pull * fixing pull test * fixing random default * testing uint32 * removing uint32 support from this PR --- src/datachain/catalog/catalog.py | 8 ++------ src/datachain/data_storage/schema.py | 6 ++++-- src/datachain/sql/types.py | 2 ++ tests/func/test_datasets.py | 7 ++----- tests/func/test_pull.py | 10 ++++------ 5 files changed, 14 insertions(+), 19 deletions(-) diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 7b19e8538..f48e0da1e 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -58,7 +58,7 @@ from datachain.node import DirType, Node, NodeWithPath from datachain.nodes_thread_pool import NodesThreadPool from datachain.remote.studio import StudioClient -from datachain.sql.types import DateTime, SQLType, String +from datachain.sql.types import DateTime, SQLType from datachain.utils import ( DataChainDir, batched, @@ -196,11 +196,6 @@ def fix_columns(self, df) -> None: for c in [c for c, t in self.schema.items() if t == DateTime]: df[c] = pd.to_datetime(df[c], unit="s") - # strings are represented as binaries in parquet export so need to - # decode it back to strings - for c in [c for c, t in self.schema.items() if t == String]: - df[c] = df[c].str.decode("utf-8") - def do_task(self, urls): import lz4.frame import pandas as pd @@ -1403,6 +1398,7 @@ def _instantiate_dataset(): query_script=remote_dataset_version.query_script, create_rows=True, columns=columns, + feature_schema=remote_dataset_version.feature_schema, validate_version=False, ) diff --git a/src/datachain/data_storage/schema.py b/src/datachain/data_storage/schema.py index f34bdeeca..0c002e5dd 100644 --- a/src/datachain/data_storage/schema.py +++ b/src/datachain/data_storage/schema.py @@ -145,6 +145,8 @@ def query(self, q): class DataTable: + MAX_RANDOM = 2**63 - 1 + def __init__( self, name: str, @@ -269,8 +271,8 @@ def update(self): def delete(self): return self.apply_conditions(self.table.delete()) - @staticmethod - def sys_columns(): + @classmethod + def sys_columns(cls): return [ sa.Column("sys__id", Int, primary_key=True), sa.Column( diff --git a/src/datachain/sql/types.py b/src/datachain/sql/types.py index 14fec9502..8b90efaa2 100644 --- a/src/datachain/sql/types.py +++ b/src/datachain/sql/types.py @@ -440,6 +440,8 @@ def array(self, value, item_type, dialect): def json(self, value): if isinstance(value, str): + if value == "": + return {} return orjson.loads(value) return value diff --git a/tests/func/test_datasets.py b/tests/func/test_datasets.py index d1fe0ac8d..d637fea58 100644 --- a/tests/func/test_datasets.py +++ b/tests/func/test_datasets.py @@ -5,7 +5,7 @@ import pytest import sqlalchemy as sa -from datachain.data_storage.sqlite import SQLiteWarehouse +from datachain.data_storage.schema import DataTable from datachain.dataset import DatasetDependencyType, DatasetStatus from datachain.error import ( DatasetInvalidVersionError, @@ -827,10 +827,7 @@ def test_row_random(cloud_test_catalog): # Random values are unique assert len(set(random_values)) == len(random_values) - if isinstance(catalog.warehouse, SQLiteWarehouse): - RAND_MAX = 2**63 # noqa: N806 - else: - RAND_MAX = 2**64 # noqa: N806 + RAND_MAX = DataTable.MAX_RANDOM # noqa: N806 # Values are drawn uniformly from range(2**63) assert 0 <= min(random_values) < 0.4 * RAND_MAX diff --git a/tests/func/test_pull.py b/tests/func/test_pull.py index 53e818963..bb720fbf5 100644 --- a/tests/func/test_pull.py +++ b/tests/func/test_pull.py @@ -49,19 +49,17 @@ def _adapt_row(row): """ adapted = {} for k, v in row.items(): - if isinstance(v, str): - adapted[k] = v.encode("utf-8") - elif isinstance(v, datetime): + if isinstance(v, datetime): adapted[k] = v.timestamp() elif v is None: - adapted[k] = b"" + adapted[k] = "" else: adapted[k] = v adapted["sys__id"] = 1 adapted["sys__rand"] = 1 - adapted["file__location"] = b"" - adapted["file__source"] = b"s3://dogs" + adapted["file__location"] = "" + adapted["file__source"] = "s3://dogs" return adapted dog_entries = [_adapt_row(e) for e in dog_entries]