Skip to content

Commit

Permalink
Various improvements and additions
Browse files Browse the repository at this point in the history
  • Loading branch information
gtopper committed Nov 7, 2024
1 parent 57a22be commit e9a1814
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 19 deletions.
39 changes: 25 additions & 14 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,9 +1433,11 @@ def set_table(self, key, table):


class ParallelExecutionRunnable:
execution_mechanism = "multiprocessing"
execution_mechanism = None

def __init__(self, name):
if not self.execution_mechanism:
raise ValueError("ParallelExecutionRunnable's execution_mechanism attribute must be overridden")
self.name = name

def init(self):
Expand All @@ -1449,42 +1451,51 @@ class ParallelExecution(Flow):
def __init__(self, runnables, **kwargs):
super().__init__(**kwargs)
self._runnables = runnables
self._runnable_by_name = {}

def select_runnables(self, event):
return self._runnables

def _init(self):
super()._init()
num_process = 0
num_thread = 0
names = set()
for runnable in self._runnables:
if runnable.name in names:
if runnable.name in self._runnable_by_name:
raise ValueError(f"ParallelExecutionRunnable name '{runnable.name}' is not unique")
names.add(runnable.name)
self._runnable_by_name[runnable.name] = runnable
runnable.init()
if runnable.execution_mechanism == "multiprocessing":
num_process += 1
elif runnable.execution_mechanism == "thread":
elif runnable.execution_mechanism == "threading":
num_thread += 1
elif runnable.execution_mechanism != "async":
elif runnable.execution_mechanism not in ("asyncio", "naive"):
raise ValueError(f"Unsupported execution mechanism: {runnable.execution_mechanism}")
self._executors = {}
if num_process:
self._executors["multiprocessing"] = ProcessPoolExecutor(max_workers=num_process)
if num_thread:
self._executors["thread"] = ThreadPoolExecutor(max_workers=num_thread)
self._executors["threading"] = ThreadPoolExecutor(max_workers=num_thread)

async def _do(self, event):
if event is _termination_obj:
return await self._do_downstream(_termination_obj)
else:
tasks = []
for runnable in self._runnables:
if runnable.execution_mechanism == "async":
task = asyncio.get_running_loop().create_task(runnable.run(event))
runnables = self.select_runnables(event)
futures = []
for runnable in runnables:
if isinstance(runnable, str):
runnable = self._runnable_by_name[runnable]
if runnable.execution_mechanism == "asyncio":
future = asyncio.get_running_loop().create_task(runnable.run(event))
elif runnable.execution_mechanism == "naive":
future = asyncio.get_running_loop().create_future()
future.set_result(runnable.run(event))
else:
executor = self._executors[runnable.execution_mechanism]
task = asyncio.get_running_loop().run_in_executor(executor, runnable.run, event)
tasks.append(task)
results = await asyncio.gather(*tasks)
future = asyncio.get_running_loop().run_in_executor(executor, runnable.run, event)
futures.append(future)
results = await asyncio.gather(*futures)
event.body = {"inputs": event.body, "outputs": {}}
for index, result in enumerate(results):
event.body["outputs"][self._runnables[index].name] = result
Expand Down
39 changes: 34 additions & 5 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4652,6 +4652,7 @@ def test_filters_type():


class RunnableBusyWait(ParallelExecutionRunnable):
execution_mechanism = "multiprocessing"
_result = 0

def init(self):
Expand All @@ -4665,7 +4666,7 @@ def run(self, event):


class RunnableSleep(ParallelExecutionRunnable):
execution_mechanism = "thread"
execution_mechanism = "threading"
_result = 0

def init(self):
Expand All @@ -4677,7 +4678,7 @@ def run(self, event):


class RunnableAsyncSleep(ParallelExecutionRunnable):
execution_mechanism = "async"
execution_mechanism = "asyncio"
_result = 0

def init(self):
Expand All @@ -4688,6 +4689,24 @@ async def run(self, event):
return self._result


class RunnableAsyncNaive(ParallelExecutionRunnable):
execution_mechanism = "naive"
_result = 0

def init(self):
self._result = 1

def run(self, event):
return self._result


class RunnableWithError(ParallelExecutionRunnable):
execution_mechanism = "naive"

def run(self, event):
raise Exception("This shouldn't run!")


def test_parallel_execution_uniqueness():
runnables = [
RunnableBusyWait("x"),
Expand All @@ -4706,8 +4725,15 @@ def test_parallel_execution():
RunnableSleep("sleep2"),
RunnableAsyncSleep("asleep1"),
RunnableAsyncSleep("asleep2"),
RunnableAsyncSleep("naive"),
RunnableWithError("error"),
]
parallel_execution = ParallelExecution(runnables)

class MyParallelExecution(ParallelExecution):
def select_runnables(self, event):
return [runnable.name for runnable in runnables if runnable.name != "error"]

parallel_execution = MyParallelExecution(runnables)
reduce = Reduce([], lambda acc, x: acc + [x])

source = SyncEmitSource()
Expand All @@ -4720,7 +4746,10 @@ def test_parallel_execution():
result = controller.await_termination()
end = time.monotonic()

assert end - start < len(runnables)
assert end - start < 6
assert result == [
{"inputs": 0, "outputs": {"busy1": 1, "busy2": 1, "sleep1": 1, "sleep2": 1, "asleep1": 1, "asleep2": 1}}
{
"inputs": 0,
"outputs": {"busy1": 1, "busy2": 1, "sleep1": 1, "sleep2": 1, "asleep1": 1, "asleep2": 1, "naive": 1},
}
]

0 comments on commit e9a1814

Please sign in to comment.