Skip to content

Commit

Permalink
[Serving] Add ParallelRun router for running multiple steps in parall…
Browse files Browse the repository at this point in the history
…el threads/processes (mlrun#2136)
  • Loading branch information
yaronha authored Jul 18, 2022
1 parent d28fc99 commit 640f925
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 1 deletion.
187 changes: 187 additions & 0 deletions mlrun/serving/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@
from .v2_serving import _ModelLogPusher


class ExecutorTypes:
thread = "thread"
process = "process"

@staticmethod
def all():
return [ExecutorTypes.thread, ExecutorTypes.process]


class BaseModelRouter(RouterToDict):
"""base model router class"""

Expand Down Expand Up @@ -985,3 +994,181 @@ def preprocess(self, event):
event.body["inputs"], as_list=True
)
return event


class ParallelRun(BaseModelRouter):
def __init__(
self,
context=None,
name: str = None,
routes=None,
extend_event=None,
executor_type: ExecutorTypes = None,
**kwargs,
):
"""Process multiple steps (child routes) in parallel and merge the results
By default the results dict from each step are merged (by key), when setting the `extend_event`
the results will start from the event body dict (values can be overwritten)
Users can overwrite the merger() method to implement custom merging logic.
Example::
# create a function with a parallel router and 3 children
fn = mlrun.new_function("parallel", kind="serving")
graph = fn.set_topology(
"router",
mlrun.serving.routers.ParallelRun(extend_event=True, executor_type=executor),
)
graph.add_route("child1", class_name="Cls1")
graph.add_route("child2", class_name="Cls2", my_arg={"c": 7})
graph.add_route("child3", handler="my_handler")
server = fn.to_mock_server()
resp = server.test("", {"x": 8})
:param context: for internal use (passed in init)
:param name: step name
:param routes: for internal use (routes passed in init)
:param executor_type: Parallelism mechanism, "thread" or "process"
:param extend_event: True will add the event body to the result
:param input_path: when specified selects the key/path in the event to use as body
this require that the event body will behave like a dict, example:
event: {"data": {"a": 5, "b": 7}}, input_path="data.b" means request body will be 7
:param result_path: selects the key/path in the event to write the results to
this require that the event body will behave like a dict, example:
event: {"x": 5} , result_path="resp" means the returned response will be written
to event["y"] resulting in {"x": 5, "resp": <result>}
:param vote_type: Voting type to be used (from `VotingTypes`).
by default will try to self-deduct upon the first event:
- float prediction type: regression
- int prediction type: classification
:param kwargs: extra arguments
"""
super().__init__(context, name, routes, **kwargs)
self.name = name or "ParallelRun"
if executor_type and executor_type not in ExecutorTypes.all():
raise ValueError(
f"executor_type must be one of {' | '.join(ExecutorTypes.all())}"
)
self.executor_type = executor_type
self.extend_event = extend_event

self._pool = None

def _init_pool(self):
if self._pool is None:
if self.executor_type == ExecutorTypes.process:
# init the context and route on the worker side (cannot be pickeled)
server = self.context.server.to_dict()
routes = {}
for key, route in self.routes.items():
step = copy.copy(route)
step.context = None
step._parent = None
if step._object:
step._object.context = None
routes[key] = step
executor_class = concurrent.futures.ProcessPoolExecutor
self._pool = executor_class(
max_workers=len(self.routes),
initializer=init_pool,
initargs=(
server,
routes,
),
)
else:
executor_class = concurrent.futures.ThreadPoolExecutor
self._pool = executor_class(max_workers=len(self.routes))

return self._pool

def _shutdown_pool(self):
if self._pool is not None:
self._pool.shutdown()
self._pool = None

def merger(self, body, results):
"""Merging logic
input the event body and a dict of route results and returns a dict with merged results
"""
for result in results.values():
body.update(result)
return body

def do_event(self, event, *args, **kwargs):
# Handle and verify the request
original_body = event.body
event.body = _extract_input_data(self._input_path, event.body)
event = self.preprocess(event)
event = self._pre_handle_event(event)

# Should we terminate the event?
if hasattr(event, "terminated") and event.terminated:
event.body = _update_result_body(
self._result_path, original_body, event.body
)
self._shutdown_pool()
return event

# Verify we use the V2 protocol
results = self._parallel_run(event)
response = copy.copy(event)
if self.extend_event:
body = copy.copy(event.body)
else:
body = {}
response.body = self.merger(body, results)
response = self.postprocess(response)

event.body = _update_result_body(
self._result_path, original_body, response.body if response else None
)
return event

def _parallel_run(self, event):
futures = []
results = {}
executor = self._init_pool()
for route in self.routes.keys():
if self.executor_type == ExecutorTypes.process:
future = executor.submit(_wrap_step, route, copy.copy(event))
else:
step = self.routes[route]
future = executor.submit(
_wrap_method, route, step.run, copy.copy(event)
)

futures.append(future)

for future in concurrent.futures.as_completed(futures):
try:
key, result = future.result()
results[key] = result.body
except Exception as exc:
logger.error(traceback.format_exc())
print(f"child route generated an exception: {exc}")
self.context.logger.debug(f"Collected results from children: {results}")
return results


def init_pool(server_spec, routes):
server = mlrun.serving.GraphServer.from_dict(server_spec)
server.init_states(None, None)
global local_routes
for route in routes.values():
route.context = server.context
if route._object:
route._object.context = server.context
local_routes = routes


def _wrap_step(route, event):
return route, local_routes[route].run(event)


def _wrap_method(route, handler, event):
return route, handler(event)
2 changes: 1 addition & 1 deletion mlrun/serving/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def add_route(
:param function: function this step should run in
"""

if not route and not class_name:
if not route and not class_name and not handler:
raise MLRunInvalidArgumentError("route or class_name must be specified")
if not route:
route = TaskStep(class_name, class_args, handler=handler)
Expand Down
42 changes: 42 additions & 0 deletions tests/serving/test_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

import mlrun
import mlrun.serving


class Echo:
"""example class"""

def __init__(self, context, name=None, data={}):
self.context = context
self.name = name
self.data = data

def do(self, x):
self.context.logger.info("test text")
return self.data


def my_hnd(event):
"""example handler"""
return {"mul": event["x"] * 2}


@pytest.mark.parametrize("executor", mlrun.serving.routers.ExecutorTypes.all())
def test_parallel(executor):
fn = mlrun.new_function("tests", kind="serving")
graph = fn.set_topology(
"router",
mlrun.serving.routers.ParallelRun(extend_event=True, executor_type=executor),
)
graph.add_route("c1", class_name="Echo", data={"a": 1, "b": 2})
graph.add_route("c2", class_name="Echo", data={"c": 7})
graph.add_route("c3", handler="my_hnd")

server = fn.to_mock_server()

resp = server.test(body={"x": 8})
assert resp == {"x": 8, "a": 1, "b": 2, "c": 7, "mul": 16}

resp = server.test("", {"x": 9})
assert resp == {"x": 9, "a": 1, "b": 2, "c": 7, "mul": 18}

0 comments on commit 640f925

Please sign in to comment.