Skip to content

Commit

Permalink
prefetch: use a separate temporary cache for prefetching
Browse files Browse the repository at this point in the history
Unless `cache=True`, `prefetch` will use a separate temporary cache
that resides in `.datachain/tmp/prefetch-<random>` directory.

The temporary directory will be automatically deleted after
the prefetching is done. For `cache=True`, the cache will be
reused and won't be deleted at the end.

Please note that auto-cleanup does not work for PyTorch datasets
because there is no way to invoke cleanup from the Dataset side.
The DataLoader may still have cached data or rows even after
the Dataset has finished iterating. As a result, values associated
with a catalog/cache instance can outlive the Dataset instance.
One potential solution is to implement a custom dataloader or
provide a user-facing API.

In this PR, I have implemented the latter.
The PytorchDataset now includes a close method,
which can be used to clean up the temporary prefix cache.

Eg:
```python
with closing(dataset):
    pass
```
  • Loading branch information
skshetry committed Dec 23, 2024
1 parent aed6d96 commit afae789
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 41 deletions.
33 changes: 18 additions & 15 deletions examples/get_started/torch-loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import multiprocessing
import os
from contextlib import closing
from posixpath import basename

import torch
Expand Down Expand Up @@ -56,7 +57,7 @@ def forward(self, x):
if __name__ == "__main__":
ds = (
DataChain.from_storage(STORAGE, type="image")
.settings(cache=True, prefetch=25)
.settings(prefetch=25)
.filter(C("file.path").glob("*.jpg"))
.map(
label=lambda path: label_to_int(basename(path)[:3], CLASSES),
Expand All @@ -65,10 +66,11 @@ def forward(self, x):
)
)

dataset = ds.to_pytorch(transform=transform)
train_loader = DataLoader(
ds.to_pytorch(transform=transform),
dataset,
batch_size=25,
num_workers=max(4, os.cpu_count() or 2),
num_workers=min(4, os.cpu_count() or 2),
persistent_workers=True,
multiprocessing_context=multiprocessing.get_context("spawn"),
)
Expand All @@ -82,15 +84,16 @@ def forward(self, x):
with tqdm(
train_loader, desc=f"epoch {epoch + 1}/{NUM_EPOCHS}", unit="batch"
) as loader:
for data in loader:
inputs, labels = data
optimizer.zero_grad()

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)

# Backward pass and optimize
loss.backward()
optimizer.step()
loader.set_postfix(loss=loss.item())
with closing(dataset):
for data in loader:
inputs, labels = data
optimizer.zero_grad()

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)

# Backward pass and optimize
loss.backward()
optimizer.step()
loader.set_postfix(loss=loss.item())
30 changes: 30 additions & 0 deletions src/datachain/cache.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import os
import shutil
from collections.abc import Iterator
from contextlib import contextmanager
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Optional

from dvc_data.hashfile.db.local import LocalHashFileDB
from dvc_objects.fs.local import LocalFileSystem
from fsspec.callbacks import Callback, TqdmCallback
from tqdm import tqdm
from tqdm.std import TqdmDefaultWriteLock

from .progress import Tqdm

Expand All @@ -12,6 +18,9 @@
from datachain.lib.file import File


tqdm.set_lock(TqdmDefaultWriteLock())


def try_scandir(path):
try:
with os.scandir(path) as it:
Expand All @@ -20,6 +29,23 @@ def try_scandir(path):
pass


def get_temp_cache(tmp_dir: str, prefix: Optional[str] = None) -> "DataChainCache":
cache_dir = mkdtemp(prefix=prefix, dir=tmp_dir)
return DataChainCache(cache_dir, tmp_dir=tmp_dir)


@contextmanager
def temporary_cache(
tmp_dir: str, prefix: Optional[str] = None, delete: bool = False
) -> Iterator["DataChainCache"]:
cache = get_temp_cache(tmp_dir, prefix=prefix)
try:
yield cache
finally:
if delete:
cache.destroy()


class DataChainCache:
def __init__(self, cache_dir: str, tmp_dir: str):
self.odb = LocalHashFileDB(
Expand Down Expand Up @@ -94,6 +120,10 @@ def clear(self):
"""
self.odb.clear()

def destroy(self):
# `clear` leaves the prefix directory structure intact.
shutil.rmtree(self.cache_dir)

def get_total_size(self) -> int:
total = 0
for subdir in try_scandir(self.odb.path):
Expand Down
64 changes: 64 additions & 0 deletions src/datachain/lib/prefetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from contextlib import contextmanager, nullcontext
from functools import partial
from typing import TYPE_CHECKING, Optional

from fsspec.callbacks import DEFAULT_CALLBACK, Callback

from datachain.asyn import AsyncMapper
from datachain.cache import temporary_cache
from datachain.catalog.catalog import Catalog
from datachain.lib.file import File

if TYPE_CHECKING:
from contextlib import AbstractContextManager

from datachain.cache import DataChainCache as Cache


def noop(*args, **kwargs):
pass


async def _prefetch_input(row, catalog, download_cb):
try:
callback = download_cb.increment_file_count
except AttributeError:
callback = noop

for obj in row:
if isinstance(obj, File):
obj._set_stream(catalog, True, download_cb)
await obj._prefetch()
callback()
return row


@contextmanager
def catalog_with_cache(catalog: Catalog, cache):
ocache = catalog.cache
try:
catalog.cache = cache
yield catalog
finally:
catalog.cache = ocache


def rows_prefetcher(
catalog,
rows,
prefetch: int,
cache: Optional["Cache"] = None,
download_cb: Callback = DEFAULT_CALLBACK,
):
cache_ctx: AbstractContextManager[Cache]
if cache:
cache_ctx = nullcontext(cache)
else:
tmp_dir = catalog.cache.tmp_dir
assert tmp_dir
cache_ctx = temporary_cache(tmp_dir, prefix="prefetch-")

with cache_ctx as prefetch_cache, catalog_with_cache(catalog, prefetch_cache):
func = partial(_prefetch_input, download_cb=download_cb, catalog=catalog)
mapper = AsyncMapper(func, rows, workers=prefetch)
yield from mapper.iterate()
40 changes: 35 additions & 5 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
from collections.abc import Iterator
from contextlib import closing
from typing import TYPE_CHECKING, Any, Callable, Optional

from PIL import Image
Expand All @@ -9,11 +11,14 @@
from torchvision.transforms import v2

from datachain import Session
from datachain.asyn import AsyncMapper
from datachain.cache import get_temp_cache
from datachain.catalog import Catalog, get_catalog
from datachain.lib.dc import DataChain
from datachain.lib.prefetcher import rows_prefetcher
from datachain.lib.settings import Settings
from datachain.lib.text import convert_text
from datachain.progress import CombinedDownloadCallback
from datachain.query.dataset import get_download_callback

if TYPE_CHECKING:
from torchvision.transforms.v2 import Transform
Expand Down Expand Up @@ -75,6 +80,17 @@ def __init__(
if (prefetch := dc_settings.prefetch) is not None:
self.prefetch = prefetch

if self.cache:
self._cache = catalog.cache
else:
tmp_dir = catalog.cache.tmp_dir
assert tmp_dir
self._cache = get_temp_cache(tmp_dir, prefix="prefetch-")

def close(self):
if not self.cache:
self._cache.destroy()

def _init_catalog(self, catalog: "Catalog"):
# For compatibility with multiprocessing,
# we can only store params in __init__(), as Catalog isn't picklable
Expand Down Expand Up @@ -107,11 +123,25 @@ def _rows_iter(self, total_rank: int, total_workers: int):
def __iter__(self) -> Iterator[Any]:
total_rank, total_workers = self.get_rank_and_workers()
rows = self._rows_iter(total_rank, total_workers)
if self.prefetch > 0:
from datachain.lib.udf import _prefetch_input

rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
yield from map(self._process_row, rows)
download_cb = CombinedDownloadCallback()
if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"):
download_cb = get_download_callback(
f"{total_rank}/{total_workers}", position=total_rank
)

if self.prefetch > 0:
catalog = self._get_catalog()
rows = rows_prefetcher(
catalog,
rows,
self.prefetch,
cache=self._cache,
download_cb=download_cb,
)

with download_cb, closing(rows):
yield from map(self._process_row, rows)

def _process_row(self, row_features):
row = []
Expand Down
32 changes: 18 additions & 14 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from pydantic import BaseModel

from datachain.asyn import AsyncMapper
from datachain.dataset import RowDict
from datachain.lib.convert.flatten import flatten
from datachain.lib.data_model import DataValue
from datachain.lib.file import File
from datachain.lib.prefetcher import rows_prefetcher
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
from datachain.query.batch import (
Batch,
Expand Down Expand Up @@ -280,13 +280,6 @@ def process_safe(self, obj_rows):
return result_objs


async def _prefetch_input(row):
for obj in row:
if isinstance(obj, File):
await obj._prefetch()
return row


class Mapper(UDFBase):
"""Inherit from this class to pass to `DataChain.map()`."""

Expand All @@ -308,9 +301,14 @@ def run(
for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()
_cache = self.catalog.cache if cache else None
prepared_inputs = rows_prefetcher(
self.catalog,
prepared_inputs,
self.prefetch,
cache=_cache,
download_cb=download_cb,
)

with contextlib.closing(prepared_inputs):
for id_, *udf_args in prepared_inputs:
Expand Down Expand Up @@ -384,10 +382,16 @@ def run(
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
)

if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()
_cache = self.catalog.cache if cache else None
prepared_inputs = rows_prefetcher(
self.catalog,
prepared_inputs,
self.prefetch,
cache=_cache,
download_cb=download_cb,
)

with contextlib.closing(prepared_inputs):
for row in prepared_inputs:
Expand Down
21 changes: 18 additions & 3 deletions src/datachain/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

import logging
import sys
from threading import RLock
from typing import Any, ClassVar

from fsspec import Callback
from fsspec.callbacks import TqdmCallback
from tqdm import tqdm

from datachain.utils import env2bool

logger = logging.getLogger(__name__)
tqdm.set_lock(RLock())


class Tqdm(tqdm):
Expand Down Expand Up @@ -132,8 +131,24 @@ def format_dict(self):
return d


class CombinedDownloadCallback(TqdmCallback):
class CombinedDownloadCallback(Callback):
def set_size(self, size):
# This is a no-op to prevent fsspec's .get_file() from setting the combined
# download size to the size of the current file.
pass

def increment_file_count(self, n: int = 1) -> None:
pass


class TqdmCombinedDownloadCallback(CombinedDownloadCallback, TqdmCallback):
def __init__(self, tqdm_kwargs=None, *args, **kwargs):
self.files_count = 0
tqdm_kwargs = tqdm_kwargs or {}
tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count)
super().__init__(tqdm_kwargs, *args, **kwargs)

def increment_file_count(self, n: int = 1) -> None:
self.files_count += n
if self.tqdm is not None:
self.tqdm.postfix = f"{self.files_count} files"
14 changes: 10 additions & 4 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
from datachain.func.base import Function
from datachain.lib.udf import UDFAdapter
from datachain.progress import CombinedDownloadCallback
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
from datachain.sql.functions.random import rand
from datachain.utils import (
batched,
Expand Down Expand Up @@ -348,9 +348,15 @@ def process_udf_outputs(
warehouse.insert_rows(udf_table, row_chunk)


def get_download_callback() -> Callback:
return CombinedDownloadCallback(
{"desc": "Download", "unit": "B", "unit_scale": True, "unit_divisor": 1024}
def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallback:
return TqdmCombinedDownloadCallback(
{
"desc": "Download" + suffix,
"unit": "B",
"unit_scale": True,
"unit_divisor": 1024,
**kwargs,
},
)


Expand Down

0 comments on commit afae789

Please sign in to comment.