Skip to content

Commit

Permalink
ensure iterable gets closed in asyncmapper
Browse files Browse the repository at this point in the history
In `AsyncMapper`, generators are executed on a separate thread.
However, if an exception occurs, the `__exit__` method gets executed
in the main thread.

This can lead to issues with SQLite. Specifically, attempting to
close a connection outside the thread where it was created results
in the following error:

```console
sqlite3.ProgrammingError: SQLite objects created in a thread can only be used in that same thread.
The object was created in thread id 123145530425344 and this is thread id 140704344640320.
```

This PR modifies `AsyncMapper` to ensure that `__exit__` is executed in
the same thread where the generator runs.
Additionally, since the main thread might still attempt to
close the generator, I have added safeguards to avoid calling
`sqlite.Connection.close()` more than once.

The behavior of the main thread calling `__exit__` is beyond `AsyncMapper`'s control,
as it accepts arbitrary `Iterable` inputs.
  • Loading branch information
skshetry committed Jan 1, 2025
1 parent 7adfc0a commit b9ee297
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 16 deletions.
15 changes: 10 additions & 5 deletions src/datachain/asyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from fsspec.asyn import get_loop

from datachain.utils import safe_closing

ASYNC_WORKERS = 20

InputT = TypeVar("InputT", contravariant=True) # noqa: PLC0105
Expand Down Expand Up @@ -64,11 +66,14 @@ def start_task(self, coro: Coroutine) -> asyncio.Task:
return task

def _produce(self) -> None:
for item in self.iterable:
if self._shutdown_producer.is_set():
return
fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop)
fut.result() # wait until the item is in the queue
with safe_closing(self.iterable):
for item in self.iterable:
if self._shutdown_producer.is_set():
return
fut = asyncio.run_coroutine_threadsafe(
self.work_queue.put(item), self.loop
)
fut.result() # wait until the item is in the queue

async def produce(self) -> None:
await self.to_thread(self._produce)
Expand Down
2 changes: 2 additions & 0 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ def cursor(self, factory=None):
return self.db.cursor(factory)

def close(self) -> None:
if self.is_closed:
return
self.db.close()
self.is_closed = True

Expand Down
25 changes: 16 additions & 9 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import sys
import traceback
from collections.abc import Generator as GeneratorType
from collections.abc import Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Callable, Optional

Expand All @@ -21,10 +22,9 @@
Partition,
RowsOutputBatch,
)
from datachain.utils import safe_closing

if TYPE_CHECKING:
from collections import abc

from typing_extensions import Self

from datachain.catalog import Catalog
Expand Down Expand Up @@ -295,10 +295,13 @@ def run(
) -> Iterator[Iterable[UDFResult]]:
self.catalog = catalog
self.setup()
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row_and_id(row, udf_fields, cache, download_cb)
for row in udf_inputs
)

def row_iter() -> GeneratorType[Sequence[Any], None, None]:
with safe_closing(udf_inputs) as rows:
for row in rows:
yield self._prepare_row_and_id(row, udf_fields, cache, download_cb)

prepared_inputs = row_iter()
if self.prefetch > 0:
_cache = self.catalog.cache if cache else None
prepared_inputs = rows_prefetcher(
Expand Down Expand Up @@ -378,9 +381,13 @@ def run(
) -> Iterator[Iterable[UDFResult]]:
self.catalog = catalog
self.setup()
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
)

def row_iter() -> GeneratorType[Sequence[Any], None, None]:
with safe_closing(udf_inputs) as rows:
for row in rows:
yield self._prepare_row(row, udf_fields, cache, download_cb)

prepared_inputs = row_iter()
if self.prefetch > 0:
_cache = self.catalog.cache if cache else None
prepared_inputs = rows_prefetcher(
Expand Down
15 changes: 13 additions & 2 deletions src/datachain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys
import time
from collections.abc import Iterable, Iterator, Sequence
from contextlib import contextmanager
from datetime import date, datetime, timezone
from itertools import chain, islice
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
Expand All @@ -22,6 +23,7 @@

if TYPE_CHECKING:
import pandas as pd
from typing_extensions import Self

NUL = b"\0"
TIME_ZERO = datetime.fromtimestamp(0, tz=timezone.utc)
Expand All @@ -33,7 +35,7 @@
STUDIO_URL = "https://studio.datachain.ai"


T = TypeVar("T", bound="DataChainDir")
T = TypeVar("T")


class DataChainDir:
Expand Down Expand Up @@ -90,7 +92,7 @@ def default_root(cls) -> str:
return osp.join(root_dir, cls.DEFAULT)

@classmethod
def find(cls: type[T], create: bool = True) -> T:
def find(cls, create: bool = True) -> "Self":
try:
root = os.environ[cls.ENV_VAR]
except KeyError:
Expand Down Expand Up @@ -479,3 +481,12 @@ def row_to_nested_dict(
for h, v in zip(headers, row):
nested_dict_path_set(result, h, v)
return result


@contextmanager
def safe_closing(thing: T) -> Iterator[T]:
try:
yield thing
finally:
if hasattr(thing, "close"):
thing.close()

0 comments on commit b9ee297

Please sign in to comment.