Skip to content

Commit

Permalink
add tests for prefetch
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Jan 1, 2025
1 parent cd59801 commit 92d0cc5
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 14 deletions.
8 changes: 1 addition & 7 deletions src/datachain/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,7 @@ async def download(
os.unlink(tmp_info)

def store_data(self, file: "File", contents: bytes) -> None:
checksum = file.get_hash()
dst = self.path_from_checksum(checksum)
if not os.path.exists(dst):
# Create the file only if it's not already in cache
os.makedirs(os.path.dirname(dst), exist_ok=True)
with open(dst, mode="wb") as f:
f.write(contents)
self.odb.add_bytes(file.get_hash(), contents)

def clear(self) -> None:
"""
Expand Down
15 changes: 9 additions & 6 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from collections.abc import Generator, Iterator
from collections.abc import Generator, Iterable, Iterator
from contextlib import closing
from typing import TYPE_CHECKING, Any, Callable, Optional

Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
if (prefetch := dc_settings.prefetch) is not None:
self.prefetch = prefetch

if self.cache:
if self.cache or not self.prefetch:
self._cache = catalog.cache
else:
tmp_dir = catalog.cache.tmp_dir
Expand Down Expand Up @@ -122,9 +122,8 @@ def _row_iter(
ds = ds.chunk(total_rank, total_workers)
yield from ds.collect()

def __iter__(self) -> Iterator[Any]:
def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]:
total_rank, total_workers = self.get_rank_and_workers()

download_cb = CombinedDownloadCallback()
if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"):
download_cb = get_download_callback(
Expand All @@ -142,10 +141,14 @@ def __iter__(self) -> Iterator[Any]:
download_cb=download_cb,
)

with download_cb, closing(rows):
with download_cb:
yield from rows

def __iter__(self) -> Iterator[list[Any]]:
with closing(self._iter_with_prefetch()) as rows:
yield from map(self._process_row, rows)

def _process_row(self, row_features):
def _process_row(self, row_features: Iterable[Any]) -> list[Any]:
row = []
for fr in row_features:
if hasattr(fr, "read"):
Expand Down
41 changes: 41 additions & 0 deletions tests/func/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
from contextlib import closing

import open_clip
import pytest
import torch
Expand All @@ -7,6 +10,7 @@
from torchvision.transforms import v2

from datachain.lib.dc import DataChain
from datachain.lib.file import File
from datachain.lib.pytorch import PytorchDataset


Expand Down Expand Up @@ -80,6 +84,43 @@ def test_to_pytorch(fake_dataset):
assert img.size() == Size([3, 64, 64])


@pytest.mark.parametrize("cache", (True, False))
@pytest.mark.parametrize("prefetch", (0, 10))
def test_prefetch(mocker, catalog, fake_dataset, cache, prefetch):
catalog.cache.clear()

dataset = fake_dataset.limit(10)
ds = dataset.settings(cache=cache, prefetch=prefetch).to_pytorch()

iter_with_prefetch = ds._iter_with_prefetch
_cache = ds._cache

def is_prefetched(file: File):
assert file._catalog
assert file._catalog.cache == _cache
return _cache.contains(file)

def check_prefetched():
for row in iter_with_prefetch():
files = [f for f in row if isinstance(f, File)]
assert files
files_not_in_cache = [f for f in files if not is_prefetched(f)]
if prefetch:
assert not files_not_in_cache, "Some files are not in cache"
else:
assert files == files_not_in_cache, "Some files are in cache"
yield row

# we peek internally with `_iter_with_prefetch` to check if the files are prefetched
# as `__iter__` transforms them.
m = mocker.patch.object(ds, "_iter_with_prefetch", wraps=check_prefetched)
with closing(ds), closing(iter(ds)) as rows:
assert next(rows)
m.assert_called_once()
# cache directory should be removed after `close()` if the cache is not enabled
assert os.path.exists(_cache.cache_dir) == cache


def test_hf_to_pytorch(catalog, fake_image_dir):
hf_ds = load_dataset("imagefolder", data_dir=fake_image_dir)
chain = DataChain.from_hf(hf_ds)
Expand Down
28 changes: 27 additions & 1 deletion tests/unit/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os

import pytest

from datachain.cache import DataChainCache
from datachain.cache import DataChainCache, get_temp_cache, temporary_cache
from datachain.lib.file import File


Expand Down Expand Up @@ -53,3 +55,27 @@ def test_remove(cache):
assert cache.contains(uid)
cache.remove(uid)
assert not cache.contains(uid)


def test_destroy(cache: DataChainCache):
file = File(source="s3://foo", path="data/bar", etag="xyz", size=3, location=None)
cache.store_data(file, b"foo")
assert cache.contains(file)

cache.destroy()
assert not os.path.exists(cache.cache_dir)


def test_get_temp_cache(tmp_path):
temp = get_temp_cache(tmp_path, prefix="test-")
assert os.path.isdir(temp.cache_dir)
assert isinstance(temp, DataChainCache)
head, tail = os.path.split(temp.cache_dir)
assert head == str(tmp_path)
assert tail.startswith("test-")


def test_temporary_cache(tmp_path):
with temporary_cache(tmp_path, prefix="test-") as temp:
assert os.path.isdir(temp.cache_dir)
assert not os.path.exists(temp.cache_dir)
40 changes: 40 additions & 0 deletions tests/unit/test_prefetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os

from datachain.cache import DataChainCache, get_temp_cache
from datachain.lib.file import File
from datachain.lib.prefetcher import rows_prefetcher


def test_prefetcher(mocker, tmp_dir, catalog):
rows = []
for path in "abcd":
(tmp_dir / path).write_text(path)
row = (File(path=str(tmp_dir / path)),)
rows.append(row)

for (file,) in rows_prefetcher(catalog, rows, prefetch=5):
assert file._catalog
head, tail = os.path.split(file._catalog.cache.cache_dir)
assert head == catalog.cache.tmp_dir
assert tail.startswith("prefetch-")
assert file._catalog.cache.contains(file)

cache = get_temp_cache(tmp_dir)
for (file,) in rows_prefetcher(catalog, rows, prefetch=5, cache=cache):
assert file._catalog
assert file._catalog.cache == cache
assert cache.contains(file)


def test_prefetcher_closes_temp_cache(mocker, tmp_dir, catalog):
rows = []
for path in "abcd":
(tmp_dir / path).write_text(path)
row = (File(path=str(tmp_dir / path)),)
rows.append(row)
spy = mocker.spy(DataChainCache, "destroy")

rows_gen = rows_prefetcher(catalog, rows, prefetch=5)
next(rows_gen)
rows_gen.close()
assert spy.called
35 changes: 35 additions & 0 deletions tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os

import pytest

from datachain.cache import DataChainCache
from datachain.lib.pytorch import PytorchDataset
from datachain.lib.settings import Settings


@pytest.mark.parametrize(
"cache,prefetch", [(True, 0), (True, 10), (False, 10), (False, 0)]
)
def test_cache(catalog, cache, prefetch):
settings = Settings(cache=cache, prefetch=prefetch)
ds = PytorchDataset("fake", 1, catalog, dc_settings=settings)
assert ds.cache == cache
assert ds.prefetch == prefetch

if cache or not prefetch:
assert catalog.cache is ds._cache
return

assert catalog.cache is not ds._cache
head, tail = os.path.split(ds._cache.cache_dir)
assert head == catalog.cache.tmp_dir
assert tail.startswith("prefetch-")


@pytest.mark.parametrize("cache", [True, False])
def test_close(mocker, catalog, cache):
ds = PytorchDataset("fake", 1, catalog, dc_settings=Settings(cache=cache))
spy = mocker.spy(DataChainCache, "destroy")

ds.close()
assert spy.called == (not cache)

0 comments on commit 92d0cc5

Please sign in to comment.