Skip to content

Commit

Permalink
hoist temporary cache creation to Mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Jan 3, 2025
1 parent 4742ea0 commit f9e580d
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 183 deletions.
6 changes: 6 additions & 0 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,12 @@ def find_column_to_str( # noqa: PLR0911
return ""


def clone_catalog_with_cache(catalog: "Catalog", cache: "DataChainCache") -> "Catalog":
clone = catalog.copy()
clone.cache = cache
return clone


class Catalog:
def __init__(
self,
Expand Down
17 changes: 5 additions & 12 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,25 +269,18 @@ def ensure_cached(self) -> None:
client = self._catalog.get_client(self.source)
client.download(self, callback=self._download_cb)

async def _prefetch(
self,
catalog: Optional["Catalog"] = None,
download_cb: Optional["Callback"] = None,
) -> bool:
async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
from datachain.client.hf import HfClient

catalog = catalog or self._catalog
download_cb = download_cb or self._download_cb
if catalog is None:
if self._catalog is None:
raise RuntimeError("cannot prefetch file because catalog is not setup")

client = catalog.get_client(self.source)
client = self._catalog.get_client(self.source)
if client.protocol == HfClient.protocol:
self._set_stream(catalog, self._caching_enabled, download_cb=download_cb)
return False

await client._download(self, callback=download_cb)
self._set_stream(catalog, caching_enabled=True) # reset download callback
await client._download(self, callback=download_cb or self._download_cb)
self._download_cb = DEFAULT_CALLBACK
return True

def get_local_path(self) -> Optional[str]:
Expand Down
65 changes: 0 additions & 65 deletions src/datachain/lib/prefetcher.py

This file was deleted.

26 changes: 16 additions & 10 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from collections.abc import Generator, Iterable, Iterator
from contextlib import closing
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional

from PIL import Image
Expand All @@ -11,10 +12,10 @@
from torchvision.transforms import v2

from datachain import Session
from datachain.asyn import AsyncMapper
from datachain.cache import get_temp_cache
from datachain.catalog import Catalog, get_catalog
from datachain.lib.dc import DataChain
from datachain.lib.prefetcher import rows_prefetcher
from datachain.lib.settings import Settings
from datachain.lib.text import convert_text
from datachain.progress import CombinedDownloadCallback
Expand Down Expand Up @@ -105,10 +106,14 @@ def _get_catalog(self) -> "Catalog":
ms = ms_cls(*ms_args, **ms_kwargs)
wh_cls, wh_args, wh_kwargs = self._wh_params
wh = wh_cls(*wh_args, **wh_kwargs)
return Catalog(ms, wh, **self._catalog_params)
catalog = Catalog(ms, wh, **self._catalog_params)
catalog.cache = self._cache
return catalog

def _row_iter(
self, total_rank: int, total_workers: int
self,
total_rank: int,
total_workers: int,
) -> Generator[tuple[Any, ...], None, None]:
catalog = self._get_catalog()
session = Session("PyTorch", catalog=catalog)
Expand All @@ -132,16 +137,17 @@ def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]:

rows = self._row_iter(total_rank, total_workers)
if self.prefetch > 0:
catalog = self._get_catalog()
rows = rows_prefetcher(
catalog,
rows,
self.prefetch,
cache=self._cache,
from datachain.lib.udf import _prefetch_input

func = partial(
_prefetch_input,
download_cb=download_cb,
after_prefetch=download_cb.increment_file_count,
)
mapper = AsyncMapper(func, rows, workers=self.prefetch)
rows = mapper.iterate() # type: ignore[assignment]

with download_cb:
with download_cb, closing(rows):
yield from rows

def __iter__(self) -> Iterator[list[Any]]:
Expand Down
150 changes: 103 additions & 47 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import contextlib
import sys
import traceback
from collections.abc import Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Callable, Optional
from contextlib import closing, nullcontext
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar

import attrs
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from pydantic import BaseModel

from datachain.asyn import AsyncMapper
from datachain.cache import temporary_cache
from datachain.catalog.catalog import clone_catalog_with_cache
from datachain.dataset import RowDict
from datachain.lib.convert.flatten import flatten
from datachain.lib.data_model import DataValue
from datachain.lib.file import File
from datachain.lib.prefetcher import rows_prefetcher
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
from datachain.query.batch import (
Batch,
Expand All @@ -24,14 +26,18 @@

if TYPE_CHECKING:
from collections import abc
from contextlib import AbstractContextManager

from typing_extensions import Self

from datachain.cache import DataChainCache as Cache
from datachain.catalog import Catalog
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.udf_signature import UdfSignature
from datachain.query.batch import RowsOutput

T = TypeVar("T", bound=Sequence[Any])


class UdfError(DataChainParamsError):
def __init__(self, msg):
Expand Down Expand Up @@ -279,45 +285,79 @@ def process_safe(self, obj_rows):
return result_objs


def noop(*args, **kwargs):
pass


async def _prefetch_input(
row: T,
download_cb: Optional["Callback"] = None,
after_prefetch: "Callable[[], None]" = noop,
) -> T:
for obj in row:
if isinstance(obj, File) and await obj._prefetch(download_cb):
after_prefetch()
return row


class Mapper(UDFBase):
"""Inherit from this class to pass to `DataChain.map()`."""

prefetch: int = 2

def run(
def _iter_with_prefetch(
self,
udf_fields: "Sequence[str]",
udf_inputs: "Iterable[Sequence[Any]]",
catalog: "Catalog",
cache: bool,
download_cb: Callback = DEFAULT_CALLBACK,
processed_cb: Callback = DEFAULT_CALLBACK,
) -> Iterator[Iterable[UDFResult]]:
self.catalog = catalog
self.setup()
) -> "abc.Generator[Sequence[Any], None, None]":
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row_and_id(row, udf_fields, cache, download_cb)
for row in udf_inputs
)
if self.prefetch > 0:
_cache = self.catalog.cache if cache else None
prepared_inputs = rows_prefetcher(
self.catalog,
prepared_inputs,
self.prefetch,
cache=_cache,
download_cb=download_cb,
)
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()
yield from prepared_inputs

with contextlib.closing(prepared_inputs):
for id_, *udf_args in prepared_inputs:
result_objs = self.process_safe(udf_args)
udf_output = self._flatten_row(result_objs)
output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
processed_cb.relative_update(1)
yield output
def run(
self,
udf_fields: "Sequence[str]",
udf_inputs: "Iterable[Sequence[Any]]",
catalog: "Catalog",
cache: bool,
download_cb: Callback = DEFAULT_CALLBACK,
processed_cb: Callback = DEFAULT_CALLBACK,
) -> Iterator[Iterable[UDFResult]]:
_cache = catalog.cache
tmp_dir = _cache.tmp_dir
assert tmp_dir

self.teardown()
cache_ctx: AbstractContextManager[Cache] = nullcontext(_cache)
if self.prefetch > 0 and not cache:
cache_ctx = temporary_cache(tmp_dir, prefix="prefetch-")

with cache_ctx as _cache:
self.catalog = clone_catalog_with_cache(catalog, _cache)
self.setup()

prepared_inputs = self._iter_with_prefetch(
udf_fields, udf_inputs, cache, download_cb
)
with closing(prepared_inputs):
for id_, *udf_args in prepared_inputs:
result_objs = self.process_safe(udf_args)
udf_output = self._flatten_row(result_objs)
output = [
{"sys__id": id_} | dict(zip(self.signal_names, udf_output))
]
processed_cb.relative_update(1)
yield output

self.teardown()
self.catalog = catalog


class BatchMapper(UDFBase):
Expand Down Expand Up @@ -367,39 +407,55 @@ class Generator(UDFBase):
is_output_batched = True
prefetch: int = 2

def run(
def _iter_with_prefetch(
self,
udf_fields: "Sequence[str]",
udf_inputs: "Iterable[Sequence[Any]]",
catalog: "Catalog",
cache: bool,
download_cb: Callback = DEFAULT_CALLBACK,
processed_cb: Callback = DEFAULT_CALLBACK,
) -> Iterator[Iterable[UDFResult]]:
self.catalog = catalog
self.setup()
) -> "abc.Generator[Sequence[Any], None, None]":
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
)
if self.prefetch > 0:
_cache = self.catalog.cache if cache else None
prepared_inputs = rows_prefetcher(
self.catalog,
prepared_inputs,
self.prefetch,
cache=_cache,
download_cb=download_cb,
)
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()
yield from prepared_inputs

with contextlib.closing(prepared_inputs):
for row in prepared_inputs:
result_objs = self.process_safe(row)
udf_outputs = (self._flatten_row(row) for row in result_objs)
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
processed_cb.relative_update(1)
yield output
def run(
self,
udf_fields: "Sequence[str]",
udf_inputs: "Iterable[Sequence[Any]]",
catalog: "Catalog",
cache: bool,
download_cb: Callback = DEFAULT_CALLBACK,
processed_cb: Callback = DEFAULT_CALLBACK,
) -> Iterator[Iterable[UDFResult]]:
_cache = catalog.cache
tmp_dir = _cache.tmp_dir
assert tmp_dir

self.teardown()
cache_ctx: AbstractContextManager[Cache] = nullcontext(_cache)
if self.prefetch > 0 and not cache:
cache_ctx = temporary_cache(tmp_dir, prefix="prefetch-")

with cache_ctx as _cache:
self.catalog = clone_catalog_with_cache(catalog, _cache)
self.setup()

prepared_inputs = self._iter_with_prefetch(
udf_fields, udf_inputs, cache, download_cb
)
with closing(prepared_inputs):
for row in prepared_inputs:
result_objs = self.process_safe(row)
udf_outputs = (self._flatten_row(row) for row in result_objs)
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
processed_cb.relative_update(1)
yield output
self.teardown()
self.catalog = catalog


class Aggregator(UDFBase):
Expand Down
Loading

0 comments on commit f9e580d

Please sign in to comment.