From 950170b5790b2c9201e70a47f657fcac04584a36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Fri, 3 Jan 2025 21:40:29 +0545 Subject: [PATCH] refactor udfs --- src/datachain/lib/pytorch.py | 20 +++--- src/datachain/lib/udf.py | 136 +++++++++++++++-------------------- 2 files changed, 66 insertions(+), 90 deletions(-) diff --git a/src/datachain/lib/pytorch.py b/src/datachain/lib/pytorch.py index bcb0a9337..5b22835cc 100644 --- a/src/datachain/lib/pytorch.py +++ b/src/datachain/lib/pytorch.py @@ -2,7 +2,6 @@ import os from collections.abc import Generator, Iterable, Iterator from contextlib import closing -from functools import partial from typing import TYPE_CHECKING, Any, Callable, Optional from PIL import Image @@ -12,7 +11,6 @@ from torchvision.transforms import v2 from datachain import Session -from datachain.asyn import AsyncMapper from datachain.cache import get_temp_cache from datachain.catalog import Catalog, get_catalog from datachain.lib.dc import DataChain @@ -128,6 +126,8 @@ def _row_iter( yield from ds.collect() def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]: + from datachain.lib.udf import _prefetch_inputs + total_rank, total_workers = self.get_rank_and_workers() download_cb = CombinedDownloadCallback() if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"): @@ -136,16 +136,12 @@ def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]: ) rows = self._row_iter(total_rank, total_workers) - if self.prefetch > 0: - from datachain.lib.udf import _prefetch_input - - func = partial( - _prefetch_input, - download_cb=download_cb, - after_prefetch=download_cb.increment_file_count, - ) - mapper = AsyncMapper(func, rows, workers=self.prefetch) - rows = mapper.iterate() # type: ignore[assignment] + rows = _prefetch_inputs( + rows, + self.prefetch, + download_cb=download_cb, + after_prefetch=download_cb.increment_file_count, + ) with download_cb, closing(rows): yield from rows diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index 002ef9c9f..e5b52dca8 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -1,8 +1,9 @@ import sys import traceback -from collections.abc import Iterable, Iterator, Mapping, Sequence +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from contextlib import closing, nullcontext -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar +from functools import partial +from typing import TYPE_CHECKING, Any, Optional, TypeVar import attrs from fsspec.callbacks import DEFAULT_CALLBACK, Callback @@ -154,12 +155,10 @@ def process(self, file) -> list[float]: """ is_output_batched = False - catalog: "Optional[Catalog]" def __init__(self): self.params: Optional[SignalSchema] = None self.output = None - self.catalog = None self._func = None def process(self, *args, **kwargs): @@ -248,26 +247,23 @@ def _obj_to_list(obj): return flatten(obj) if isinstance(obj, BaseModel) else [obj] def _parse_row( - self, row_dict: RowDict, cache: bool, download_cb: Callback + self, row_dict: RowDict, catalog: "Catalog", cache: bool, download_cb: Callback ) -> list[DataValue]: assert self.params row = [row_dict[p] for p in self.params.to_udf_spec()] obj_row = self.params.row_to_objs(row) for obj in obj_row: if isinstance(obj, File): - assert self.catalog is not None - obj._set_stream( - self.catalog, caching_enabled=cache, download_cb=download_cb - ) + obj._set_stream(catalog, caching_enabled=cache, download_cb=download_cb) return obj_row - def _prepare_row(self, row, udf_fields, cache, download_cb): + def _prepare_row(self, row, udf_fields, catalog, cache, download_cb): row_dict = RowDict(zip(udf_fields, row)) - return self._parse_row(row_dict, cache, download_cb) + return self._parse_row(row_dict, catalog, cache, download_cb) - def _prepare_row_and_id(self, row, udf_fields, cache, download_cb): + def _prepare_row_and_id(self, row, udf_fields, catalog, cache, download_cb): row_dict = RowDict(zip(udf_fields, row)) - udf_input = self._parse_row(row_dict, cache, download_cb) + udf_input = self._parse_row(row_dict, catalog, cache, download_cb) return row_dict["sys__id"], *udf_input def process_safe(self, obj_rows): @@ -300,28 +296,37 @@ async def _prefetch_input( return row +def _prefetch_inputs( + prepared_inputs: "Iterable[T]", + prefetch: int = 0, + download_cb: Optional["Callback"] = None, + after_prefetch: "Callable[[], None]" = noop, +) -> "abc.Generator[T, None, None]": + if prefetch > 0: + f = partial( + _prefetch_input, + download_cb=download_cb, + after_prefetch=after_prefetch, + ) + prepared_inputs = AsyncMapper(f, prepared_inputs, workers=prefetch).iterate() # type: ignore[assignment] + yield from prepared_inputs + + +def _get_cache( + cache: "Cache", prefetch: int = 0, use_cache: bool = False +) -> "AbstractContextManager[Cache]": + tmp_dir = cache.tmp_dir + assert tmp_dir + if prefetch and not use_cache: + return temporary_cache(tmp_dir, prefix="prefetch-") + return nullcontext(cache) + + class Mapper(UDFBase): """Inherit from this class to pass to `DataChain.map()`.""" prefetch: int = 2 - def _iter_with_prefetch( - self, - udf_fields: "Sequence[str]", - udf_inputs: "Iterable[Sequence[Any]]", - cache: bool, - download_cb: Callback = DEFAULT_CALLBACK, - ) -> "abc.Generator[Sequence[Any], None, None]": - prepared_inputs: abc.Generator[Sequence[Any], None, None] = ( - self._prepare_row_and_id(row, udf_fields, cache, download_cb) - for row in udf_inputs - ) - if self.prefetch > 0: - prepared_inputs = AsyncMapper( - _prefetch_input, prepared_inputs, workers=self.prefetch - ).iterate() - yield from prepared_inputs - def run( self, udf_fields: "Sequence[str]", @@ -331,21 +336,17 @@ def run( download_cb: Callback = DEFAULT_CALLBACK, processed_cb: Callback = DEFAULT_CALLBACK, ) -> Iterator[Iterable[UDFResult]]: - _cache = catalog.cache - tmp_dir = _cache.tmp_dir - assert tmp_dir - - cache_ctx: AbstractContextManager[Cache] = nullcontext(_cache) - if self.prefetch > 0 and not cache: - cache_ctx = temporary_cache(tmp_dir, prefix="prefetch-") + self.setup() - with cache_ctx as _cache: - self.catalog = clone_catalog_with_cache(catalog, _cache) - self.setup() + with _get_cache(catalog.cache, self.prefetch, use_cache=cache) as _cache: + catalog = clone_catalog_with_cache(catalog, _cache) - prepared_inputs = self._iter_with_prefetch( - udf_fields, udf_inputs, cache, download_cb + prepared_inputs: abc.Generator[Sequence[Any], None, None] = ( + self._prepare_row_and_id(row, udf_fields, catalog, cache, download_cb) + for row in udf_inputs ) + prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch) + with closing(prepared_inputs): for id_, *udf_args in prepared_inputs: result_objs = self.process_safe(udf_args) @@ -356,8 +357,7 @@ def run( processed_cb.relative_update(1) yield output - self.teardown() - self.catalog = catalog + self.teardown() class BatchMapper(UDFBase): @@ -374,14 +374,15 @@ def run( download_cb: Callback = DEFAULT_CALLBACK, processed_cb: Callback = DEFAULT_CALLBACK, ) -> Iterator[Iterable[UDFResult]]: - self.catalog = catalog self.setup() for batch in udf_inputs: n_rows = len(batch.rows) row_ids, *udf_args = zip( *[ - self._prepare_row_and_id(row, udf_fields, cache, download_cb) + self._prepare_row_and_id( + row, udf_fields, catalog, cache, download_cb + ) for row in batch.rows ] ) @@ -407,22 +408,6 @@ class Generator(UDFBase): is_output_batched = True prefetch: int = 2 - def _iter_with_prefetch( - self, - udf_fields: "Sequence[str]", - udf_inputs: "Iterable[Sequence[Any]]", - cache: bool, - download_cb: Callback = DEFAULT_CALLBACK, - ) -> "abc.Generator[Sequence[Any], None, None]": - prepared_inputs: abc.Generator[Sequence[Any], None, None] = ( - self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs - ) - if self.prefetch > 0: - prepared_inputs = AsyncMapper( - _prefetch_input, prepared_inputs, workers=self.prefetch - ).iterate() - yield from prepared_inputs - def run( self, udf_fields: "Sequence[str]", @@ -432,21 +417,17 @@ def run( download_cb: Callback = DEFAULT_CALLBACK, processed_cb: Callback = DEFAULT_CALLBACK, ) -> Iterator[Iterable[UDFResult]]: - _cache = catalog.cache - tmp_dir = _cache.tmp_dir - assert tmp_dir - - cache_ctx: AbstractContextManager[Cache] = nullcontext(_cache) - if self.prefetch > 0 and not cache: - cache_ctx = temporary_cache(tmp_dir, prefix="prefetch-") + self.setup() - with cache_ctx as _cache: - self.catalog = clone_catalog_with_cache(catalog, _cache) - self.setup() + with _get_cache(catalog.cache, self.prefetch, use_cache=cache) as _cache: + catalog = clone_catalog_with_cache(catalog, _cache) - prepared_inputs = self._iter_with_prefetch( - udf_fields, udf_inputs, cache, download_cb + prepared_inputs: abc.Generator[Sequence[Any], None, None] = ( + self._prepare_row(row, udf_fields, catalog, cache, download_cb) + for row in udf_inputs ) + prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch) + with closing(prepared_inputs): for row in prepared_inputs: result_objs = self.process_safe(row) @@ -454,8 +435,8 @@ def run( output = (dict(zip(self.signal_names, row)) for row in udf_outputs) processed_cb.relative_update(1) yield output - self.teardown() - self.catalog = catalog + + self.teardown() class Aggregator(UDFBase): @@ -472,13 +453,12 @@ def run( download_cb: Callback = DEFAULT_CALLBACK, processed_cb: Callback = DEFAULT_CALLBACK, ) -> Iterator[Iterable[UDFResult]]: - self.catalog = catalog self.setup() for batch in udf_inputs: udf_args = zip( *[ - self._prepare_row(row, udf_fields, cache, download_cb) + self._prepare_row(row, udf_fields, catalog, cache, download_cb) for row in batch.rows ] )