From 392aaff01e94d638b56097d3399dcd9ad97a8d16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Tue, 31 Dec 2024 22:26:39 +0545 Subject: [PATCH] fix type hints --- src/datachain/cache.py | 6 +++--- src/datachain/lib/file.py | 6 +++++- src/datachain/lib/prefetcher.py | 8 ++++---- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/datachain/cache.py b/src/datachain/cache.py index bbd1eccf9..a71a21f56 100644 --- a/src/datachain/cache.py +++ b/src/datachain/cache.py @@ -49,7 +49,7 @@ def __init__(self, cache_dir: str, tmp_dir: str): tmp_dir=tmp_dir, ) - def __eq__(self, other): + def __eq__(self, other) -> bool: return self.odb == other.odb @property @@ -108,13 +108,13 @@ async def download( def store_data(self, file: "File", contents: bytes) -> None: self.odb.add_bytes(file.get_hash(), contents) - def clear(self): + def clear(self) -> None: """ Completely clear the cache. """ self.odb.clear() - def destroy(self): + def destroy(self) -> None: # `clear` leaves the prefix directory structure intact. remove(self.cache_dir) diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index 5e3cd987e..d9ee418c0 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -269,7 +269,11 @@ def ensure_cached(self) -> None: client = self._catalog.get_client(self.source) client.download(self, callback=self._download_cb) - async def _prefetch(self, catalog=None, download_cb=None) -> bool: + async def _prefetch( + self, + catalog: Optional["Catalog"] = None, + download_cb: Optional["Callback"] = None, + ) -> bool: from datachain.client.hf import HfClient catalog = catalog or self._catalog diff --git a/src/datachain/lib/prefetcher.py b/src/datachain/lib/prefetcher.py index f2c60fc86..4bc0fbda0 100644 --- a/src/datachain/lib/prefetcher.py +++ b/src/datachain/lib/prefetcher.py @@ -7,13 +7,13 @@ from datachain.asyn import AsyncMapper from datachain.cache import temporary_cache -from datachain.catalog.catalog import Catalog from datachain.lib.file import File if TYPE_CHECKING: from contextlib import AbstractContextManager from datachain.cache import DataChainCache as Cache + from datachain.catalog.catalog import Catalog T = TypeVar("T", bound=Sequence[Any]) @@ -23,7 +23,7 @@ def noop(*args, **kwargs): pass -async def _prefetch_input(row: T, catalog: Catalog, download_cb: Callback) -> T: +async def _prefetch_input(row: T, catalog: "Catalog", download_cb: Callback) -> T: try: callback = download_cb.increment_file_count except AttributeError: @@ -37,14 +37,14 @@ async def _prefetch_input(row: T, catalog: Catalog, download_cb: Callback) -> T: return row -def clone_catalog_with_cache(catalog: Catalog, cache: "Cache"): +def clone_catalog_with_cache(catalog: "Catalog", cache: "Cache") -> "Catalog": clone = catalog.copy() clone.cache = cache return clone def rows_prefetcher( - catalog, + catalog: "Catalog", rows: Iterable[T], prefetch: int, cache: Optional["Cache"] = None,