Skip to content

Commit

Permalink
Add DataChain.listings() method and use it in getting storages (#331)
Browse files Browse the repository at this point in the history
* first version of from_storage without deprecated listing

* first version of from_storage without deprecated listing

* fixing tests and removing prints, refactoring

* refactoring listing static methods

* fixing non recursive queries

* using ctc in test session

* fixing json

* added DataChain.listings classmethod that returns list of ListingInfo objects for each cached listing

* another test for listings

* removed not needed filters

* refactoring test

* removed not needed catalog storage methods and their related codebase

* fixing windows tests

* returning to all tests

* removed unlist_source method and related codebase

* fixing dataset dependencies

* added session on cloud test catalog and refactoring tests

* using new listings method in from_storage

* fixing test

* fixing test

* added dataset name dependencies test and fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* small refactoring

* refactor comments

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ilongin and pre-commit-ci[bot] authored Sep 5, 2024
1 parent 944029b commit 576b69a
Show file tree
Hide file tree
Showing 18 changed files with 282 additions and 180 deletions.
23 changes: 0 additions & 23 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,20 +1018,6 @@ def _row_to_node(d: dict[str, Any]) -> Node:

return node_groups

def unlist_source(self, uri: StorageURI) -> None:
self.metastore.clone(uri=uri).mark_storage_not_indexed(uri)

def storage_stats(self, uri: StorageURI) -> Optional[DatasetStats]:
"""
Returns tuple with storage stats: total number of rows and total dataset size.
"""
partial_path = self.metastore.get_last_partial_path(uri)
if partial_path is None:
return None
dataset = self.get_dataset(Storage.dataset_name(uri, partial_path))

return self.dataset_stats(dataset.name, dataset.latest_version)

def create_dataset(
self,
name: str,
Expand Down Expand Up @@ -1618,15 +1604,6 @@ def ls(
for source in data_sources: # type: ignore [union-attr]
yield source, source.ls(fields)

def ls_storage_uris(self) -> Iterator[str]:
yield from self.metastore.get_all_storage_uris()

def get_storage(self, uri: StorageURI) -> Storage:
return self.metastore.get_storage(uri)

def ls_storages(self) -> list[Storage]:
return self.metastore.list_storages()

def pull_dataset(
self,
dataset_uri: str,
Expand Down
18 changes: 4 additions & 14 deletions src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from datachain import utils
from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs
from datachain.lib.dc import DataChain
from datachain.utils import DataChainDir

if TYPE_CHECKING:
Expand Down Expand Up @@ -615,18 +616,6 @@ def _ls_urls_flat(
raise FileNotFoundError(f"No such file or directory: {source}")


def ls_indexed_storages(catalog: "Catalog", long: bool = False) -> Iterator[str]:
from datachain.node import long_line_str

storage_uris = catalog.ls_storage_uris()
if long:
for uri in storage_uris:
# TODO: add Storage.created so it can be used here
yield long_line_str(uri, None, "")
else:
yield from storage_uris


def ls_local(
sources,
long: bool = False,
Expand Down Expand Up @@ -657,8 +646,9 @@ def ls_local(
for entry in entries:
print(format_ls_entry(entry))
else:
for entry in ls_indexed_storages(catalog, long=long):
print(format_ls_entry(entry))
chain = DataChain.listings()
for ls in chain.collect("listing"):
print(format_ls_entry(f"{ls.uri}@v{ls.version}")) # type: ignore[union-attr]


def format_ls_entry(entry: str) -> str:
Expand Down
29 changes: 1 addition & 28 deletions src/datachain/data_storage/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,32 +167,17 @@ def mark_storage_indexed(
This method should be called when index operation is finished.
"""

@abstractmethod
def mark_storage_not_indexed(self, uri: StorageURI) -> None:
"""
Mark storage as not indexed.
This method should be called when storage index is deleted.
"""

@abstractmethod
def update_last_inserted_at(self, uri: Optional[StorageURI] = None) -> None:
"""Updates last inserted datetime in bucket with current time."""

@abstractmethod
def get_all_storage_uris(self) -> Iterator[StorageURI]:
"""Returns all storage uris."""

@abstractmethod
def get_storage(self, uri: StorageURI) -> Storage:
"""
Gets storage representation from database.
E.g. if s3 is used as storage this would be s3 bucket data.
"""

@abstractmethod
def list_storages(self) -> list[Storage]:
"""Returns all storages."""

@abstractmethod
def mark_storage_pending(self, storage: Storage) -> Storage:
"""Marks storage as pending."""
Expand Down Expand Up @@ -324,7 +309,7 @@ def add_dependency(
self.add_dataset_dependency(
source_dataset_name,
source_dataset_version,
dependency.name,
dependency.dataset_name,
int(dependency.version),
)
else:
Expand Down Expand Up @@ -906,11 +891,6 @@ def update_last_inserted_at(self, uri: Optional[StorageURI] = None) -> None:
self._storages_update().where(s.c.uri == uri).values(**updates) # type: ignore [attr-defined]
)

def get_all_storage_uris(self) -> Iterator[StorageURI]:
"""Returns all storage uris."""
s = self._storages
yield from (r[0] for r in self.db.execute(self._storages_select(s.c.uri)))

def get_storage(self, uri: StorageURI, conn=None) -> Storage:
"""
Gets storage representation from database.
Expand All @@ -926,13 +906,6 @@ def get_storage(self, uri: StorageURI, conn=None) -> Storage:

return self.storage_class._make(result)

def list_storages(self) -> list[Storage]:
result = self.db.execute(self._storages_select())
if not result:
return []

return [self.storage_class._make(r) for r in result]

def mark_storage_pending(self, storage: Storage, conn=None) -> Storage:
# Update status to pending and dates
updates = {
Expand Down
11 changes: 0 additions & 11 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,17 +517,6 @@ def _datasets_versions_insert(self) -> "Insert":
def _datasets_dependencies_insert(self) -> "Insert":
return sqlite.insert(self._datasets_dependencies)

#
# Storages
#

def mark_storage_not_indexed(self, uri: StorageURI) -> None:
"""
Mark storage as not indexed.
This method should be called when storage index is deleted.
"""
self.db.execute(self._storages_delete().where(self._storages.c.uri == uri))

#
# Dataset dependencies
#
Expand Down
69 changes: 38 additions & 31 deletions src/datachain/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
)
from urllib.parse import urlparse

from dateutil.parser import isoparse

from datachain.client import Client
from datachain.sql.types import NAME_TYPES_MAPPING, SQLType

Expand Down Expand Up @@ -73,11 +71,22 @@ class DatasetDependencyType:
class DatasetDependency:
id: int
type: str
name: str # when the type is STORAGE, this is actually StorageURI
version: str # string until we'll have proper bucket listing versions
name: str
version: str # TODO change to int
created_at: datetime
dependencies: list[Optional["DatasetDependency"]]

@property
def dataset_name(self) -> str:
"""Returns clean dependency dataset name"""
from datachain.lib.listing import parse_listing_uri

if self.type == DatasetDependencyType.DATASET:
return self.name

list_dataset_name, _, _ = parse_listing_uri(self.name.strip("/"), None, {})
return list_dataset_name

@classmethod
def parse(
cls: builtins.type[DD],
Expand All @@ -92,33 +101,31 @@ def parse(
dataset_version_created_at: Optional[datetime],
bucket_uri: Optional["StorageURI"],
) -> Optional["DatasetDependency"]:
if dataset_id:
assert dataset_name is not None
return cls(
id,
DatasetDependencyType.DATASET,
dataset_name,
(
str(dataset_version) # type: ignore[arg-type]
if dataset_version
else None
),
dataset_version_created_at or dataset_created_at, # type: ignore[arg-type]
[],
)
if bucket_uri:
return cls(
id,
DatasetDependencyType.STORAGE,
bucket_uri,
bucket_version, # type: ignore[arg-type]
isoparse(bucket_version), # type: ignore[arg-type]
[],
)
# dependency has been removed
# TODO we should introduce flags for removed datasets, instead of
# removing them from tables so that we can still have references
return None
from datachain.lib.listing import is_listing_dataset, listing_uri_from_name

if not dataset_id:
return None

assert dataset_name is not None
dependency_type = DatasetDependencyType.DATASET
dependency_name = dataset_name

if is_listing_dataset(dataset_name):
dependency_type = DatasetDependencyType.STORAGE # type: ignore[arg-type]
dependency_name = listing_uri_from_name(dataset_name)

return cls(
id,
dependency_type,
dependency_name,
(
str(dataset_version) # type: ignore[arg-type]
if dataset_version
else None
),
dataset_version_created_at or dataset_created_at, # type: ignore[arg-type]
[],
)

@property
def is_dataset(self) -> bool:
Expand Down
4 changes: 4 additions & 0 deletions src/datachain/lib/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class DatasetInfo(DataModel):
size: Optional[int] = Field(default=None)
params: dict[str, str] = Field(default=dict)
metrics: dict[str, Any] = Field(default=dict)
error_message: str = Field(default="")
error_stack: str = Field(default="")

@staticmethod
def _validate_dict(
Expand Down Expand Up @@ -67,4 +69,6 @@ def from_models(
size=version.size,
params=job.params if job else {},
metrics=job.metrics if job else {},
error_message=version.error_message,
error_stack=version.error_stack,
)
47 changes: 39 additions & 8 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ls,
parse_listing_uri,
)
from datachain.lib.listing_info import ListingInfo
from datachain.lib.meta_formats import read_meta, read_schema
from datachain.lib.model_store import ModelStore
from datachain.lib.settings import Settings
Expand Down Expand Up @@ -349,10 +350,7 @@ def from_storage(
"""
file_type = get_file_type(type)

if anon:
client_config = {"anon": True}
else:
client_config = None
client_config = {"anon": True} if anon else None

session = Session.get(session, client_config=client_config, in_memory=in_memory)

Expand All @@ -361,12 +359,9 @@ def from_storage(
)
need_listing = True

for ds in cls.datasets(
session=session, in_memory=in_memory, include_listing=True
).collect("dataset"):
for ds in cls.listings(session=session, in_memory=in_memory).collect("listing"):
if (
not is_listing_expired(ds.created_at) # type: ignore[union-attr]
and is_listing_dataset(ds.name) # type: ignore[union-attr]
and is_listing_subset(ds.name, list_dataset_name) # type: ignore[union-attr]
and not update
):
Expand Down Expand Up @@ -577,6 +572,42 @@ def datasets(
**{object_name: datasets}, # type: ignore[arg-type]
)

@classmethod
def listings(
cls,
session: Optional[Session] = None,
in_memory: bool = False,
object_name: str = "listing",
**kwargs,
) -> "DataChain":
"""Generate chain with list of cached listings.
Listing is a special kind of dataset which has directory listing data of
some underlying storage (e.g S3 bucket).
Example:
```py
from datachain import DataChain
DataChain.listings().show()
```
"""
session = Session.get(session, in_memory=in_memory)
catalog = kwargs.get("catalog") or session.catalog

listings = [
ListingInfo.from_models(d, v, j)
for d, v, j in catalog.list_datasets_versions(
include_listing=True, **kwargs
)
if is_listing_dataset(d.name)
]

return cls.from_values(
session=session,
in_memory=in_memory,
output={object_name: ListingInfo},
**{object_name: listings}, # type: ignore[arg-type]
)

def print_json_schema( # type: ignore[override]
self, jmespath: Optional[str] = None, model_name: Optional[str] = None
) -> "Self":
Expand Down
7 changes: 7 additions & 0 deletions src/datachain/lib/listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ def is_listing_dataset(name: str) -> bool:
return name.startswith(LISTING_PREFIX)


def listing_uri_from_name(dataset_name: str) -> str:
"""Returns clean storage URI from listing dataset name"""
if not is_listing_dataset(dataset_name):
raise ValueError(f"Dataset {dataset_name} is not a listing")
return dataset_name.removeprefix(LISTING_PREFIX)


def is_listing_expired(created_at: datetime) -> bool:
"""Checks if listing has expired based on it's creation date"""
return datetime.now(timezone.utc) > created_at + timedelta(seconds=LISTING_TTL)
Expand Down
32 changes: 32 additions & 0 deletions src/datachain/lib/listing_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from datetime import datetime, timedelta, timezone
from typing import Optional

from datachain.client import Client
from datachain.lib.dataset_info import DatasetInfo
from datachain.lib.listing import LISTING_PREFIX, LISTING_TTL


class ListingInfo(DatasetInfo):
@property
def uri(self) -> str:
return self.name.removeprefix(LISTING_PREFIX)

@property
def storage_uri(self) -> str:
client, _ = Client.parse_url(self.uri, None) # type: ignore[arg-type]
return client.uri

@property
def expires(self) -> Optional[datetime]:
if not self.finished_at:
return None
return self.finished_at + timedelta(seconds=LISTING_TTL)

@property
def is_expired(self) -> bool:
return datetime.now(timezone.utc) > self.expires if self.expires else False

@property
def last_inserted_at(self):
# TODO we need to add updated_at to dataset version or explicit last_inserted_at
raise NotImplementedError
Loading

0 comments on commit 576b69a

Please sign in to comment.