diff --git a/examples/get_started/torch-loader.py b/examples/get_started/torch-loader.py index a77e30257..d8a559a9b 100644 --- a/examples/get_started/torch-loader.py +++ b/examples/get_started/torch-loader.py @@ -7,6 +7,7 @@ import multiprocessing import os +from contextlib import closing from posixpath import basename import torch @@ -56,7 +57,7 @@ def forward(self, x): if __name__ == "__main__": ds = ( DataChain.from_storage(STORAGE, type="image") - .settings(cache=True, prefetch=25) + .settings(prefetch=25) .filter(C("file.path").glob("*.jpg")) .map( label=lambda path: label_to_int(basename(path)[:3], CLASSES), @@ -65,10 +66,11 @@ def forward(self, x): ) ) + dataset = ds.to_pytorch(transform=transform) train_loader = DataLoader( - ds.to_pytorch(transform=transform), + dataset, batch_size=25, - num_workers=max(4, os.cpu_count() or 2), + num_workers=min(4, os.cpu_count() or 2), persistent_workers=True, multiprocessing_context=multiprocessing.get_context("spawn"), ) @@ -78,19 +80,20 @@ def forward(self, x): optimizer = optim.Adam(model.parameters(), lr=0.001) # Train the model - for epoch in range(NUM_EPOCHS): - with tqdm( - train_loader, desc=f"epoch {epoch + 1}/{NUM_EPOCHS}", unit="batch" - ) as loader: - for data in loader: - inputs, labels = data - optimizer.zero_grad() - - # Forward pass - outputs = model(inputs) - loss = criterion(outputs, labels) - - # Backward pass and optimize - loss.backward() - optimizer.step() - loader.set_postfix(loss=loss.item()) + with closing(dataset): + for epoch in range(NUM_EPOCHS): + with tqdm( + train_loader, desc=f"epoch {epoch + 1}/{NUM_EPOCHS}", unit="batch" + ) as loader: + for data in loader: + inputs, labels = data + optimizer.zero_grad() + + # Forward pass + outputs = model(inputs) + loss = criterion(outputs, labels) + + # Backward pass and optimize + loss.backward() + optimizer.step() + loader.set_postfix(loss=loss.item()) diff --git a/src/datachain/cache.py b/src/datachain/cache.py index 8b3e8b07f..ba6e78bc0 100644 --- a/src/datachain/cache.py +++ b/src/datachain/cache.py @@ -1,4 +1,8 @@ import os +import shutil +from collections.abc import Iterator +from contextlib import contextmanager +from tempfile import mkdtemp from typing import TYPE_CHECKING, Optional from dvc_data.hashfile.db.local import LocalHashFileDB @@ -20,6 +24,23 @@ def try_scandir(path): pass +def get_temp_cache(tmp_dir: str, prefix: Optional[str] = None) -> "DataChainCache": + cache_dir = mkdtemp(prefix=prefix, dir=tmp_dir) + return DataChainCache(cache_dir, tmp_dir=tmp_dir) + + +@contextmanager +def temporary_cache( + tmp_dir: str, prefix: Optional[str] = None, delete: bool = False +) -> Iterator["DataChainCache"]: + cache = get_temp_cache(tmp_dir, prefix=prefix) + try: + yield cache + finally: + if delete: + cache.destroy() + + class DataChainCache: def __init__(self, cache_dir: str, tmp_dir: str): self.odb = LocalHashFileDB( @@ -94,6 +115,10 @@ def clear(self): """ self.odb.clear() + def destroy(self): + # `clear` leaves the prefix directory structure intact. + shutil.rmtree(self.cache_dir) + def get_total_size(self) -> int: total = 0 for subdir in try_scandir(self.odb.path): diff --git a/src/datachain/lib/prefetcher.py b/src/datachain/lib/prefetcher.py new file mode 100644 index 000000000..7e044244f --- /dev/null +++ b/src/datachain/lib/prefetcher.py @@ -0,0 +1,64 @@ +from contextlib import contextmanager, nullcontext +from functools import partial +from typing import TYPE_CHECKING, Optional + +from fsspec.callbacks import DEFAULT_CALLBACK, Callback + +from datachain.asyn import AsyncMapper +from datachain.cache import temporary_cache +from datachain.catalog.catalog import Catalog +from datachain.lib.file import File + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + from datachain.cache import DataChainCache as Cache + + +def noop(*args, **kwargs): + pass + + +async def _prefetch_input(row, catalog, download_cb): + try: + callback = download_cb.increment_file_count + except AttributeError: + callback = noop + + for obj in row: + if isinstance(obj, File): + obj._set_stream(catalog, True, download_cb) + await obj._prefetch() + callback() + return row + + +@contextmanager +def catalog_with_cache(catalog: Catalog, cache): + ocache = catalog.cache + try: + catalog.cache = cache + yield catalog + finally: + catalog.cache = ocache + + +def rows_prefetcher( + catalog, + rows, + prefetch: int, + cache: Optional["Cache"] = None, + download_cb: Callback = DEFAULT_CALLBACK, +): + cache_ctx: AbstractContextManager[Cache] + if cache: + cache_ctx = nullcontext(cache) + else: + tmp_dir = catalog.cache.tmp_dir + assert tmp_dir + cache_ctx = temporary_cache(tmp_dir, prefix="prefetch-") + + with cache_ctx as prefetch_cache, catalog_with_cache(catalog, prefetch_cache): + func = partial(_prefetch_input, download_cb=download_cb, catalog=catalog) + mapper = AsyncMapper(func, rows, workers=prefetch) + yield from mapper.iterate() diff --git a/src/datachain/lib/pytorch.py b/src/datachain/lib/pytorch.py index e85fb0aae..691da8d97 100644 --- a/src/datachain/lib/pytorch.py +++ b/src/datachain/lib/pytorch.py @@ -1,5 +1,7 @@ import logging +import os from collections.abc import Iterator +from contextlib import closing from typing import TYPE_CHECKING, Any, Callable, Optional from PIL import Image @@ -9,11 +11,14 @@ from torchvision.transforms import v2 from datachain import Session -from datachain.asyn import AsyncMapper +from datachain.cache import get_temp_cache from datachain.catalog import Catalog, get_catalog from datachain.lib.dc import DataChain +from datachain.lib.prefetcher import rows_prefetcher from datachain.lib.settings import Settings from datachain.lib.text import convert_text +from datachain.progress import CombinedDownloadCallback +from datachain.query.dataset import get_download_callback if TYPE_CHECKING: from torchvision.transforms.v2 import Transform @@ -75,6 +80,17 @@ def __init__( if (prefetch := dc_settings.prefetch) is not None: self.prefetch = prefetch + if self.cache: + self._cache = catalog.cache + else: + tmp_dir = catalog.cache.tmp_dir + assert tmp_dir + self._cache = get_temp_cache(tmp_dir, prefix="prefetch-") + + def close(self): + if not self.cache: + self._cache.destroy() + def _init_catalog(self, catalog: "Catalog"): # For compatibility with multiprocessing, # we can only store params in __init__(), as Catalog isn't picklable @@ -107,11 +123,25 @@ def _rows_iter(self, total_rank: int, total_workers: int): def __iter__(self) -> Iterator[Any]: total_rank, total_workers = self.get_rank_and_workers() rows = self._rows_iter(total_rank, total_workers) - if self.prefetch > 0: - from datachain.lib.udf import _prefetch_input - rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate() - yield from map(self._process_row, rows) + download_cb = CombinedDownloadCallback() + if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"): + download_cb = get_download_callback( + f"{total_rank}/{total_workers}", position=total_rank + ) + + if self.prefetch > 0: + catalog = self._get_catalog() + rows = rows_prefetcher( + catalog, + rows, + self.prefetch, + cache=self._cache, + download_cb=download_cb, + ) + + with download_cb, closing(rows): + yield from map(self._process_row, rows) def _process_row(self, row_features): row = [] diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index d708c0330..d6da37f74 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -8,11 +8,11 @@ from fsspec.callbacks import DEFAULT_CALLBACK, Callback from pydantic import BaseModel -from datachain.asyn import AsyncMapper from datachain.dataset import RowDict from datachain.lib.convert.flatten import flatten from datachain.lib.data_model import DataValue from datachain.lib.file import File +from datachain.lib.prefetcher import rows_prefetcher from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError from datachain.query.batch import ( Batch, @@ -280,13 +280,6 @@ def process_safe(self, obj_rows): return result_objs -async def _prefetch_input(row): - for obj in row: - if isinstance(obj, File): - await obj._prefetch() - return row - - class Mapper(UDFBase): """Inherit from this class to pass to `DataChain.map()`.""" @@ -308,9 +301,14 @@ def run( for row in udf_inputs ) if self.prefetch > 0: - prepared_inputs = AsyncMapper( - _prefetch_input, prepared_inputs, workers=self.prefetch - ).iterate() + _cache = self.catalog.cache if cache else None + prepared_inputs = rows_prefetcher( + self.catalog, + prepared_inputs, + self.prefetch, + cache=_cache, + download_cb=download_cb, + ) with contextlib.closing(prepared_inputs): for id_, *udf_args in prepared_inputs: @@ -385,9 +383,14 @@ def run( self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs ) if self.prefetch > 0: - prepared_inputs = AsyncMapper( - _prefetch_input, prepared_inputs, workers=self.prefetch - ).iterate() + _cache = self.catalog.cache if cache else None + prepared_inputs = rows_prefetcher( + self.catalog, + prepared_inputs, + self.prefetch, + cache=_cache, + download_cb=download_cb, + ) with contextlib.closing(prepared_inputs): for row in prepared_inputs: diff --git a/src/datachain/progress.py b/src/datachain/progress.py index 5507742ff..7eafd759b 100644 --- a/src/datachain/progress.py +++ b/src/datachain/progress.py @@ -5,6 +5,7 @@ from threading import RLock from typing import Any, ClassVar +from fsspec import Callback from fsspec.callbacks import TqdmCallback from tqdm import tqdm @@ -132,8 +133,24 @@ def format_dict(self): return d -class CombinedDownloadCallback(TqdmCallback): +class CombinedDownloadCallback(Callback): def set_size(self, size): # This is a no-op to prevent fsspec's .get_file() from setting the combined # download size to the size of the current file. pass + + def increment_file_count(self, n: int = 1) -> None: + pass + + +class TqdmCombinedDownloadCallback(CombinedDownloadCallback, TqdmCallback): + def __init__(self, tqdm_kwargs=None, *args, **kwargs): + self.files_count = 0 + tqdm_kwargs = tqdm_kwargs or {} + tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count) + super().__init__(tqdm_kwargs, *args, **kwargs) + + def increment_file_count(self, n: int = 1) -> None: + self.files_count += n + if self.tqdm is not None: + self.tqdm.postfix = f"{self.files_count} files" diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 3761bed4b..5f0f4f11c 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -44,7 +44,7 @@ from datachain.error import DatasetNotFoundError, QueryScriptCancelError from datachain.func.base import Function from datachain.lib.udf import UDFAdapter -from datachain.progress import CombinedDownloadCallback +from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback from datachain.sql.functions.random import rand from datachain.utils import ( batched, @@ -348,9 +348,15 @@ def process_udf_outputs( warehouse.insert_rows(udf_table, row_chunk) -def get_download_callback() -> Callback: - return CombinedDownloadCallback( - {"desc": "Download", "unit": "B", "unit_scale": True, "unit_divisor": 1024} +def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallback: + return TqdmCombinedDownloadCallback( + { + "desc": "Download" + suffix, + "unit": "B", + "unit_scale": True, + "unit_divisor": 1024, + **kwargs, + }, )