diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index 64147469e..0dc2268c3 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -269,10 +269,22 @@ def ensure_cached(self) -> None: client = self._catalog.get_client(self.source) client.download(self, callback=self._download_cb) - async def _prefetch(self) -> None: - if self._caching_enabled: - client = self._catalog.get_client(self.source) - await client._download(self, callback=self._download_cb) + async def _prefetch(self, catalog=None, download_cb=None) -> bool: + from datachain.client.hf import HfClient + + catalog = catalog or self._catalog + download_cb = download_cb or self._download_cb + if catalog is None: + raise RuntimeError("cannot prefetch file because catalog is not setup") + + client = catalog.get_client(self.source) + if client.protocol != HfClient.protocol: + self._set_stream(catalog, self._caching_enabled, download_cb=download_cb) + return False + + await client._download(self, callback=download_cb) + self._set_stream(catalog, caching_enabled=True) # reset download callback + return True def get_local_path(self) -> Optional[str]: """Return path to a file in a local cache. diff --git a/src/datachain/lib/prefetcher.py b/src/datachain/lib/prefetcher.py index b4315e7bb..fa6431ce4 100644 --- a/src/datachain/lib/prefetcher.py +++ b/src/datachain/lib/prefetcher.py @@ -27,9 +27,9 @@ async def _prefetch_input(row, catalog, download_cb): for obj in row: if isinstance(obj, File): - obj._set_stream(catalog, True, download_cb) - await obj._prefetch() - callback() + prefetched = await obj._prefetch(catalog, download_cb) + if prefetched: + callback() return row