diff --git a/mlrun/serving/routers.py b/mlrun/serving/routers.py index 8ea8e1c5d6..9e6894ef0c 100644 --- a/mlrun/serving/routers.py +++ b/mlrun/serving/routers.py @@ -37,6 +37,15 @@ from .v2_serving import _ModelLogPusher +class ExecutorTypes: + thread = "thread" + process = "process" + + @staticmethod + def all(): + return [ExecutorTypes.thread, ExecutorTypes.process] + + class BaseModelRouter(RouterToDict): """base model router class""" @@ -985,3 +994,181 @@ def preprocess(self, event): event.body["inputs"], as_list=True ) return event + + +class ParallelRun(BaseModelRouter): + def __init__( + self, + context=None, + name: str = None, + routes=None, + extend_event=None, + executor_type: ExecutorTypes = None, + **kwargs, + ): + """Process multiple steps (child routes) in parallel and merge the results + + By default the results dict from each step are merged (by key), when setting the `extend_event` + the results will start from the event body dict (values can be overwritten) + + Users can overwrite the merger() method to implement custom merging logic. + + Example:: + + # create a function with a parallel router and 3 children + fn = mlrun.new_function("parallel", kind="serving") + graph = fn.set_topology( + "router", + mlrun.serving.routers.ParallelRun(extend_event=True, executor_type=executor), + ) + graph.add_route("child1", class_name="Cls1") + graph.add_route("child2", class_name="Cls2", my_arg={"c": 7}) + graph.add_route("child3", handler="my_handler") + server = fn.to_mock_server() + resp = server.test("", {"x": 8}) + + + :param context: for internal use (passed in init) + :param name: step name + :param routes: for internal use (routes passed in init) + :param executor_type: Parallelism mechanism, "thread" or "process" + :param extend_event: True will add the event body to the result + :param input_path: when specified selects the key/path in the event to use as body + this require that the event body will behave like a dict, example: + event: {"data": {"a": 5, "b": 7}}, input_path="data.b" means request body will be 7 + :param result_path: selects the key/path in the event to write the results to + this require that the event body will behave like a dict, example: + event: {"x": 5} , result_path="resp" means the returned response will be written + to event["y"] resulting in {"x": 5, "resp": } + :param vote_type: Voting type to be used (from `VotingTypes`). + by default will try to self-deduct upon the first event: + - float prediction type: regression + - int prediction type: classification + :param kwargs: extra arguments + """ + super().__init__(context, name, routes, **kwargs) + self.name = name or "ParallelRun" + if executor_type and executor_type not in ExecutorTypes.all(): + raise ValueError( + f"executor_type must be one of {' | '.join(ExecutorTypes.all())}" + ) + self.executor_type = executor_type + self.extend_event = extend_event + + self._pool = None + + def _init_pool(self): + if self._pool is None: + if self.executor_type == ExecutorTypes.process: + # init the context and route on the worker side (cannot be pickeled) + server = self.context.server.to_dict() + routes = {} + for key, route in self.routes.items(): + step = copy.copy(route) + step.context = None + step._parent = None + if step._object: + step._object.context = None + routes[key] = step + executor_class = concurrent.futures.ProcessPoolExecutor + self._pool = executor_class( + max_workers=len(self.routes), + initializer=init_pool, + initargs=( + server, + routes, + ), + ) + else: + executor_class = concurrent.futures.ThreadPoolExecutor + self._pool = executor_class(max_workers=len(self.routes)) + + return self._pool + + def _shutdown_pool(self): + if self._pool is not None: + self._pool.shutdown() + self._pool = None + + def merger(self, body, results): + """Merging logic + + input the event body and a dict of route results and returns a dict with merged results + """ + for result in results.values(): + body.update(result) + return body + + def do_event(self, event, *args, **kwargs): + # Handle and verify the request + original_body = event.body + event.body = _extract_input_data(self._input_path, event.body) + event = self.preprocess(event) + event = self._pre_handle_event(event) + + # Should we terminate the event? + if hasattr(event, "terminated") and event.terminated: + event.body = _update_result_body( + self._result_path, original_body, event.body + ) + self._shutdown_pool() + return event + + # Verify we use the V2 protocol + results = self._parallel_run(event) + response = copy.copy(event) + if self.extend_event: + body = copy.copy(event.body) + else: + body = {} + response.body = self.merger(body, results) + response = self.postprocess(response) + + event.body = _update_result_body( + self._result_path, original_body, response.body if response else None + ) + return event + + def _parallel_run(self, event): + futures = [] + results = {} + executor = self._init_pool() + for route in self.routes.keys(): + if self.executor_type == ExecutorTypes.process: + future = executor.submit(_wrap_step, route, copy.copy(event)) + else: + step = self.routes[route] + future = executor.submit( + _wrap_method, route, step.run, copy.copy(event) + ) + + futures.append(future) + + for future in concurrent.futures.as_completed(futures): + try: + key, result = future.result() + results[key] = result.body + except Exception as exc: + logger.error(traceback.format_exc()) + print(f"child route generated an exception: {exc}") + self.context.logger.debug(f"Collected results from children: {results}") + return results + + +def init_pool(server_spec, routes): + server = mlrun.serving.GraphServer.from_dict(server_spec) + server.init_states(None, None) + global local_routes + for route in routes.values(): + route.context = server.context + if route._object: + route._object.context = server.context + local_routes = routes + + +def _wrap_step(route, event): + return route, local_routes[route].run(event) + + +def _wrap_method(route, handler, event): + return route, handler(event) diff --git a/mlrun/serving/states.py b/mlrun/serving/states.py index ba3953592f..6fe8214709 100644 --- a/mlrun/serving/states.py +++ b/mlrun/serving/states.py @@ -538,7 +538,7 @@ def add_route( :param function: function this step should run in """ - if not route and not class_name: + if not route and not class_name and not handler: raise MLRunInvalidArgumentError("route or class_name must be specified") if not route: route = TaskStep(class_name, class_args, handler=handler) diff --git a/tests/serving/test_parallel.py b/tests/serving/test_parallel.py new file mode 100644 index 0000000000..40aa9cb6c3 --- /dev/null +++ b/tests/serving/test_parallel.py @@ -0,0 +1,42 @@ +import pytest + +import mlrun +import mlrun.serving + + +class Echo: + """example class""" + + def __init__(self, context, name=None, data={}): + self.context = context + self.name = name + self.data = data + + def do(self, x): + self.context.logger.info("test text") + return self.data + + +def my_hnd(event): + """example handler""" + return {"mul": event["x"] * 2} + + +@pytest.mark.parametrize("executor", mlrun.serving.routers.ExecutorTypes.all()) +def test_parallel(executor): + fn = mlrun.new_function("tests", kind="serving") + graph = fn.set_topology( + "router", + mlrun.serving.routers.ParallelRun(extend_event=True, executor_type=executor), + ) + graph.add_route("c1", class_name="Echo", data={"a": 1, "b": 2}) + graph.add_route("c2", class_name="Echo", data={"c": 7}) + graph.add_route("c3", handler="my_hnd") + + server = fn.to_mock_server() + + resp = server.test(body={"x": 8}) + assert resp == {"x": 8, "a": 1, "b": 2, "c": 7, "mul": 16} + + resp = server.test("", {"x": 9}) + assert resp == {"x": 9, "a": 1, "b": 2, "c": 7, "mul": 18}