Skip to content

Commit

Permalink
refactor udfs
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Jan 3, 2025
1 parent bf24f57 commit 950170b
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 90 deletions.
20 changes: 8 additions & 12 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from collections.abc import Generator, Iterable, Iterator
from contextlib import closing
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional

from PIL import Image
Expand All @@ -12,7 +11,6 @@
from torchvision.transforms import v2

from datachain import Session
from datachain.asyn import AsyncMapper
from datachain.cache import get_temp_cache
from datachain.catalog import Catalog, get_catalog
from datachain.lib.dc import DataChain
Expand Down Expand Up @@ -128,6 +126,8 @@ def _row_iter(
yield from ds.collect()

def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]:
from datachain.lib.udf import _prefetch_inputs

total_rank, total_workers = self.get_rank_and_workers()
download_cb = CombinedDownloadCallback()
if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"):
Expand All @@ -136,16 +136,12 @@ def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]:
)

rows = self._row_iter(total_rank, total_workers)
if self.prefetch > 0:
from datachain.lib.udf import _prefetch_input

func = partial(
_prefetch_input,
download_cb=download_cb,
after_prefetch=download_cb.increment_file_count,
)
mapper = AsyncMapper(func, rows, workers=self.prefetch)
rows = mapper.iterate() # type: ignore[assignment]
rows = _prefetch_inputs(
rows,
self.prefetch,
download_cb=download_cb,
after_prefetch=download_cb.increment_file_count,
)

with download_cb, closing(rows):
yield from rows
Expand Down
136 changes: 58 additions & 78 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import sys
import traceback
from collections.abc import Iterable, Iterator, Mapping, Sequence
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from contextlib import closing, nullcontext
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar
from functools import partial
from typing import TYPE_CHECKING, Any, Optional, TypeVar

import attrs
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
Expand Down Expand Up @@ -154,12 +155,10 @@ def process(self, file) -> list[float]:
"""

is_output_batched = False
catalog: "Optional[Catalog]"

def __init__(self):
self.params: Optional[SignalSchema] = None
self.output = None
self.catalog = None
self._func = None

def process(self, *args, **kwargs):
Expand Down Expand Up @@ -248,26 +247,23 @@ def _obj_to_list(obj):
return flatten(obj) if isinstance(obj, BaseModel) else [obj]

def _parse_row(
self, row_dict: RowDict, cache: bool, download_cb: Callback
self, row_dict: RowDict, catalog: "Catalog", cache: bool, download_cb: Callback
) -> list[DataValue]:
assert self.params
row = [row_dict[p] for p in self.params.to_udf_spec()]
obj_row = self.params.row_to_objs(row)
for obj in obj_row:
if isinstance(obj, File):
assert self.catalog is not None
obj._set_stream(
self.catalog, caching_enabled=cache, download_cb=download_cb
)
obj._set_stream(catalog, caching_enabled=cache, download_cb=download_cb)
return obj_row

def _prepare_row(self, row, udf_fields, cache, download_cb):
def _prepare_row(self, row, udf_fields, catalog, cache, download_cb):
row_dict = RowDict(zip(udf_fields, row))
return self._parse_row(row_dict, cache, download_cb)
return self._parse_row(row_dict, catalog, cache, download_cb)

def _prepare_row_and_id(self, row, udf_fields, cache, download_cb):
def _prepare_row_and_id(self, row, udf_fields, catalog, cache, download_cb):
row_dict = RowDict(zip(udf_fields, row))
udf_input = self._parse_row(row_dict, cache, download_cb)
udf_input = self._parse_row(row_dict, catalog, cache, download_cb)
return row_dict["sys__id"], *udf_input

def process_safe(self, obj_rows):
Expand Down Expand Up @@ -300,28 +296,37 @@ async def _prefetch_input(
return row


def _prefetch_inputs(
prepared_inputs: "Iterable[T]",
prefetch: int = 0,
download_cb: Optional["Callback"] = None,
after_prefetch: "Callable[[], None]" = noop,
) -> "abc.Generator[T, None, None]":
if prefetch > 0:
f = partial(
_prefetch_input,
download_cb=download_cb,
after_prefetch=after_prefetch,
)
prepared_inputs = AsyncMapper(f, prepared_inputs, workers=prefetch).iterate() # type: ignore[assignment]
yield from prepared_inputs


def _get_cache(
cache: "Cache", prefetch: int = 0, use_cache: bool = False
) -> "AbstractContextManager[Cache]":
tmp_dir = cache.tmp_dir
assert tmp_dir
if prefetch and not use_cache:
return temporary_cache(tmp_dir, prefix="prefetch-")
return nullcontext(cache)


class Mapper(UDFBase):
"""Inherit from this class to pass to `DataChain.map()`."""

prefetch: int = 2

def _iter_with_prefetch(
self,
udf_fields: "Sequence[str]",
udf_inputs: "Iterable[Sequence[Any]]",
cache: bool,
download_cb: Callback = DEFAULT_CALLBACK,
) -> "abc.Generator[Sequence[Any], None, None]":
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row_and_id(row, udf_fields, cache, download_cb)
for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()
yield from prepared_inputs

def run(
self,
udf_fields: "Sequence[str]",
Expand All @@ -331,21 +336,17 @@ def run(
download_cb: Callback = DEFAULT_CALLBACK,
processed_cb: Callback = DEFAULT_CALLBACK,
) -> Iterator[Iterable[UDFResult]]:
_cache = catalog.cache
tmp_dir = _cache.tmp_dir
assert tmp_dir

cache_ctx: AbstractContextManager[Cache] = nullcontext(_cache)
if self.prefetch > 0 and not cache:
cache_ctx = temporary_cache(tmp_dir, prefix="prefetch-")
self.setup()

with cache_ctx as _cache:
self.catalog = clone_catalog_with_cache(catalog, _cache)
self.setup()
with _get_cache(catalog.cache, self.prefetch, use_cache=cache) as _cache:
catalog = clone_catalog_with_cache(catalog, _cache)

prepared_inputs = self._iter_with_prefetch(
udf_fields, udf_inputs, cache, download_cb
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row_and_id(row, udf_fields, catalog, cache, download_cb)
for row in udf_inputs
)
prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)

with closing(prepared_inputs):
for id_, *udf_args in prepared_inputs:
result_objs = self.process_safe(udf_args)
Expand All @@ -356,8 +357,7 @@ def run(
processed_cb.relative_update(1)
yield output

self.teardown()
self.catalog = catalog
self.teardown()


class BatchMapper(UDFBase):
Expand All @@ -374,14 +374,15 @@ def run(
download_cb: Callback = DEFAULT_CALLBACK,
processed_cb: Callback = DEFAULT_CALLBACK,
) -> Iterator[Iterable[UDFResult]]:
self.catalog = catalog
self.setup()

for batch in udf_inputs:
n_rows = len(batch.rows)
row_ids, *udf_args = zip(
*[
self._prepare_row_and_id(row, udf_fields, cache, download_cb)
self._prepare_row_and_id(
row, udf_fields, catalog, cache, download_cb
)
for row in batch.rows
]
)
Expand All @@ -407,22 +408,6 @@ class Generator(UDFBase):
is_output_batched = True
prefetch: int = 2

def _iter_with_prefetch(
self,
udf_fields: "Sequence[str]",
udf_inputs: "Iterable[Sequence[Any]]",
cache: bool,
download_cb: Callback = DEFAULT_CALLBACK,
) -> "abc.Generator[Sequence[Any], None, None]":
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()
yield from prepared_inputs

def run(
self,
udf_fields: "Sequence[str]",
Expand All @@ -432,30 +417,26 @@ def run(
download_cb: Callback = DEFAULT_CALLBACK,
processed_cb: Callback = DEFAULT_CALLBACK,
) -> Iterator[Iterable[UDFResult]]:
_cache = catalog.cache
tmp_dir = _cache.tmp_dir
assert tmp_dir

cache_ctx: AbstractContextManager[Cache] = nullcontext(_cache)
if self.prefetch > 0 and not cache:
cache_ctx = temporary_cache(tmp_dir, prefix="prefetch-")
self.setup()

with cache_ctx as _cache:
self.catalog = clone_catalog_with_cache(catalog, _cache)
self.setup()
with _get_cache(catalog.cache, self.prefetch, use_cache=cache) as _cache:
catalog = clone_catalog_with_cache(catalog, _cache)

prepared_inputs = self._iter_with_prefetch(
udf_fields, udf_inputs, cache, download_cb
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row(row, udf_fields, catalog, cache, download_cb)
for row in udf_inputs
)
prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)

with closing(prepared_inputs):
for row in prepared_inputs:
result_objs = self.process_safe(row)
udf_outputs = (self._flatten_row(row) for row in result_objs)
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
processed_cb.relative_update(1)
yield output
self.teardown()
self.catalog = catalog

self.teardown()


class Aggregator(UDFBase):
Expand All @@ -472,13 +453,12 @@ def run(
download_cb: Callback = DEFAULT_CALLBACK,
processed_cb: Callback = DEFAULT_CALLBACK,
) -> Iterator[Iterable[UDFResult]]:
self.catalog = catalog
self.setup()

for batch in udf_inputs:
udf_args = zip(
*[
self._prepare_row(row, udf_fields, cache, download_cb)
self._prepare_row(row, udf_fields, catalog, cache, download_cb)
for row in batch.rows
]
)
Expand Down

0 comments on commit 950170b

Please sign in to comment.