From 78ee1ba74194fa00113258a8953f57682240911c Mon Sep 17 00:00:00 2001 From: Ivan Longin Date: Mon, 16 Sep 2024 14:27:17 +0200 Subject: [PATCH] Fix storage dependencies (#421) * fix storage dependencies * fix test * fix dataset parsing --- src/datachain/dataset.py | 2 +- tests/func/test_datachain.py | 18 +++++++++++++++++- tests/func/test_datasets.py | 7 ++++--- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/datachain/dataset.py b/src/datachain/dataset.py index 772d0745d..b26d76014 100644 --- a/src/datachain/dataset.py +++ b/src/datachain/dataset.py @@ -112,7 +112,7 @@ def parse( if is_listing_dataset(dataset_name): dependency_type = DatasetDependencyType.STORAGE # type: ignore[arg-type] - dependency_name = listing_uri_from_name(dataset_name) + dependency_name, _ = Client.parse_url(listing_uri_from_name(dataset_name)) return cls( id, diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index 92db470a6..0591e498b 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -10,8 +10,9 @@ from PIL import Image from sqlalchemy import Column +from datachain.client.local import FileClient from datachain.data_storage.sqlite import SQLiteWarehouse -from datachain.dataset import DatasetStats +from datachain.dataset import DatasetDependencyType, DatasetStats from datachain.lib.dc import C, DataChain, DataChainColumnError from datachain.lib.file import File, ImageFile from datachain.lib.listing import ( @@ -178,6 +179,21 @@ def _list_dataset_name(uri: str) -> str: ) +def test_from_storage_dependencies(cloud_test_catalog, cloud_type): + ctc = cloud_test_catalog + src_uri = ctc.src_uri + uri = f"{src_uri}/cats" + ds_name = "dep" + DataChain.from_storage(uri, session=ctc.session).save(ds_name) + dependencies = ctc.session.catalog.get_dataset_dependencies(ds_name, 1) + assert len(dependencies) == 1 + assert dependencies[0].type == DatasetDependencyType.STORAGE + if cloud_type == "file": + assert dependencies[0].name == FileClient.root_path().as_uri() + else: + assert dependencies[0].name == src_uri + + @pytest.mark.parametrize("use_cache", [True, False]) def test_map_file(cloud_test_catalog, use_cache): ctc = cloud_test_catalog diff --git a/tests/func/test_datasets.py b/tests/func/test_datasets.py index cf0304eae..090088b65 100644 --- a/tests/func/test_datasets.py +++ b/tests/func/test_datasets.py @@ -7,8 +7,9 @@ import sqlalchemy as sa from datachain.catalog.catalog import DATASET_INTERNAL_ERROR_MESSAGE +from datachain.client.local import FileClient from datachain.data_storage.sqlite import SQLiteWarehouse -from datachain.dataset import LISTING_PREFIX, DatasetDependencyType, DatasetStatus +from datachain.dataset import DatasetDependencyType, DatasetStatus from datachain.error import DatasetInvalidVersionError, DatasetNotFoundError from datachain.lib.dc import DataChain from datachain.lib.listing import parse_listing_uri @@ -805,7 +806,7 @@ def test_dataset_stats_registered_ds(cloud_test_catalog, dogs_dataset): @pytest.mark.parametrize("indirect", [True, False]) -def test_dataset_storage_dependencies(cloud_test_catalog, indirect): +def test_dataset_storage_dependencies(cloud_test_catalog, cloud_type, indirect): ctc = cloud_test_catalog session = ctc.session catalog = session.catalog @@ -824,7 +825,7 @@ def test_dataset_storage_dependencies(cloud_test_catalog, indirect): { "id": ANY, "type": DatasetDependencyType.STORAGE, - "name": lst_dataset.name.removeprefix(LISTING_PREFIX), + "name": uri if cloud_type != "file" else FileClient.root_path().as_uri(), "version": "1", "created_at": lst_dataset.get_version(1).created_at, "dependencies": [],