From b1ce0933510b869f47db39f05a2945d168474b09 Mon Sep 17 00:00:00 2001 From: Vladimir Rudnykh Date: Thu, 26 Dec 2024 22:44:59 +0700 Subject: [PATCH] Optimize UDF with parallel execution (#713) --------- Co-authored-by: skshetry <18718008+skshetry@users.noreply.github.com> --- src/datachain/data_storage/warehouse.py | 1 - src/datachain/lib/udf.py | 1 - src/datachain/query/batch.py | 38 +++- src/datachain/query/dataset.py | 30 ++- src/datachain/query/dispatch.py | 250 ++++++++++++------------ src/datachain/query/udf.py | 20 ++ src/datachain/query/utils.py | 42 ++++ src/datachain/utils.py | 2 +- 8 files changed, 234 insertions(+), 150 deletions(-) create mode 100644 src/datachain/query/udf.py create mode 100644 src/datachain/query/utils.py diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index fdb8f3c17..cd0c4376e 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -216,7 +216,6 @@ def dataset_select_paginated( limit = query._limit paginated_query = query.limit(page_size) - results = None offset = 0 num_yielded = 0 diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index d708c0330..c59442d6b 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -85,7 +85,6 @@ def run( udf_fields: "Sequence[str]", udf_inputs: "Iterable[RowsOutput]", catalog: "Catalog", - is_generator: bool, cache: bool, download_cb: Callback = DEFAULT_CALLBACK, processed_cb: Callback = DEFAULT_CALLBACK, diff --git a/src/datachain/query/batch.py b/src/datachain/query/batch.py index 8f24ec895..65b2f5742 100644 --- a/src/datachain/query/batch.py +++ b/src/datachain/query/batch.py @@ -7,6 +7,7 @@ from datachain.data_storage.schema import PARTITION_COLUMN_ID from datachain.data_storage.warehouse import SELECT_BATCH_SIZE +from datachain.query.utils import get_query_column, get_query_id_column if TYPE_CHECKING: from sqlalchemy import Select @@ -23,11 +24,14 @@ class RowsOutputBatch: class BatchingStrategy(ABC): """BatchingStrategy provides means of batching UDF executions.""" + is_batching: bool + @abstractmethod def __call__( self, - execute: Callable[..., Generator[Sequence, None, None]], + execute: Callable, query: "Select", + ids_only: bool = False, ) -> Generator[RowsOutput, None, None]: """Apply the provided parameters to the UDF.""" @@ -38,11 +42,16 @@ class NoBatching(BatchingStrategy): batch UDF calls. """ + is_batching = False + def __call__( self, - execute: Callable[..., Generator[Sequence, None, None]], + execute: Callable, query: "Select", + ids_only: bool = False, ) -> Generator[Sequence, None, None]: + if ids_only: + query = query.with_only_columns(get_query_id_column(query)) return execute(query) @@ -52,14 +61,20 @@ class Batch(BatchingStrategy): is passed a sequence of multiple parameter sets. """ + is_batching = True + def __init__(self, count: int): self.count = count def __call__( self, - execute: Callable[..., Generator[Sequence, None, None]], + execute: Callable, query: "Select", + ids_only: bool = False, ) -> Generator[RowsOutputBatch, None, None]: + if ids_only: + query = query.with_only_columns(get_query_id_column(query)) + # choose page size that is a multiple of the batch size page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count @@ -84,19 +99,30 @@ class Partition(BatchingStrategy): Dataset rows need to be sorted by the grouping column. """ + is_batching = True + def __call__( self, - execute: Callable[..., Generator[Sequence, None, None]], + execute: Callable, query: "Select", + ids_only: bool = False, ) -> Generator[RowsOutputBatch, None, None]: + id_col = get_query_id_column(query) + if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None: + raise RuntimeError("partition column not found in query") + + if ids_only: + query = query.with_only_columns(id_col, partition_col) + current_partition: Optional[int] = None batch: list[Sequence] = [] query_fields = [str(c.name) for c in query.selected_columns] + id_column_idx = query_fields.index("sys__id") partition_column_idx = query_fields.index(PARTITION_COLUMN_ID) ordered_query = query.order_by(None).order_by( - PARTITION_COLUMN_ID, + partition_col, *query._order_by_clauses, ) @@ -108,7 +134,7 @@ def __call__( if len(batch) > 0: yield RowsOutputBatch(batch) batch = [] - batch.append(row) + batch.append([row[id_column_idx]] if ids_only else row) if len(batch) > 0: yield RowsOutputBatch(batch) diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 3761bed4b..1c5895489 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -43,8 +43,9 @@ from datachain.dataset import DatasetStatus, RowDict from datachain.error import DatasetNotFoundError, QueryScriptCancelError from datachain.func.base import Function -from datachain.lib.udf import UDFAdapter from datachain.progress import CombinedDownloadCallback +from datachain.query.schema import C, UDFParamSpec, normalize_param +from datachain.query.session import Session from datachain.sql.functions.random import rand from datachain.utils import ( batched, @@ -53,9 +54,6 @@ get_datachain_executable, ) -from .schema import C, UDFParamSpec, normalize_param -from .session import Session - if TYPE_CHECKING: from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.schema import Table @@ -65,7 +63,8 @@ from datachain.catalog import Catalog from datachain.data_storage import AbstractWarehouse from datachain.dataset import DatasetRecord - from datachain.lib.udf import UDFResult + from datachain.lib.udf import UDFAdapter, UDFResult + from datachain.query.udf import UdfInfo P = ParamSpec("P") @@ -301,7 +300,7 @@ def adjust_outputs( return row -def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]: +def get_udf_col_types(warehouse: "AbstractWarehouse", udf: "UDFAdapter") -> list[tuple]: """Optimization: Precompute UDF column types so these don't have to be computed in the convert_type function for each row in a loop.""" dialect = warehouse.db.dialect @@ -322,7 +321,7 @@ def process_udf_outputs( warehouse: "AbstractWarehouse", udf_table: "Table", udf_results: Iterator[Iterable["UDFResult"]], - udf: UDFAdapter, + udf: "UDFAdapter", batch_size: int = INSERT_BATCH_SIZE, cb: Callback = DEFAULT_CALLBACK, ) -> None: @@ -347,6 +346,8 @@ def process_udf_outputs( for row_chunk in batched(rows, batch_size): warehouse.insert_rows(udf_table, row_chunk) + warehouse.insert_rows_done(udf_table) + def get_download_callback() -> Callback: return CombinedDownloadCallback( @@ -366,7 +367,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback: @frozen class UDFStep(Step, ABC): - udf: UDFAdapter + udf: "UDFAdapter" catalog: "Catalog" partition_by: Optional[PartitionByType] = None parallel: Optional[int] = None @@ -440,7 +441,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: raise RuntimeError( "In-memory databases cannot be used with parallel processing." ) - udf_info = { + udf_info: UdfInfo = { "udf_data": filtered_cloudpickle_dumps(self.udf), "catalog_init": self.catalog.get_init_params(), "metastore_clone_params": self.catalog.metastore.clone_params(), @@ -464,8 +465,8 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: with subprocess.Popen(cmd, env=envs, stdin=subprocess.PIPE) as process: # noqa: S603 process.communicate(process_data) - if process.poll(): - raise RuntimeError("UDF Execution Failed!") + if retval := process.poll(): + raise RuntimeError(f"UDF Execution Failed! Exit code: {retval}") else: # Otherwise process single-threaded (faster for smaller UDFs) warehouse = self.catalog.warehouse @@ -479,7 +480,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: udf_fields, udf_inputs, self.catalog, - self.is_generator, self.cache, download_cb, processed_cb, @@ -496,8 +496,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: processed_cb.close() generated_cb.close() - warehouse.insert_rows_done(udf_table) - except QueryScriptCancelError: self.catalog.warehouse.close() sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE) @@ -1491,7 +1489,7 @@ def chunk(self, index: int, total: int) -> "Self": @detach def add_signals( self, - udf: UDFAdapter, + udf: "UDFAdapter", parallel: Optional[int] = None, workers: Union[bool, int] = False, min_task_size: Optional[int] = None, @@ -1535,7 +1533,7 @@ def subtract(self, dq: "DatasetQuery", on: Sequence[tuple[str, str]]) -> "Self": @detach def generate( self, - udf: UDFAdapter, + udf: "UDFAdapter", parallel: Optional[int] = None, workers: Union[bool, int] = False, min_task_size: Optional[int] = None, diff --git a/src/datachain/query/dispatch.py b/src/datachain/query/dispatch.py index 5392cf491..722f68c10 100644 --- a/src/datachain/query/dispatch.py +++ b/src/datachain/query/dispatch.py @@ -1,34 +1,37 @@ import contextlib -from collections.abc import Iterator, Sequence +from collections.abc import Iterable, Sequence from itertools import chain from multiprocessing import cpu_count from sys import stdin -from typing import Optional +from threading import Timer +from typing import TYPE_CHECKING, Optional import attrs import multiprocess from cloudpickle import load, loads from fsspec.callbacks import DEFAULT_CALLBACK, Callback from multiprocess import get_context +from sqlalchemy.sql import func from datachain.catalog import Catalog from datachain.catalog.loader import get_distributed_class -from datachain.lib.udf import UDFAdapter, UDFResult +from datachain.query.batch import RowsOutput, RowsOutputBatch from datachain.query.dataset import ( get_download_callback, get_generated_callback, get_processed_callback, process_udf_outputs, ) -from datachain.query.queue import ( - get_from_queue, - marshal, - msgpack_pack, - msgpack_unpack, - put_into_queue, - unmarshal, -) -from datachain.utils import batched_it +from datachain.query.queue import get_from_queue, put_into_queue +from datachain.query.udf import UdfInfo +from datachain.query.utils import get_query_id_column +from datachain.utils import batched, flatten + +if TYPE_CHECKING: + from sqlalchemy import Select, Table + + from datachain.data_storage import AbstractMetastore, AbstractWarehouse + from datachain.lib.udf import UDFAdapter DEFAULT_BATCH_SIZE = 10000 STOP_SIGNAL = "STOP" @@ -38,10 +41,6 @@ NOTIFY_STATUS = "NOTIFY" -def full_module_type_path(typ: type) -> str: - return f"{typ.__module__}.{typ.__qualname__}" - - def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int: if not n_workers: return cpu_count() @@ -52,55 +51,42 @@ def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int: def udf_entrypoint() -> int: # Load UDF info from stdin - udf_info = load(stdin.buffer) - - ( - warehouse_class, - warehouse_args, - warehouse_kwargs, - ) = udf_info["warehouse_clone_params"] - warehouse = warehouse_class(*warehouse_args, **warehouse_kwargs) + udf_info: UdfInfo = load(stdin.buffer) # Parallel processing (faster for more CPU-heavy UDFs) - dispatch = UDFDispatcher( - udf_info["udf_data"], - udf_info["catalog_init"], - udf_info["metastore_clone_params"], - udf_info["warehouse_clone_params"], - udf_fields=udf_info["udf_fields"], - cache=udf_info["cache"], - is_generator=udf_info.get("is_generator", False), - ) + dispatch = UDFDispatcher(udf_info) query = udf_info["query"] batching = udf_info["batching"] - table = udf_info["table"] n_workers = udf_info["processes"] - udf = loads(udf_info["udf_data"]) if n_workers is True: - # Use default number of CPUs (cores) - n_workers = None + n_workers = None # Use default number of CPUs (cores) + + wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"] + warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs) + + total_rows = next( + warehouse.db.execute( + query.with_only_columns(func.count(query.c.sys__id)).order_by(None) + ) + )[0] with contextlib.closing( - batching(warehouse.dataset_select_paginated, query) + batching(warehouse.dataset_select_paginated, query, ids_only=True) ) as udf_inputs: download_cb = get_download_callback() processed_cb = get_processed_callback() - generated_cb = get_generated_callback(dispatch.is_generator) try: - udf_results = dispatch.run_udf_parallel( - marshal(udf_inputs), + dispatch.run_udf_parallel( + udf_inputs, + total_rows=total_rows, n_workers=n_workers, processed_cb=processed_cb, download_cb=download_cb, ) - process_udf_outputs(warehouse, table, udf_results, udf, cb=generated_cb) finally: download_cb.close() processed_cb.close() - generated_cb.close() - - warehouse.insert_rows_done(table) return 0 @@ -114,32 +100,17 @@ class UDFDispatcher: task_queue: Optional[multiprocess.Queue] = None done_queue: Optional[multiprocess.Queue] = None - def __init__( - self, - udf_data, - catalog_init_params, - metastore_clone_params, - warehouse_clone_params, - udf_fields: "Sequence[str]", - cache: bool, - is_generator: bool = False, - buffer_size: int = DEFAULT_BATCH_SIZE, - ): - self.udf_data = udf_data - self.catalog_init_params = catalog_init_params - ( - self.metastore_class, - self.metastore_args, - self.metastore_kwargs, - ) = metastore_clone_params - ( - self.warehouse_class, - self.warehouse_args, - self.warehouse_kwargs, - ) = warehouse_clone_params - self.udf_fields = udf_fields - self.cache = cache - self.is_generator = is_generator + def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE): + self.udf_data = udf_info["udf_data"] + self.catalog_init_params = udf_info["catalog_init"] + self.metastore_clone_params = udf_info["metastore_clone_params"] + self.warehouse_clone_params = udf_info["warehouse_clone_params"] + self.query = udf_info["query"] + self.table = udf_info["table"] + self.udf_fields = udf_info["udf_fields"] + self.cache = udf_info["cache"] + self.is_generator = udf_info["is_generator"] + self.is_batching = udf_info["batching"].is_batching self.buffer_size = buffer_size self.catalog = None self.task_queue = None @@ -148,12 +119,10 @@ def __init__( def _create_worker(self) -> "UDFWorker": if not self.catalog: - metastore = self.metastore_class( - *self.metastore_args, **self.metastore_kwargs - ) - warehouse = self.warehouse_class( - *self.warehouse_args, **self.warehouse_kwargs - ) + ms_cls, ms_args, ms_kwargs = self.metastore_clone_params + metastore: AbstractMetastore = ms_cls(*ms_args, **ms_kwargs) + ws_cls, ws_args, ws_kwargs = self.warehouse_clone_params + warehouse: AbstractWarehouse = ws_cls(*ws_args, **ws_kwargs) self.catalog = Catalog(metastore, warehouse, **self.catalog_init_params) self.udf = loads(self.udf_data) return UDFWorker( @@ -161,7 +130,10 @@ def _create_worker(self) -> "UDFWorker": self.udf, self.task_queue, self.done_queue, + self.query, + self.table, self.is_generator, + self.is_batching, self.cache, self.udf_fields, ) @@ -189,26 +161,27 @@ def create_input_queue(self): def run_udf_parallel( # noqa: C901, PLR0912 self, - input_rows, + input_rows: Iterable[RowsOutput], + total_rows: int, n_workers: Optional[int] = None, - input_queue=None, processed_cb: Callback = DEFAULT_CALLBACK, download_cb: Callback = DEFAULT_CALLBACK, - ) -> Iterator[Sequence[UDFResult]]: + ) -> None: n_workers = get_n_workers_from_arg(n_workers) + input_batch_size = total_rows // n_workers + if input_batch_size == 0: + input_batch_size = 1 + elif input_batch_size > DEFAULT_BATCH_SIZE: + input_batch_size = DEFAULT_BATCH_SIZE + if self.buffer_size < n_workers: raise RuntimeError( "Parallel run error: buffer size is smaller than " f"number of workers: {self.buffer_size} < {n_workers}" ) - if input_queue: - streaming_mode = True - self.task_queue = input_queue - else: - streaming_mode = False - self.task_queue = self.ctx.Queue() + self.task_queue = self.ctx.Queue() self.done_queue = self.ctx.Queue() pool = [ self.ctx.Process(name=f"Worker-UDF-{i}", target=self._run_worker) @@ -223,41 +196,41 @@ def run_udf_parallel( # noqa: C901, PLR0912 # Will be set to True when the input is exhausted input_finished = False - if not streaming_mode: - # Stop all workers after the input rows have finished processing - input_data = chain(input_rows, [STOP_SIGNAL] * n_workers) + if not self.is_batching: + input_rows = batched(flatten(input_rows), input_batch_size) - # Add initial buffer of tasks - for _ in range(self.buffer_size): - try: - put_into_queue(self.task_queue, next(input_data)) - except StopIteration: - input_finished = True - break + # Stop all workers after the input rows have finished processing + input_data = chain(input_rows, [STOP_SIGNAL] * n_workers) + + # Add initial buffer of tasks + for _ in range(self.buffer_size): + try: + put_into_queue(self.task_queue, next(input_data)) + except StopIteration: + input_finished = True + break # Process all tasks while n_workers > 0: result = get_from_queue(self.done_queue) + + if downloaded := result.get("downloaded"): + download_cb.relative_update(downloaded) + if processed := result.get("processed"): + processed_cb.relative_update(processed) + status = result["status"] - if status == NOTIFY_STATUS: - if downloaded := result.get("downloaded"): - download_cb.relative_update(downloaded) - if processed := result.get("processed"): - processed_cb.relative_update(processed) + if status in (OK_STATUS, NOTIFY_STATUS): + pass # Do nothing here elif status == FINISHED_STATUS: - # Worker finished - n_workers -= 1 - elif status == OK_STATUS: - if processed := result.get("processed"): - processed_cb.relative_update(processed) - yield msgpack_unpack(result["result"]) + n_workers -= 1 # Worker finished else: # Failed / error n_workers -= 1 if exc := result.get("exception"): raise exc raise RuntimeError("Internal error: Parallel UDF execution failed") - if status == OK_STATUS and not streaming_mode and not input_finished: + if status == OK_STATUS and not input_finished: try: put_into_queue(self.task_queue, next(input_data)) except StopIteration: @@ -311,11 +284,14 @@ def relative_update(self, inc: int = 1) -> None: @attrs.define class UDFWorker: - catalog: Catalog - udf: UDFAdapter + catalog: "Catalog" + udf: "UDFAdapter" task_queue: "multiprocess.Queue" done_queue: "multiprocess.Queue" + query: "Select" + table: "Table" is_generator: bool + is_batching: bool cache: bool udf_fields: Sequence[str] cb: Callback = attrs.field() @@ -326,30 +302,54 @@ def _default_callback(self) -> WorkerCallback: def run(self) -> None: processed_cb = ProcessedCallback() + generated_cb = get_generated_callback(self.is_generator) + udf_results = self.udf.run( self.udf_fields, - unmarshal(self.get_inputs()), + self.get_inputs(), self.catalog, - self.is_generator, self.cache, download_cb=self.cb, processed_cb=processed_cb, ) - for udf_output in udf_results: - for batch in batched_it(udf_output, DEFAULT_BATCH_SIZE): - put_into_queue( - self.done_queue, - { - "status": OK_STATUS, - "result": msgpack_pack(list(batch)), - }, - ) + process_udf_outputs( + self.catalog.warehouse, + self.table, + self.notify_and_process(udf_results, processed_cb), + self.udf, + cb=generated_cb, + ) + + put_into_queue( + self.done_queue, + {"status": FINISHED_STATUS, "processed": processed_cb.processed_rows}, + ) + + def notify_and_process(self, udf_results, processed_cb): + for row in udf_results: put_into_queue( self.done_queue, - {"status": NOTIFY_STATUS, "processed": processed_cb.processed_rows}, + {"status": OK_STATUS, "processed": processed_cb.processed_rows}, ) - put_into_queue(self.done_queue, {"status": FINISHED_STATUS}) + yield row def get_inputs(self): - while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL: - yield batch + warehouse = self.catalog.warehouse.clone() + col_id = get_query_id_column(self.query) + + if self.is_batching: + while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL: + ids = [row[0] for row in batch.rows] + rows = warehouse.dataset_rows_select(self.query.where(col_id.in_(ids))) + yield RowsOutputBatch(list(rows)) + else: + while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL: + yield from warehouse.dataset_rows_select( + self.query.where(col_id.in_(batch)) + ) + + +class RepeatTimer(Timer): + def run(self): + while not self.finished.wait(self.interval): + self.function(*self.args, **self.kwargs) diff --git a/src/datachain/query/udf.py b/src/datachain/query/udf.py new file mode 100644 index 000000000..a6046deae --- /dev/null +++ b/src/datachain/query/udf.py @@ -0,0 +1,20 @@ +from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict + +if TYPE_CHECKING: + from sqlalchemy import Select, Table + + from datachain.query.batch import BatchingStrategy + + +class UdfInfo(TypedDict): + udf_data: bytes + catalog_init: dict[str, Any] + metastore_clone_params: tuple[Callable[..., Any], list[Any], dict[str, Any]] + warehouse_clone_params: tuple[Callable[..., Any], list[Any], dict[str, Any]] + table: "Table" + query: "Select" + udf_fields: list[str] + batching: "BatchingStrategy" + processes: Optional[int] + is_generator: bool + cache: bool diff --git a/src/datachain/query/utils.py b/src/datachain/query/utils.py new file mode 100644 index 000000000..0d92226b1 --- /dev/null +++ b/src/datachain/query/utils.py @@ -0,0 +1,42 @@ +from typing import TYPE_CHECKING, Optional, Union + +from sqlalchemy import Column + +if TYPE_CHECKING: + from sqlalchemy import ColumnElement, Select, TextClause + + +ColT = Union[Column, "ColumnElement", "TextClause"] + + +def column_name(col: ColT) -> str: + """Returns column name from column element.""" + return col.name if isinstance(col, Column) else str(col) + + +def get_query_column(query: "Select", name: str) -> Optional[ColT]: + """Returns column element from query by name or None if column not found.""" + return next((col for col in query.inner_columns if column_name(col) == name), None) + + +def get_query_id_column(query: "Select") -> ColT: + """Returns ID column element from query or None if column not found.""" + col = get_query_column(query, "sys__id") + if col is None: + raise RuntimeError("sys__id column not found in query") + return col + + +def select_only_columns(query: "Select", *names: str) -> "Select": + """Returns query selecting defined columns only.""" + if not names: + return query + + cols: list[ColT] = [] + for name in names: + col = get_query_column(query, name) + if col is None: + raise ValueError(f"Column '{name}' not found in query") + cols.append(col) + + return query.with_only_columns(*cols) diff --git a/src/datachain/utils.py b/src/datachain/utils.py index 21fcd6e49..11018df08 100644 --- a/src/datachain/utils.py +++ b/src/datachain/utils.py @@ -263,7 +263,7 @@ def batched_it(iterable: Iterable[_T_co], n: int) -> Iterator[Iterator[_T_co]]: def flatten(items): for item in items: - if isinstance(item, list): + if isinstance(item, (list, tuple)): yield from item else: yield item