Skip to content

Commit

Permalink
Implement pre-fetching in map() and gen() (#521)
Browse files Browse the repository at this point in the history
* Use threading in AsyncMapper.produce()

* Implement prefetching in .gen() and .map()

* Avoid user code error in name_len()

* asyncmapper: shutdown producer on generator close (#597)

---------

Co-authored-by: skshetry <[email protected]>
  • Loading branch information
rlamy and skshetry authored Nov 20, 2024
1 parent a9a0649 commit 21857af
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 55 deletions.
40 changes: 36 additions & 4 deletions src/datachain/asyn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import asyncio
from collections.abc import AsyncIterable, Awaitable, Coroutine, Iterable, Iterator
import threading
from collections.abc import (
AsyncIterable,
Awaitable,
Coroutine,
Generator,
Iterable,
Iterator,
)
from concurrent.futures import ThreadPoolExecutor
from heapq import heappop, heappush
from typing import Any, Callable, Generic, Optional, TypeVar
Expand Down Expand Up @@ -47,16 +55,39 @@ def __init__(
self.loop = get_loop() if loop is None else loop
self.pool = ThreadPoolExecutor(workers)
self._tasks: set[asyncio.Task] = set()
self._shutdown_producer = threading.Event()

def start_task(self, coro: Coroutine) -> asyncio.Task:
task = self.loop.create_task(coro)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
return task

async def produce(self) -> None:
def _produce(self) -> None:
for item in self.iterable:
await self.work_queue.put(item)
if self._shutdown_producer.is_set():
return
fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop)
fut.result() # wait until the item is in the queue

async def produce(self) -> None:
await self.to_thread(self._produce)

def shutdown_producer(self) -> None:
"""
Signal the producer to stop and drain any remaining items from the work_queue.
This method sets an internal event, `_shutdown_producer`, which tells the
producer that it should stop adding items to the queue. To ensure that the
producer notices this signal promptly, we also attempt to drain any items
currently in the queue, clearing it so that the event can be checked without
delay.
"""
self._shutdown_producer.set()
q = self.work_queue
while not q.empty():
q.get_nowait()
q.task_done()

async def worker(self) -> None:
while (item := await self.work_queue.get()) is not None:
Expand Down Expand Up @@ -132,7 +163,7 @@ async def _break_iteration(self) -> None:
self.result_queue.get_nowait()
await self.result_queue.put(None)

def iterate(self, timeout=None) -> Iterable[ResultT]:
def iterate(self, timeout=None) -> Generator[ResultT, None, None]:
init = asyncio.run_coroutine_threadsafe(self.init(), self.loop)
init.result(timeout=1)
async_run = asyncio.run_coroutine_threadsafe(self.run(), self.loop)
Expand All @@ -145,6 +176,7 @@ def iterate(self, timeout=None) -> Iterable[ResultT]:
if exc := async_run.exception():
raise exc
finally:
self.shutdown_producer()
if not async_run.done():
async_run.cancel()

Expand Down
5 changes: 4 additions & 1 deletion src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,10 @@ def dataset_select_paginated(
if limit < page_size:
paginated_query = paginated_query.limit(None).limit(limit)

results = self.dataset_rows_select(paginated_query.offset(offset))
# Ensure we're using a thread-local connection
with self.clone() as wh:
# Cursor results are not thread-safe, so we convert them to a list
results = list(wh.dataset_rows_select(paginated_query.offset(offset)))

processed = False
for row in results:
Expand Down
7 changes: 6 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def settings(
parallel=None,
workers=None,
min_task_size=None,
prefetch: Optional[int] = None,
sys: Optional[bool] = None,
) -> "Self":
"""Change settings for chain.
Expand All @@ -360,7 +361,7 @@ def settings(
if sys is None:
sys = self._sys
settings = copy.copy(self._settings)
settings.add(Settings(cache, parallel, workers, min_task_size))
settings.add(Settings(cache, parallel, workers, min_task_size, prefetch))
return self._evolve(settings=settings, _sys=sys)

def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
Expand Down Expand Up @@ -882,6 +883,8 @@ def map(
```
"""
udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
if (prefetch := self._settings.prefetch) is not None:
udf_obj.prefetch = prefetch

return self._evolve(
query=self._query.add_signals(
Expand Down Expand Up @@ -919,6 +922,8 @@ def gen(
```
"""
udf_obj = self._udf_to_obj(Generator, func, params, output, signal_map)
if (prefetch := self._settings.prefetch) is not None:
udf_obj.prefetch = prefetch
return self._evolve(
query=self._query.generate(
udf_obj.to_udf_wrapper(),
Expand Down
5 changes: 5 additions & 0 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ def ensure_cached(self) -> None:
client = self._catalog.get_client(self.source)
client.download(self, callback=self._download_cb)

async def _prefetch(self) -> None:
if self._caching_enabled:
client = self._catalog.get_client(self.source)
await client._download(self, callback=self._download_cb)

def get_local_path(self) -> Optional[str]:
"""Return path to a file in a local cache.
Expand Down
12 changes: 11 additions & 1 deletion src/datachain/lib/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,19 @@ def __init__(self, msg):


class Settings:
def __init__(self, cache=None, parallel=None, workers=None, min_task_size=None):
def __init__(
self,
cache=None,
parallel=None,
workers=None,
min_task_size=None,
prefetch=None,
):
self._cache = cache
self.parallel = parallel
self._workers = workers
self.min_task_size = min_task_size
self.prefetch = prefetch

if not isinstance(cache, bool) and cache is not None:
raise SettingsError(
Expand Down Expand Up @@ -66,3 +74,5 @@ def add(self, settings: "Settings"):
self.parallel = settings.parallel or self.parallel
self._workers = settings._workers or self._workers
self.min_task_size = settings.min_task_size or self.min_task_size
if settings.prefetch is not None:
self.prefetch = settings.prefetch
63 changes: 45 additions & 18 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import sys
import traceback
from collections.abc import Iterable, Iterator, Mapping, Sequence
Expand All @@ -7,6 +8,7 @@
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from pydantic import BaseModel

from datachain.asyn import AsyncMapper
from datachain.dataset import RowDict
from datachain.lib.convert.flatten import flatten
from datachain.lib.data_model import DataValue
Expand All @@ -21,6 +23,8 @@
)

if TYPE_CHECKING:
from collections import abc

from typing_extensions import Self

from datachain.catalog import Catalog
Expand Down Expand Up @@ -276,9 +280,18 @@ def process_safe(self, obj_rows):
return result_objs


async def _prefetch_input(row):
for obj in row:
if isinstance(obj, File):
await obj._prefetch()
return row


class Mapper(UDFBase):
"""Inherit from this class to pass to `DataChain.map()`."""

prefetch: int = 2

def run(
self,
udf_fields: "Sequence[str]",
Expand All @@ -290,16 +303,22 @@ def run(
) -> Iterator[Iterable[UDFResult]]:
self.catalog = catalog
self.setup()

for row in udf_inputs:
id_, *udf_args = self._prepare_row_and_id(
row, udf_fields, cache, download_cb
)
result_objs = self.process_safe(udf_args)
udf_output = self._flatten_row(result_objs)
output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
processed_cb.relative_update(1)
yield output
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row_and_id(row, udf_fields, cache, download_cb)
for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()

with contextlib.closing(prepared_inputs):
for id_, *udf_args in prepared_inputs:
result_objs = self.process_safe(udf_args)
udf_output = self._flatten_row(result_objs)
output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
processed_cb.relative_update(1)
yield output

self.teardown()

Expand Down Expand Up @@ -349,6 +368,7 @@ class Generator(UDFBase):
"""Inherit from this class to pass to `DataChain.gen()`."""

is_output_batched = True
prefetch: int = 2

def run(
self,
Expand All @@ -361,14 +381,21 @@ def run(
) -> Iterator[Iterable[UDFResult]]:
self.catalog = catalog
self.setup()

for row in udf_inputs:
udf_args = self._prepare_row(row, udf_fields, cache, download_cb)
result_objs = self.process_safe(udf_args)
udf_outputs = (self._flatten_row(row) for row in result_objs)
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
processed_cb.relative_update(1)
yield output
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()

with contextlib.closing(prepared_inputs):
for row in prepared_inputs:
result_objs = self.process_safe(row)
udf_outputs = (self._flatten_row(row) for row in result_objs)
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
processed_cb.relative_update(1)
yield output

self.teardown()

Expand Down
52 changes: 25 additions & 27 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,33 +473,31 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
# Otherwise process single-threaded (faster for smaller UDFs)
warehouse = self.catalog.warehouse

with contextlib.closing(
batching(warehouse.dataset_select_paginated, query)
) as udf_inputs:
download_cb = get_download_callback()
processed_cb = get_processed_callback()
generated_cb = get_generated_callback(self.is_generator)
try:
udf_results = self.udf.run(
udf_fields,
udf_inputs,
self.catalog,
self.is_generator,
self.cache,
download_cb,
processed_cb,
)
process_udf_outputs(
warehouse,
udf_table,
udf_results,
self.udf,
cb=generated_cb,
)
finally:
download_cb.close()
processed_cb.close()
generated_cb.close()
udf_inputs = batching(warehouse.dataset_select_paginated, query)
download_cb = get_download_callback()
processed_cb = get_processed_callback()
generated_cb = get_generated_callback(self.is_generator)
try:
udf_results = self.udf.run(
udf_fields,
udf_inputs,
self.catalog,
self.is_generator,
self.cache,
download_cb,
processed_cb,
)
process_udf_outputs(
warehouse,
udf_table,
udf_results,
self.udf,
cb=generated_cb,
)
finally:
download_cb.close()
processed_cb.close()
generated_cb.close()

warehouse.insert_rows_done(udf_table)

Expand Down
7 changes: 5 additions & 2 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,20 @@ def test_from_storage_dependencies(cloud_test_catalog, cloud_type):


@pytest.mark.parametrize("use_cache", [True, False])
def test_map_file(cloud_test_catalog, use_cache):
@pytest.mark.parametrize("prefetch", [0, 2])
def test_map_file(cloud_test_catalog, use_cache, prefetch):
ctc = cloud_test_catalog

def new_signal(file: File) -> str:
assert bool(file.get_local_path()) is (use_cache and prefetch > 0)
with file.open() as f:
return file.name + " -> " + f.read().decode("utf-8")

dc = (
DataChain.from_storage(ctc.src_uri, session=ctc.session)
.settings(cache=use_cache)
.settings(cache=use_cache, prefetch=prefetch)
.map(signal=new_signal)
.save()
)

expected = {
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/name_len_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ def name_len(file):
"gs://dvcx-datalakes/dogs-and-cats/",
anon=True,
).filter(C("file.path").glob("*cat*")).settings(parallel=1).map(
name_len, params=["file.path"], output={"name_len": int}
name_len, params=["file"], output={"name_len": int}
).save("name_len")
Loading

0 comments on commit 21857af

Please sign in to comment.