Skip to content

Commit

Permalink
prefetch: disable for huggingface
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Dec 26, 2024
1 parent 5443324 commit 278af30
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
20 changes: 16 additions & 4 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,22 @@ def ensure_cached(self) -> None:
client = self._catalog.get_client(self.source)
client.download(self, callback=self._download_cb)

async def _prefetch(self) -> None:
if self._caching_enabled:
client = self._catalog.get_client(self.source)
await client._download(self, callback=self._download_cb)
async def _prefetch(self, catalog=None, download_cb=None) -> bool:
from datachain.client.hf import HfClient

catalog = catalog or self._catalog
download_cb = download_cb or self._download_cb
if catalog is None:
raise RuntimeError("cannot prefetch file because catalog is not setup")

Check warning on line 278 in src/datachain/lib/file.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/file.py#L278

Added line #L278 was not covered by tests

client = catalog.get_client(self.source)
if client.protocol == HfClient.protocol:
self._set_stream(catalog, self._caching_enabled, download_cb=download_cb)
return False

Check warning on line 283 in src/datachain/lib/file.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/file.py#L282-L283

Added lines #L282 - L283 were not covered by tests

await client._download(self, callback=download_cb)
self._set_stream(catalog, caching_enabled=True) # reset download callback
return True

def get_local_path(self) -> Optional[str]:
"""Return path to a file in a local cache.
Expand Down
6 changes: 3 additions & 3 deletions src/datachain/lib/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ async def _prefetch_input(row, catalog, download_cb):

for obj in row:
if isinstance(obj, File):
obj._set_stream(catalog, True, download_cb)
await obj._prefetch()
callback()
prefetched = await obj._prefetch(catalog, download_cb)
if prefetched:
callback()
return row


Expand Down

0 comments on commit 278af30

Please sign in to comment.