Skip to content

Commit

Permalink
fix type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Dec 31, 2024
1 parent ce4ab40 commit 392aaff
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/datachain/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, cache_dir: str, tmp_dir: str):
tmp_dir=tmp_dir,
)

def __eq__(self, other):
def __eq__(self, other) -> bool:
return self.odb == other.odb

@property
Expand Down Expand Up @@ -108,13 +108,13 @@ async def download(
def store_data(self, file: "File", contents: bytes) -> None:
self.odb.add_bytes(file.get_hash(), contents)

def clear(self):
def clear(self) -> None:
"""
Completely clear the cache.
"""
self.odb.clear()

def destroy(self):
def destroy(self) -> None:
# `clear` leaves the prefix directory structure intact.
remove(self.cache_dir)

Expand Down
6 changes: 5 additions & 1 deletion src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,11 @@ def ensure_cached(self) -> None:
client = self._catalog.get_client(self.source)
client.download(self, callback=self._download_cb)

async def _prefetch(self, catalog=None, download_cb=None) -> bool:
async def _prefetch(
self,
catalog: Optional["Catalog"] = None,
download_cb: Optional["Callback"] = None,
) -> bool:
from datachain.client.hf import HfClient

catalog = catalog or self._catalog
Expand Down
8 changes: 4 additions & 4 deletions src/datachain/lib/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

from datachain.asyn import AsyncMapper
from datachain.cache import temporary_cache
from datachain.catalog.catalog import Catalog
from datachain.lib.file import File

if TYPE_CHECKING:
from contextlib import AbstractContextManager

from datachain.cache import DataChainCache as Cache
from datachain.catalog.catalog import Catalog


T = TypeVar("T", bound=Sequence[Any])
Expand All @@ -23,7 +23,7 @@ def noop(*args, **kwargs):
pass

Check warning on line 23 in src/datachain/lib/prefetcher.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/prefetcher.py#L23

Added line #L23 was not covered by tests


async def _prefetch_input(row: T, catalog: Catalog, download_cb: Callback) -> T:
async def _prefetch_input(row: T, catalog: "Catalog", download_cb: Callback) -> T:
try:
callback = download_cb.increment_file_count
except AttributeError:
Expand All @@ -37,14 +37,14 @@ async def _prefetch_input(row: T, catalog: Catalog, download_cb: Callback) -> T:
return row


def clone_catalog_with_cache(catalog: Catalog, cache: "Cache"):
def clone_catalog_with_cache(catalog: "Catalog", cache: "Cache") -> "Catalog":
clone = catalog.copy()
clone.cache = cache
return clone


def rows_prefetcher(
catalog,
catalog: "Catalog",
rows: Iterable[T],
prefetch: int,
cache: Optional["Cache"] = None,
Expand Down

0 comments on commit 392aaff

Please sign in to comment.