diff --git a/.gitignore b/.gitignore index 87cfd442e..8c1997914 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ charts/*.tgz indexify_storage indexify_local_runner_cache server/indexify_storage +server/indexify_server_state local_cache .dev-tls src/state/store/snapshots/* diff --git a/python-sdk/indexify/cli.py b/python-sdk/indexify/cli.py index 2e085c88d..41b401e9b 100644 --- a/python-sdk/indexify/cli.py +++ b/python-sdk/indexify/cli.py @@ -16,7 +16,6 @@ from rich.theme import Theme from indexify.executor.agent import ExtractorAgent -from indexify.executor.function_worker import FunctionWorker from indexify.functions_sdk.image import ( DEFAULT_IMAGE_3_10, DEFAULT_IMAGE_3_11, @@ -31,7 +30,6 @@ "highlight": "magenta", } ) - console = Console(theme=custom_theme) app = typer.Typer(pretty_exceptions_enable=False, no_args_is_help=True) @@ -250,7 +248,6 @@ def _build_image(image: Image, python_sdk_path: Optional[str] = None): docker_file += "\n".join(run_strs) print(os.getcwd()) - import docker import docker.api.build docker.api.build.process_dockerfile = lambda dockerfile, path: ( diff --git a/python-sdk/indexify/executor/agent.py b/python-sdk/indexify/executor/agent.py index 8b8083460..26006c963 100644 --- a/python-sdk/indexify/executor/agent.py +++ b/python-sdk/indexify/executor/agent.py @@ -2,6 +2,7 @@ import json import traceback from concurrent.futures.process import BrokenProcessPool +from concurrent.futures.thread import ThreadPoolExecutor from importlib.metadata import version from pathlib import Path from typing import Dict, List, Optional @@ -14,6 +15,8 @@ from rich.theme import Theme from indexify.common_util import get_httpx_client +from indexify.executor.function_worker import FunctionWorker +from indexify.executor.task_reporter import TaskReporter from indexify.functions_sdk.data_objects import ( FunctionWorkerOutput, IndexifyData, @@ -25,10 +28,8 @@ from . import image_dependency_installer from .api_objects import ExecutorMetadata, Task from .downloader import DownloadedInputs, Downloader -from .executor_tasks import DownloadGraphTask, DownloadInputTask, ExtractTask -from .function_worker import FunctionWorker +from .executor_tasks import DownloadTask, RunFunctionTask, TaskEnum from .runtime_probes import ProbeInfo, RuntimeProbes -from .task_reporter import TaskReporter from .task_store import CompletedTask, TaskStore custom_theme = Theme( @@ -63,6 +64,9 @@ def __init__( name_alias: Optional[str] = None, image_version: Optional[int] = None, ): + event_loop = asyncio.get_event_loop() + self._thread_pool = ThreadPoolExecutor(max_workers=num_workers) + event_loop.set_default_executor(self._thread_pool) self.name_alias = name_alias self.image_version = image_version self._config_path = config_path @@ -75,100 +79,54 @@ def __init__( else False ) self._executor_bootstrap_failed = False + self._server_addr = server_addr console.print( f"Require Bootstrap? {self._require_image_bootstrap}", style="cyan bold" ) - self.num_workers = num_workers if config_path: console.print("Running the extractor with TLS enabled", style="cyan bold") self._protocol = "https" else: self._protocol = "http" + self._base_url = f"{self._protocol}://{self._server_addr}" self._task_store: TaskStore = TaskStore() self._executor_id = executor_id self._function_worker = FunctionWorker( - workers=num_workers, indexify_client=IndexifyClient( - service_url=f"{self._protocol}://{server_addr}", - config_path=config_path, + service_url=self._base_url, + config_path=self._config_path, ), ) self._has_registered = False - self._server_addr = server_addr - self._base_url = f"{self._protocol}://{self._server_addr}" self._code_path = code_path self._downloader = Downloader( - code_path=code_path, base_url=self._base_url, config_path=self._config_path + indexify_client=IndexifyClient( + service_url=self._base_url, + config_path=self._config_path, + ), + code_path=code_path, + base_url=self._base_url, ) - self._max_queued_tasks = 10 self._task_reporter = TaskReporter( base_url=self._base_url, executor_id=self._executor_id, - config_path=self._config_path, + indexify_client=IndexifyClient( + service_url=self._base_url, + config_path=self._config_path, + ), ) - async def task_completion_reporter(self): - console.print(Text("Starting task completion reporter", style="bold cyan")) - # We should copy only the keys and not the values - url = f"{self._protocol}://{self._server_addr}/write_content" - - while True: - outcomes = await self._task_store.task_outcomes() - for task_outcome in outcomes: - retryStr = ( - f"\nRetries: {task_outcome.reporting_retries}" - if task_outcome.reporting_retries > 0 - else "" - ) - outcome = task_outcome.task_outcome - style_outcome = ( - f"[bold red] {outcome} [/]" - if "fail" in outcome - else f"[bold green] {outcome} [/]" - ) - console.print( - Panel( - f"Reporting outcome of task: {task_outcome.task.id}, function: {task_outcome.task.compute_fn}\n" - f"Outcome: {style_outcome}\n" - f"Num Fn Outputs: {len(task_outcome.outputs or [])}\n" - f"Router Output: {task_outcome.router_output}\n" - f"Retries: {task_outcome.reporting_retries}", - title="Task Completion", - border_style="info", - ) - ) - - try: - # Send task outcome to the server - self._task_reporter.report_task_outcome(completed_task=task_outcome) - except Exception as e: - # The connection was dropped in the middle of the reporting, process, retry - console.print( - Panel( - f"Failed to report task {task_outcome.task.id}\n" - f"Exception: {type(e).__name__}({e})\n" - f"Retries: {task_outcome.reporting_retries}\n" - "Retrying...", - title="Reporting Error", - border_style="error", - ) - ) - task_outcome.reporting_retries += 1 - await asyncio.sleep(5) - continue - - self._task_store.mark_reported(task_id=task_outcome.task.id) - async def task_launcher(self): async_tasks: List[asyncio.Task] = [] fn_queue: List[FunctionInput] = [] async_tasks.append( asyncio.create_task( - self._task_store.get_runnable_tasks(), name="get_runnable_tasks" + self._task_store.get_runnable_tasks(), + name=TaskEnum.GET_RUNNABLE_TASK.value, ) ) @@ -178,49 +136,18 @@ async def task_launcher(self): task: Task = self._task_store.get_task(fn.task_id) if self._executor_bootstrap_failed: - completed_task = CompletedTask( - task=task, - outputs=[], - task_outcome="failure", - ) - self._task_store.complete(outcome=completed_task) - + self._mark_task_as_failed(task) continue # Bootstrap this executor. Fail the task if we can't. if self._require_image_bootstrap: - try: - image_info = await _get_image_info_for_compute_graph( - task, self._protocol, self._server_addr, self._config_path - ) - image_dependency_installer.executor_image_builder( - image_info, self.name_alias, self.image_version - ) - self._require_image_bootstrap = False - except Exception as e: - console.print( - Text("Failed to bootstrap the executor ", style="red bold") - + Text(f"Exception: {traceback.format_exc()}", style="red") - ) - - self._executor_bootstrap_failed = True - - completed_task = CompletedTask( - task=task, - outputs=[], - task_outcome="failure", - ) - self._task_store.complete(outcome=completed_task) - + if not self._try_bootstrap(task): continue + code_path = f"{self._code_path}/{task.namespace}/{task.compute_graph}.{task.graph_version}" async_tasks.append( - ExtractTask( - function_worker=self._function_worker, - task=task, - input=fn.input, - code_path=f"{self._code_path}/{task.namespace}/{task.compute_graph}.{task.graph_version}", - init_value=fn.init_value, + self._function_worker.run_function( + task, fn.input, fn.init_value, code_path ) ) @@ -231,132 +158,149 @@ async def task_launcher(self): async_tasks: List[asyncio.Task] = list(pending) for async_task in done: - if async_task.get_name() == "get_runnable_tasks": - if async_task.exception(): - console.print( - Text("Task Launcher Error: ", style="red bold") - + Text( - f"Failed to get runnable tasks: {async_task.exception()}", - style="red", + task_name = TaskEnum.from_value(async_task.get_name()) + match task_name: + case TaskEnum.GET_RUNNABLE_TASK: + if async_task.exception(): + self._console_log_exception( + "Task Launcher Error:", + f"Failed to get runnable tasks: {async_task.exception()}" + ) + continue + result: Dict[str, Task] = await async_task + task: Task + for _, task in result.items(): + async_tasks.append( + self._downloader.download( + task, TaskEnum.DOWNLOAD_GRAPH_TASK + ) ) - ) - continue - result: Dict[str, Task] = await async_task - task: Task - for _, task in result.items(): async_tasks.append( - DownloadGraphTask(task=task, downloader=self._downloader) - ) - async_tasks.append( - asyncio.create_task( - self._task_store.get_runnable_tasks(), - name="get_runnable_tasks", + asyncio.create_task( + self._task_store.get_runnable_tasks(), + name=TaskEnum.GET_RUNNABLE_TASK.value, + ) ) - ) - elif async_task.get_name() == "download_graph": - if async_task.exception(): - console.print( - Text( + case TaskEnum.DOWNLOAD_GRAPH_TASK: + async_task: DownloadTask + if async_task.exception(): + self._console_log_exception( f"Failed to download graph for task {async_task.task.id}\n", - style="red bold", + f"Exception: {async_task.exception()}" + ) + self._mark_task_as_failed(async_task.task) + continue + async_tasks.append( + self._downloader.download( + async_task.task, TaskEnum.DOWNLOAD_INPUT_TASK ) - + Text(f"Exception: {async_task.exception()}", style="red") - ) - completed_task = CompletedTask( - task=async_task.task, - outputs=[], - task_outcome="failure", - ) - self._task_store.complete(outcome=completed_task) - continue - async_tasks.append( - DownloadInputTask( - task=async_task.task, downloader=self._downloader ) - ) - elif async_task.get_name() == "download_input": - if async_task.exception(): - console.print( - Text( + case TaskEnum.DOWNLOAD_INPUT_TASK: + async_task: DownloadTask + if async_task.exception(): + self._console_log_exception( f"Failed to download input for task {async_task.task.id}\n", - style="red bold", + f"Exception: {async_task.exception()}" + ) + self._mark_task_as_failed(async_task.task) + continue + downloaded_inputs: DownloadedInputs = await async_task + task: Task = async_task.task + fn_queue.append( + FunctionInput( + task_id=task.id, + namespace=task.namespace, + compute_graph=task.compute_graph, + function=task.compute_fn, + input=downloaded_inputs.input, + init_value=downloaded_inputs.init_value, ) - + Text(f"Exception: {async_task.exception()}", style="red") - ) - completed_task = CompletedTask( - task=async_task.task, - outputs=[], - task_outcome="failure", - ) - self._task_store.complete(outcome=completed_task) - continue - downloaded_inputs: DownloadedInputs = await async_task - task: Task = async_task.task - fn_queue.append( - FunctionInput( - task_id=task.id, - namespace=task.namespace, - compute_graph=task.compute_graph, - function=task.compute_fn, - input=downloaded_inputs.input, - init_value=downloaded_inputs.init_value, - ) - ) - elif async_task.get_name() == "run_function": - if async_task.exception(): - completed_task = CompletedTask( - task=async_task.task, - task_outcome="failure", - outputs=[], - stderr=str(async_task.exception()), - ) - self._task_store.complete(outcome=completed_task) - continue - async_task: ExtractTask - try: - outputs: FunctionWorkerOutput = await async_task - if not outputs.success: - task_outcome = "failure" - else: - task_outcome = "success" - - completed_task = CompletedTask( - task=async_task.task, - task_outcome=task_outcome, - outputs=outputs.fn_outputs, - router_output=outputs.router_output, - stdout=outputs.stdout, - stderr=outputs.stderr, - reducer=outputs.reducer, ) - self._task_store.complete(outcome=completed_task) - except BrokenProcessPool: - self._task_store.retriable_failure(async_task.task.id) - continue - except Exception as e: - console.print( - Text( + case TaskEnum.RUN_FUNCTION_TASK: + async_task: RunFunctionTask + if async_task.exception(): + self._mark_task_as_failed( + async_task.task, str(async_task.exception()) + ) + continue + try: + outputs: FunctionWorkerOutput = await async_task + if not outputs.success: + task_outcome = "failure" + else: + task_outcome = "success" + + completed_task = CompletedTask( + task=async_task.task, + task_outcome=task_outcome, + outputs=outputs.fn_outputs, + router_output=outputs.router_output, + stdout=outputs.stdout, + stderr=outputs.stderr, + reducer=outputs.reducer, + ) + self._task_store.complete(outcome=completed_task) + except BrokenProcessPool: + self._task_store.retriable_failure(async_task.task.id) + continue + except Exception as e: + self._console_log_exception( f"Failed to execute task {async_task.task.id}\n", - style="red bold", + f"Exception: {e}" ) - + Text(f"Exception: {e}", style="red") - ) - completed_task = CompletedTask( - task=async_task.task, - task_outcome="failure", - outputs=[], + self._mark_task_as_failed( + async_task.task, str(e) + ) + continue + case _: + raise ValueError( + f"'{async_task.get_name()}' is not a valid task name." ) - self._task_store.complete(outcome=completed_task) - continue + + def _console_log_exception(self, *args: str): + errorMessage = None + for arg in args: + error_message = Text(arg) if errorMessage is None else error_message + arg + console.print(Text(error_message, style="red bold")) + + def _mark_task_as_failed(self, task: Task, stderr: str = None): + completed_task = CompletedTask( + task=task, + outputs=[], + task_outcome="failure", + stderr=stderr, + ) + self._task_store.complete(outcome=completed_task) + + def _try_bootstrap(self, task: Task) -> bool: + try: + image_info = _get_image_info_for_compute_graph( + task, self._protocol, self._server_addr, self._config_path + ) + image_dependency_installer.executor_image_builder( + image_info, self.name_alias, self.image_version + ) + self._require_image_bootstrap = False + return True + except Exception as e: + console.print( + Text("Failed to bootstrap the executor ", style="red bold") + + Text(f"Exception: {traceback.format_exc()}", style="red") + ) + + self._executor_bootstrap_failed = True + self._mark_task_as_failed(task) + return False async def run(self): + console.print("Starting Extractor Agent...", style="green") import signal asyncio.get_event_loop().add_signal_handler( signal.SIGINT, self.shutdown, asyncio.get_event_loop() ) asyncio.create_task(self.task_launcher()) - asyncio.create_task(self.task_completion_reporter()) + asyncio.create_task(self._task_reporter.run()) self._should_run = True while self._should_run: url = f"{self._protocol}://{self._server_addr}/internal/executors/{self._executor_id}/tasks" @@ -366,6 +310,7 @@ def to_sentence_case(snake_str): words = snake_str.split("_") return words[0].capitalize() + "" + " ".join(words[1:]) + console.print("Starting Probe....") runtime_probe: ProbeInfo = self._probe.probe() executor_version = version("indexify") @@ -437,15 +382,15 @@ def to_sentence_case(snake_str): async def _shutdown(self, loop): console.print(Text("shutting down agent...", style="bold yellow")) self._should_run = False + self._thread_pool.shutdown(cancel_futures=True) for task in asyncio.all_tasks(loop): task.cancel() def shutdown(self, loop): - self._function_worker.shutdown() loop.create_task(self._shutdown(loop)) -async def _get_image_info_for_compute_graph( +def _get_image_info_for_compute_graph( task: Task, protocol, server_addr, config_path: str ) -> ImageInformation: namespace = task.namespace diff --git a/python-sdk/indexify/executor/api_objects.py b/python-sdk/indexify/executor/api_objects.py index 0c4538bde..0198e7f83 100644 --- a/python-sdk/indexify/executor/api_objects.py +++ b/python-sdk/indexify/executor/api_objects.py @@ -29,10 +29,6 @@ class RouterOutput(BaseModel): edges: List[str] -class FnOutput(BaseModel): - payload: Json - - class TaskResult(BaseModel): router_output: Optional[RouterOutput] = None outcome: str diff --git a/python-sdk/indexify/executor/downloader.py b/python-sdk/indexify/executor/downloader.py index 75ac9bbdc..b8b36763a 100644 --- a/python-sdk/indexify/executor/downloader.py +++ b/python-sdk/indexify/executor/downloader.py @@ -1,7 +1,7 @@ +import asyncio import os from typing import Optional -import httpx from pydantic import BaseModel from rich.console import Console from rich.panel import Panel @@ -9,9 +9,10 @@ from indexify.functions_sdk.data_objects import IndexifyData -from ..common_util import get_httpx_client +from .. import IndexifyClient from ..functions_sdk.object_serializer import JsonSerializer, get_serializer from .api_objects import Task +from .executor_tasks import DownloadTask, TaskEnum custom_theme = Theme( { @@ -31,11 +32,29 @@ class DownloadedInputs(BaseModel): class Downloader: def __init__( - self, code_path: str, base_url: str, config_path: Optional[str] = None + self, + code_path: str, + base_url: str, + indexify_client: IndexifyClient, ): self.code_path = code_path self.base_url = base_url - self._client = get_httpx_client(config_path) + self._indexify_client = indexify_client + self._event_loop = asyncio.get_event_loop() + + def download(self, task, name): + if name == TaskEnum.DOWNLOAD_GRAPH_TASK: + coroutine = self.download_graph( + task.namespace, task.compute_graph, task.graph_version + ) + elif name == TaskEnum.DOWNLOAD_INPUT_TASK: + coroutine = self.download_input(task) + else: + raise ValueError(f"Unsupported task name: {name}") + + return DownloadTask( + task=task, coroutine=coroutine, name=name, loop=self._event_loop + ) async def download_graph(self, namespace: str, name: str, version: int) -> str: path = os.path.join(self.code_path, namespace, f"{name}.{version}") @@ -50,21 +69,7 @@ async def download_graph(self, namespace: str, name: str, version: int) -> str: ) ) - response = self._client.get( - f"{self.base_url}/internal/namespaces/{namespace}/compute_graphs/{name}/code" - ) - try: - response.raise_for_status() - except httpx.HTTPStatusError as e: - console.print( - Panel( - f"Failed to download graph: {name}\nError: {response.text}", - title="downloader error", - border_style="error", - ) - ) - raise - + response = self._indexify_client.download_graph(namespace, name) os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "wb") as f: f.write(response.content) @@ -83,25 +88,29 @@ async def download_input(self, task: Task) -> DownloadedInputs: console.print( Panel( - f"downloading input\nURL: {url} \n reducer input URL: {reducer_url}", + f"downloading input\nFunction: {task.compute_fn} \n reducer id: {task.reducer_output_id}", title="downloader", border_style="cyan", ) ) - response = self._client.get(url) - - try: - response.raise_for_status() - except httpx.HTTPStatusError as e: - console.print( - Panel( - f"failed to download input: {task.input_key}\nError: {response.text}", - title="downloader error", - border_style="error", - ) + input_id = task.input_key.split("|")[-1] + if task.invocation_id == input_id: + response = self._indexify_client.download_fn_input( + task.namespace, task.compute_graph, task.invocation_id + ) + else: + response = self._indexify_client.download_fn_output(task.input_key) + + init_value = None + if task.reducer_output_id: + init_value = self._indexify_client.download_reducer_input( + task.namespace, + task.compute_graph, + task.invocation_id, + task.compute_fn, + task.reducer_output_id, ) - raise encoder = ( "json" @@ -119,19 +128,7 @@ async def download_input(self, task: Task) -> DownloadedInputs: deserialized_content = serializer.deserialize(response.content) - if reducer_url: - init_value = self._client.get(reducer_url) - try: - init_value.raise_for_status() - except httpx.HTTPStatusError as e: - console.print( - Panel( - f"failed to download reducer output: {task.reducer_output_id}\nError: {init_value.text}", - title="downloader error", - border_style="error", - ) - ) - raise + if init_value: init_value = serializer.deserialize(init_value.content) return DownloadedInputs( input=IndexifyData( diff --git a/python-sdk/indexify/executor/executor_tasks.py b/python-sdk/indexify/executor/executor_tasks.py index e4477a169..928aef977 100644 --- a/python-sdk/indexify/executor/executor_tasks.py +++ b/python-sdk/indexify/executor/executor_tasks.py @@ -1,73 +1,59 @@ import asyncio -from typing import Optional - -from indexify.functions_sdk.data_objects import IndexifyData +from enum import Enum, unique +from typing import Coroutine from .api_objects import Task -from .downloader import Downloader -from .function_worker import FunctionWorker -class DownloadGraphTask(asyncio.Task): +class RunFunctionTask(asyncio.Task): def __init__( self, *, task: Task, - downloader: Downloader, + coroutine: Coroutine, + loop: asyncio.AbstractEventLoop, **kwargs, ): - kwargs["name"] = "download_graph" - kwargs["loop"] = asyncio.get_event_loop() + kwargs["name"] = TaskEnum.RUN_FUNCTION_TASK.value + kwargs["loop"] = loop super().__init__( - downloader.download_graph( - task.namespace, task.compute_graph, task.graph_version - ), + coroutine, **kwargs, ) self.task = task -class DownloadInputTask(asyncio.Task): +class DownloadTask(asyncio.Task): def __init__( self, *, task: Task, - downloader: Downloader, + coroutine: Coroutine, + name: str, + loop: asyncio.AbstractEventLoop, **kwargs, ): - kwargs["name"] = "download_input" - kwargs["loop"] = asyncio.get_event_loop() + if not isinstance(name, TaskEnum): + raise ValueError(f"name '{name}' must be TaskEnum") + kwargs["name"] = name.value + kwargs["loop"] = loop super().__init__( - downloader.download_input(task), + coroutine, **kwargs, ) self.task = task -class ExtractTask(asyncio.Task): - def __init__( - self, - *, - function_worker: FunctionWorker, - task: Task, - input: IndexifyData, - init_value: Optional[IndexifyData] = None, - code_path: str, - **kwargs, - ): - kwargs["name"] = "run_function" - kwargs["loop"] = asyncio.get_event_loop() - super().__init__( - function_worker.async_submit( - namespace=task.namespace, - graph_name=task.compute_graph, - fn_name=task.compute_fn, - input=input, - init_value=init_value, - code_path=code_path, - version=task.graph_version, - invocation_id=task.invocation_id, - ), - **kwargs, - ) - self.task = task +@unique +class TaskEnum(Enum): + GET_RUNNABLE_TASK = "get_runnable_tasks" + RUN_FUNCTION_TASK = "run_function" + DOWNLOAD_GRAPH_TASK = "download_graph" + DOWNLOAD_INPUT_TASK = "download_input" + + @classmethod + def from_value(cls, value): + for task_name in cls: + if task_name.value == value: + return task_name + raise ValueError(f"No task found with value {value}") diff --git a/python-sdk/indexify/executor/function_worker/__init__.py b/python-sdk/indexify/executor/function_worker/__init__.py new file mode 100644 index 000000000..1849fdc64 --- /dev/null +++ b/python-sdk/indexify/executor/function_worker/__init__.py @@ -0,0 +1,3 @@ +from .function_worker import FunctionWorker + +__all__ = ["FunctionWorker"] diff --git a/python-sdk/indexify/executor/function_worker.py b/python-sdk/indexify/executor/function_worker/function_worker.py similarity index 56% rename from python-sdk/indexify/executor/function_worker.py rename to python-sdk/indexify/executor/function_worker/function_worker.py index 3ea594a8e..e37ef3bd1 100644 --- a/python-sdk/indexify/executor/function_worker.py +++ b/python-sdk/indexify/executor/function_worker/function_worker.py @@ -1,12 +1,17 @@ +import asyncio import sys import traceback -from typing import Dict, List, Optional +from typing import List, Optional -import cloudpickle from pydantic import BaseModel from rich import print from indexify import IndexifyClient +from indexify.executor.api_objects import Task +from indexify.executor.executor_tasks import RunFunctionTask +from indexify.executor.function_worker.function_worker_utils import ( + _load_function, +) from indexify.functions_sdk.data_objects import ( FunctionWorkerOutput, IndexifyData, @@ -14,17 +19,9 @@ ) from indexify.functions_sdk.indexify_functions import ( FunctionCallResult, - GraphInvocationContext, - IndexifyFunction, - IndexifyFunctionWrapper, - IndexifyRouter, RouterCallResult, ) -function_wrapper_map: Dict[str, IndexifyFunctionWrapper] = {} - -import concurrent.futures - class FunctionRunException(Exception): def __init__( @@ -45,72 +42,60 @@ class FunctionOutput(BaseModel): stdout: str = "" stderr: str = "" - -def _load_function( - namespace: str, - graph_name: str, - fn_name: str, - code_path: str, - version: int, - invocation_id: str, - indexify_client: IndexifyClient, -): - """Load an extractor to the memory: extractor_wrapper_map.""" - global function_wrapper_map - key = f"{namespace}/{graph_name}/{version}/{fn_name}" - if key in function_wrapper_map: - return - with open(code_path, "rb") as f: - code = f.read() - pickled_functions = cloudpickle.loads(code) - context = GraphInvocationContext( - invocation_id=invocation_id, - graph_name=graph_name, - graph_version=str(version), - indexify_client=indexify_client, - ) - function_wrapper = IndexifyFunctionWrapper( - cloudpickle.loads(pickled_functions[fn_name]), - context, - ) - function_wrapper_map[key] = function_wrapper - - class FunctionWorker: def __init__( - self, workers: int = 1, indexify_client: IndexifyClient = None + self, + indexify_client: IndexifyClient = None, ) -> None: - self._executor: concurrent.futures.ProcessPoolExecutor = ( - concurrent.futures.ProcessPoolExecutor(max_workers=workers) - ) - self._workers = workers - self._indexify_client = indexify_client + self._indexify_client: IndexifyClient = indexify_client + self._loop = asyncio.get_event_loop() - async def async_submit( + def run_function( self, - namespace: str, - graph_name: str, - fn_name: str, - input: IndexifyData, + task: Task, + fn_input: IndexifyData, + init_value: IndexifyData | None, code_path: str, - version: int, - init_value: Optional[IndexifyData] = None, - invocation_id: Optional[str] = None, - ) -> FunctionWorkerOutput: + ): + return RunFunctionTask( + task=task, + coroutine=self.async_submit( + namespace=task.namespace, + graph_name=task.compute_graph, + fn_name=task.compute_fn, + input=fn_input, + init_value=init_value, + code_path=code_path, + version=task.graph_version, + invocation_id=task.invocation_id, + ), + loop=self._loop, + ) + + async def async_submit(self, **kwargs) -> FunctionWorkerOutput: try: - result = _run_function( - namespace, - graph_name, - fn_name, - input, - code_path, - version, - init_value, - invocation_id, + print(f"Submitting async function.....") + result = await _run_function( + kwargs["namespace"], + kwargs["graph_name"], + kwargs["fn_name"], + kwargs["input"], + kwargs["code_path"], + kwargs["version"], + kwargs["init_value"], + kwargs["invocation_id"], self._indexify_client, ) - # TODO - bring back running in a separate process + return FunctionWorkerOutput( + fn_outputs=result.fn_outputs, + router_output=result.router_output, + stdout=result.stdout, + stderr=result.stderr, + reducer=result.reducer, + success=result.success, + ) except Exception as e: + print(e) return FunctionWorkerOutput( stdout=e.stdout, stderr=e.stderr, @@ -118,20 +103,8 @@ async def async_submit( success=False, ) - return FunctionWorkerOutput( - fn_outputs=result.fn_outputs, - router_output=result.router_output, - stdout=result.stdout, - stderr=result.stderr, - reducer=result.reducer, - success=result.success, - ) - - def shutdown(self): - self._executor.shutdown(wait=True, cancel_futures=True) - -def _run_function( +async def _run_function( namespace: str, graph_name: str, fn_name: str, @@ -156,19 +129,15 @@ def _run_function( ) with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture): try: - key = f"{namespace}/{graph_name}/{version}/{fn_name}" - if key not in function_wrapper_map: - _load_function( - namespace, - graph_name, - fn_name, - code_path, - version, - invocation_id, - indexify_client, - ) - - fn = function_wrapper_map[key] + fn = _load_function( + namespace, + graph_name, + fn_name, + code_path, + version, + invocation_id, + indexify_client, + ) if ( str(type(fn.indexify_function)) == "" @@ -179,9 +148,16 @@ def _run_function( print(router_call_result.traceback_msg, file=sys.stderr) has_failed = True else: - fn_call_result: FunctionCallResult = fn.invoke_fn_ser( - fn_name, input, init_value - ) + print(f"is function async: {fn.indexify_function.is_async}") + if not fn.indexify_function.is_async: + fn_call_result: FunctionCallResult = fn.invoke_fn_ser( + fn_name, input, init_value + ) + else: + fn_call_result: FunctionCallResult = await fn.invoke_fn_ser_async( + fn_name, input, init_value + ) + print(f"serialized function output: {fn_call_result}") is_reducer = fn.indexify_function.accumulate is not None fn_output = fn_call_result.ser_outputs if fn_call_result.traceback_msg is not None: diff --git a/python-sdk/indexify/executor/function_worker/function_worker_utils.py b/python-sdk/indexify/executor/function_worker/function_worker_utils.py new file mode 100644 index 000000000..8cfca4b43 --- /dev/null +++ b/python-sdk/indexify/executor/function_worker/function_worker_utils.py @@ -0,0 +1,36 @@ +import os +from functools import cache + +import cloudpickle + +from indexify import IndexifyClient +from indexify.functions_sdk.indexify_functions import ( + GraphInvocationContext, + IndexifyFunctionWrapper, +) + + +@cache +def _load_function( + namespace: str, + graph_name: str, + fn_name: str, + code_path: str, + version: int, + invocation_id: str, + indexify_client: IndexifyClient, +): + """Load an extractor to the memory: extractor_wrapper_map.""" + with open(code_path, "rb") as f: + code = f.read() + pickled_functions = cloudpickle.loads(code) + context = GraphInvocationContext( + invocation_id=invocation_id, + graph_name=graph_name, + graph_version=str(version), + indexify_client=indexify_client, + ) + return IndexifyFunctionWrapper( + cloudpickle.loads(pickled_functions[fn_name]), + context, + ) diff --git a/python-sdk/indexify/executor/indexify_executor.py b/python-sdk/indexify/executor/indexify_executor.py index 087c1096d..f843b205f 100644 --- a/python-sdk/indexify/executor/indexify_executor.py +++ b/python-sdk/indexify/executor/indexify_executor.py @@ -1,10 +1,11 @@ import asyncio -from typing import List, Optional +from typing import Optional import nanoid +from indexify.executor.function_worker.function_worker import FunctionWorker + from .agent import ExtractorAgent -from .function_worker import FunctionWorker def join( diff --git a/python-sdk/indexify/executor/task_reporter/__init__.py b/python-sdk/indexify/executor/task_reporter/__init__.py new file mode 100644 index 000000000..c2f246faa --- /dev/null +++ b/python-sdk/indexify/executor/task_reporter/__init__.py @@ -0,0 +1,3 @@ +from .task_reporter import TaskReporter + +__all__ = ["TaskReporter"] diff --git a/python-sdk/indexify/executor/task_reporter.py b/python-sdk/indexify/executor/task_reporter/task_reporter.py similarity index 79% rename from python-sdk/indexify/executor/task_reporter.py rename to python-sdk/indexify/executor/task_reporter/task_reporter.py index c905eb94e..cd6e9c4cc 100644 --- a/python-sdk/indexify/executor/task_reporter.py +++ b/python-sdk/indexify/executor/task_reporter/task_reporter.py @@ -1,15 +1,21 @@ -import io +import asyncio from typing import Optional import nanoid from httpx import Timeout from pydantic import BaseModel from rich import print +from rich.text import Text -from indexify.common_util import get_httpx_client +from indexify import IndexifyClient from indexify.executor.api_objects import RouterOutput as ApiRouterOutput from indexify.executor.api_objects import TaskResult -from indexify.executor.task_store import CompletedTask +from indexify.executor.task_reporter.task_reporter_utils import ( + _log, + _log_exception, + console, +) +from indexify.executor.task_store import CompletedTask, TaskStore from indexify.functions_sdk.object_serializer import get_serializer @@ -34,14 +40,36 @@ class ReportingData(BaseModel): class TaskReporter: def __init__( - self, base_url: str, executor_id: str, config_path: Optional[str] = None + self, + base_url: str, + executor_id: str, + indexify_client: IndexifyClient, ): self._base_url = base_url self._executor_id = executor_id - self._client = get_httpx_client(config_path) + self._client = indexify_client + + async def run(self): + console.print(Text("Starting TaskReporter", style="bold cyan")) + # We should copy only the keys and not the values + + while True: + outcomes = await self._task_store.task_outcomes() + for task_outcome in outcomes: + _log(task_outcome) + try: + # Send task outcome to the server + self.report_task_outcome(completed_task=task_outcome) + except Exception as e: + # The connection was dropped in the middle of the reporting, process, retry + _log_exception(task_outcome, e) + task_outcome.reporting_retries += 1 + await asyncio.sleep(5) + continue + + self._task_store.mark_reported(task_id=task_outcome.task.id) def report_task_outcome(self, completed_task: CompletedTask): - report = ReportingData() fn_outputs = [] for output in completed_task.outputs or []: diff --git a/python-sdk/indexify/executor/task_reporter/task_reporter_utils.py b/python-sdk/indexify/executor/task_reporter/task_reporter_utils.py new file mode 100644 index 000000000..dddb8ce3f --- /dev/null +++ b/python-sdk/indexify/executor/task_reporter/task_reporter_utils.py @@ -0,0 +1,47 @@ +from rich.console import Console +from rich.panel import Panel +from rich.theme import Theme + + +def _log_exception(task_outcome, e): + console.print( + Panel( + f"Failed to report task {task_outcome.task.id}\n" + f"Exception: {type(e).__name__}({e})\n" + f"Retries: {task_outcome.reporting_retries}\n" + "Retrying...", + title="Reporting Error", + border_style="error", + ) + ) + + +def _log(task_outcome): + outcome = task_outcome.task_outcome + style_outcome = ( + f"[bold red] {outcome} [/]" + if "fail" in outcome + else f"[bold green] {outcome} [/]" + ) + console.print( + Panel( + f"Reporting outcome of task: {task_outcome.task.id}, function: {task_outcome.task.compute_fn}\n" + f"Outcome: {style_outcome}\n" + f"Num Fn Outputs: {len(task_outcome.outputs or [])}\n" + f"Router Output: {task_outcome.router_output}\n" + f"Retries: {task_outcome.reporting_retries}", + title="Task Completion", + border_style="info", + ) + ) + + +custom_theme = Theme( + { + "info": "cyan", + "warning": "yellow", + "error": "red", + "highlight": "magenta", + } +) +console = Console(theme=custom_theme) diff --git a/python-sdk/indexify/functions_sdk/indexify_functions.py b/python-sdk/indexify/functions_sdk/indexify_functions.py index b57db40ba..d1b0add0d 100644 --- a/python-sdk/indexify/functions_sdk/indexify_functions.py +++ b/python-sdk/indexify/functions_sdk/indexify_functions.py @@ -84,10 +84,14 @@ class IndexifyFunction: placement_constraints: List[PlacementConstraints] = [] accumulate: Optional[Type[Any]] = None encoder: Optional[str] = "cloudpickle" + is_async: bool = False def run(self, *args, **kwargs) -> Union[List[Any], Any]: pass + async def async_run(self, *args, **kwargs) -> Union[List[Any], Any]: + pass + def partial(self, **kwargs) -> Callable: from functools import partial @@ -170,6 +174,9 @@ def construct(fn): def run(self, *args, **kwargs): return fn(*args, **kwargs) + async def async_run(self, *args, **kwargs): + return await fn(*args, **kwargs) + # Apply original signature and annotations to run method run.__signature__ = fn_sig run.__annotations__ = fn_hints @@ -186,6 +193,8 @@ def run(self, *args, **kwargs): "accumulate": accumulate, "encoder": encoder, "run": run, + "async_run": async_run, + "is_async": inspect.iscoroutinefunction(fn), } return type("IndexifyFunction", (IndexifyFunction,), attrs) @@ -276,6 +285,30 @@ def run_fn( ) return output, None + async def run_fn_async( + self, input: Union[Dict, Type[BaseModel]], acc: Type[Any] = None + ) -> Tuple[List[Any], Optional[str]]: + args = [] + kwargs = {} + if acc is not None: + args.append(acc) + if isinstance(input, dict): + kwargs = input + else: + args.append(input) + + try: + extracted_data = await self.indexify_function.run(*args, **kwargs) + except Exception as e: + return [], traceback.format_exc() + if extracted_data is None: + return [], None + + output = ( + extracted_data if isinstance(extracted_data, list) else [extracted_data] + ) + return output, None + def invoke_fn_ser( self, name: str, input: IndexifyData, acc: Optional[Any] = None ) -> FunctionCallResult: @@ -299,6 +332,29 @@ def invoke_fn_ser( ] return FunctionCallResult(ser_outputs=ser_outputs, traceback_msg=err) + async def invoke_fn_ser_async( + self, name: str, input: IndexifyData, acc: Optional[Any] = None + ) -> FunctionCallResult: + input = self.deserialize_input(name, input) + serializer = get_serializer(self.indexify_function.encoder) + if acc is not None: + acc = self.indexify_function.accumulate.model_validate( + serializer.deserialize(acc.payload) + ) + if acc is None and self.indexify_function.accumulate is not None: + acc = self.indexify_function.accumulate.model_validate( + self.indexify_function.accumulate() + ) + outputs, err = await self.run_fn_async(input, acc=acc) + ser_outputs = [ + IndexifyData( + payload=serializer.serialize(output), + encoder=self.indexify_function.encoder, + ) + for output in outputs + ] + return FunctionCallResult(ser_outputs=ser_outputs, traceback_msg=err) + def invoke_router(self, name: str, input: IndexifyData) -> RouterCallResult: input = self.deserialize_input(name, input) edges, err = self.run_router(input) diff --git a/python-sdk/indexify/http_client.py b/python-sdk/indexify/http_client.py index b2a080cb4..467357c41 100644 --- a/python-sdk/indexify/http_client.py +++ b/python-sdk/indexify/http_client.py @@ -164,6 +164,31 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() + def download_graph(self, namespace: str, compute_graph: str): + return self._get( + f"internal/namespaces/{namespace}/compute_graphs/{compute_graph}/code" + ) + + def download_fn_input(self, namespace: str, compute_graph: str, invocation_id: str): + return self._get( + f"namespaces/{namespace}/compute_graphs/{compute_graph}/invocations/{invocation_id}/payload" + ) + + def download_fn_output(self, input_key: str): + return self._get(f"internal/fn_outputs/{input_key}") + + def download_reducer_input( + self, + namespace: str, + compute_graph: str, + invocation_id: str, + compute_fn: str, + reducer_output_id: str, + ): + return self._get( + f"namespaces/{namespace}/compute_graphs/{compute_graph}/invocations/{invocation_id}/fn/{compute_fn}/output/{reducer_output_id}" + ) + def register_compute_graph(self, graph: Graph, additional_modules): graph_metadata = graph.definition() serialized_code = cloudpickle.dumps(graph.serialize(additional_modules)) diff --git a/python-sdk/tests/test_graph_behaviours.py b/python-sdk/tests/test_graph_behaviours.py index 73c2b819c..8498a3506 100644 --- a/python-sdk/tests/test_graph_behaviours.py +++ b/python-sdk/tests/test_graph_behaviours.py @@ -227,7 +227,23 @@ def remote_or_local_pipeline(pipeline, remote=True): return pipeline +@indexify_function() +async def async_simple_function(x: int) -> int: + return x * x + + class TestGraphBehaviors(unittest.TestCase): + def test_async_simple_function(self): + graph = Graph( + name="test_async_simple_function", + description="test", + start_node=async_simple_function, + ) + graph = RemoteGraph.deploy(graph) + invocation_id = graph.run(block_until_done=True, x=10) + output = graph.output(invocation_id, "async_simple_function") + self.assertEqual(output, [100]) + @parameterized.expand([(False), (True)]) def test_simple_function(self, is_remote): graph = Graph(