Skip to content

Commit

Permalink
wip: prefetch cache
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Dec 24, 2024
1 parent 60256d6 commit 15c30fb
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 43 deletions.
41 changes: 22 additions & 19 deletions examples/get_started/torch-loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import multiprocessing
import os
from contextlib import closing
from posixpath import basename

import torch
Expand Down Expand Up @@ -56,7 +57,7 @@ def forward(self, x):
if __name__ == "__main__":
ds = (
DataChain.from_storage(STORAGE, type="image")
.settings(cache=True, prefetch=25)
.settings(prefetch=25)
.filter(C("file.path").glob("*.jpg"))
.map(
label=lambda path: label_to_int(basename(path)[:3], CLASSES),
Expand All @@ -65,10 +66,11 @@ def forward(self, x):
)
)

dataset = ds.to_pytorch(transform=transform)
train_loader = DataLoader(
ds.to_pytorch(transform=transform),
dataset,
batch_size=25,
num_workers=max(4, os.cpu_count() or 2),
num_workers=min(4, os.cpu_count() or 2),
persistent_workers=True,
multiprocessing_context=multiprocessing.get_context("spawn"),
)
Expand All @@ -78,19 +80,20 @@ def forward(self, x):
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
for epoch in range(NUM_EPOCHS):
with tqdm(
train_loader, desc=f"epoch {epoch + 1}/{NUM_EPOCHS}", unit="batch"
) as loader:
for data in loader:
inputs, labels = data
optimizer.zero_grad()

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)

# Backward pass and optimize
loss.backward()
optimizer.step()
loader.set_postfix(loss=loss.item())
with closing(dataset):
for epoch in range(NUM_EPOCHS):
with tqdm(
train_loader, desc=f"epoch {epoch + 1}/{NUM_EPOCHS}", unit="batch"
) as loader:
for data in loader:
inputs, labels = data
optimizer.zero_grad()

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)

# Backward pass and optimize
loss.backward()
optimizer.step()
loader.set_postfix(loss=loss.item())
25 changes: 25 additions & 0 deletions src/datachain/cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import os
import shutil
from collections.abc import Iterator
from contextlib import contextmanager
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Optional

from dvc_data.hashfile.db.local import LocalHashFileDB
Expand All @@ -20,6 +24,23 @@ def try_scandir(path):
pass


def get_temp_cache(tmp_dir: str, prefix: Optional[str] = None) -> "DataChainCache":
cache_dir = mkdtemp(prefix=prefix, dir=tmp_dir)
return DataChainCache(cache_dir, tmp_dir=tmp_dir)


@contextmanager
def temporary_cache(
tmp_dir: str, prefix: Optional[str] = None, delete: bool = False
) -> Iterator["DataChainCache"]:
cache = get_temp_cache(tmp_dir, prefix=prefix)
try:
yield cache
finally:
if delete:
cache.destroy()


class DataChainCache:
def __init__(self, cache_dir: str, tmp_dir: str):
self.odb = LocalHashFileDB(
Expand Down Expand Up @@ -94,6 +115,10 @@ def clear(self):
"""
self.odb.clear()

def destroy(self):
# `clear` leaves the prefix directory structure intact.
shutil.rmtree(self.cache_dir)

def get_total_size(self) -> int:
total = 0
for subdir in try_scandir(self.odb.path):
Expand Down
64 changes: 64 additions & 0 deletions src/datachain/lib/prefetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from contextlib import contextmanager, nullcontext
from functools import partial
from typing import TYPE_CHECKING, Optional

from fsspec.callbacks import DEFAULT_CALLBACK, Callback

from datachain.asyn import AsyncMapper
from datachain.cache import temporary_cache
from datachain.catalog.catalog import Catalog
from datachain.lib.file import File

if TYPE_CHECKING:
from contextlib import AbstractContextManager

from datachain.cache import DataChainCache as Cache


def noop(*args, **kwargs):
pass


async def _prefetch_input(row, catalog, download_cb):
try:
callback = download_cb.increment_file_count
except AttributeError:
callback = noop

for obj in row:
if isinstance(obj, File):
obj._set_stream(catalog, True, download_cb)
await obj._prefetch()
callback()
return row


@contextmanager
def catalog_with_cache(catalog: Catalog, cache):
ocache = catalog.cache
try:
catalog.cache = cache
yield catalog
finally:
catalog.cache = ocache


def rows_prefetcher(
catalog,
rows,
prefetch: int,
cache: Optional["Cache"] = None,
download_cb: Callback = DEFAULT_CALLBACK,
):
cache_ctx: AbstractContextManager[Cache]
if cache:
cache_ctx = nullcontext(cache)
else:
tmp_dir = catalog.cache.tmp_dir
assert tmp_dir
cache_ctx = temporary_cache(tmp_dir, prefix="prefetch-")

with cache_ctx as prefetch_cache, catalog_with_cache(catalog, prefetch_cache):
func = partial(_prefetch_input, download_cb=download_cb, catalog=catalog)
mapper = AsyncMapper(func, rows, workers=prefetch)
yield from mapper.iterate()
40 changes: 35 additions & 5 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
from collections.abc import Iterator
from contextlib import closing
from typing import TYPE_CHECKING, Any, Callable, Optional

from PIL import Image
Expand All @@ -9,11 +11,14 @@
from torchvision.transforms import v2

from datachain import Session
from datachain.asyn import AsyncMapper
from datachain.cache import get_temp_cache
from datachain.catalog import Catalog, get_catalog
from datachain.lib.dc import DataChain
from datachain.lib.prefetcher import rows_prefetcher
from datachain.lib.settings import Settings
from datachain.lib.text import convert_text
from datachain.progress import CombinedDownloadCallback
from datachain.query.dataset import get_download_callback

if TYPE_CHECKING:
from torchvision.transforms.v2 import Transform
Expand Down Expand Up @@ -75,6 +80,17 @@ def __init__(
if (prefetch := dc_settings.prefetch) is not None:
self.prefetch = prefetch

if self.cache:
self._cache = catalog.cache
else:
tmp_dir = catalog.cache.tmp_dir
assert tmp_dir
self._cache = get_temp_cache(tmp_dir, prefix="prefetch-")

def close(self):
if not self.cache:
self._cache.destroy()

def _init_catalog(self, catalog: "Catalog"):
# For compatibility with multiprocessing,
# we can only store params in __init__(), as Catalog isn't picklable
Expand Down Expand Up @@ -107,11 +123,25 @@ def _rows_iter(self, total_rank: int, total_workers: int):
def __iter__(self) -> Iterator[Any]:
total_rank, total_workers = self.get_rank_and_workers()
rows = self._rows_iter(total_rank, total_workers)
if self.prefetch > 0:
from datachain.lib.udf import _prefetch_input

rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
yield from map(self._process_row, rows)
download_cb = CombinedDownloadCallback()
if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"):
download_cb = get_download_callback(
f"{total_rank}/{total_workers}", position=total_rank
)

if self.prefetch > 0:
catalog = self._get_catalog()
rows = rows_prefetcher(
catalog,
rows,
self.prefetch,
cache=self._cache,
download_cb=download_cb,
)

with download_cb, closing(rows):
yield from map(self._process_row, rows)

def _process_row(self, row_features):
row = []
Expand Down
31 changes: 17 additions & 14 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from pydantic import BaseModel

from datachain.asyn import AsyncMapper
from datachain.dataset import RowDict
from datachain.lib.convert.flatten import flatten
from datachain.lib.data_model import DataValue
from datachain.lib.file import File
from datachain.lib.prefetcher import rows_prefetcher
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
from datachain.query.batch import (
Batch,
Expand Down Expand Up @@ -280,13 +280,6 @@ def process_safe(self, obj_rows):
return result_objs


async def _prefetch_input(row):
for obj in row:
if isinstance(obj, File):
await obj._prefetch()
return row


class Mapper(UDFBase):
"""Inherit from this class to pass to `DataChain.map()`."""

Expand All @@ -308,9 +301,14 @@ def run(
for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()
_cache = self.catalog.cache if cache else None
prepared_inputs = rows_prefetcher(
self.catalog,
prepared_inputs,
self.prefetch,
cache=_cache,
download_cb=download_cb,
)

with contextlib.closing(prepared_inputs):
for id_, *udf_args in prepared_inputs:
Expand Down Expand Up @@ -385,9 +383,14 @@ def run(
self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()
_cache = self.catalog.cache if cache else None
prepared_inputs = rows_prefetcher(
self.catalog,
prepared_inputs,
self.prefetch,
cache=_cache,
download_cb=download_cb,
)

with contextlib.closing(prepared_inputs):
for row in prepared_inputs:
Expand Down
19 changes: 18 additions & 1 deletion src/datachain/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from threading import RLock
from typing import Any, ClassVar

from fsspec import Callback
from fsspec.callbacks import TqdmCallback
from tqdm import tqdm

Expand Down Expand Up @@ -132,8 +133,24 @@ def format_dict(self):
return d


class CombinedDownloadCallback(TqdmCallback):
class CombinedDownloadCallback(Callback):
def set_size(self, size):
# This is a no-op to prevent fsspec's .get_file() from setting the combined
# download size to the size of the current file.
pass

def increment_file_count(self, n: int = 1) -> None:
pass


class TqdmCombinedDownloadCallback(CombinedDownloadCallback, TqdmCallback):
def __init__(self, tqdm_kwargs=None, *args, **kwargs):
self.files_count = 0
tqdm_kwargs = tqdm_kwargs or {}
tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count)
super().__init__(tqdm_kwargs, *args, **kwargs)

def increment_file_count(self, n: int = 1) -> None:
self.files_count += n
if self.tqdm is not None:
self.tqdm.postfix = f"{self.files_count} files"
14 changes: 10 additions & 4 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
from datachain.func.base import Function
from datachain.lib.udf import UDFAdapter
from datachain.progress import CombinedDownloadCallback
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
from datachain.sql.functions.random import rand
from datachain.utils import (
batched,
Expand Down Expand Up @@ -348,9 +348,15 @@ def process_udf_outputs(
warehouse.insert_rows(udf_table, row_chunk)


def get_download_callback() -> Callback:
return CombinedDownloadCallback(
{"desc": "Download", "unit": "B", "unit_scale": True, "unit_divisor": 1024}
def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallback:
return TqdmCombinedDownloadCallback(
{
"desc": "Download" + suffix,
"unit": "B",
"unit_scale": True,
"unit_divisor": 1024,
**kwargs,
},
)


Expand Down

0 comments on commit 15c30fb

Please sign in to comment.