Skip to content

Commit

Permalink
Add error on duplicate runnable selection, similar to mlrun#545
Browse files Browse the repository at this point in the history
  • Loading branch information
gtopper committed Nov 17, 2024
1 parent e9a1814 commit cbf7e32
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
4 changes: 4 additions & 0 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,9 +1483,13 @@ async def _do(self, event):
else:
runnables = self.select_runnables(event)
futures = []
runnables_encountered = set()
for runnable in runnables:
if isinstance(runnable, str):
runnable = self._runnable_by_name[runnable]
if id(runnable) in runnables_encountered:
raise ValueError(f"select_runnables() returned more than one outlet named '{runnable.name}'")
runnables_encountered.add(id(runnable))
if runnable.execution_mechanism == "asyncio":
future = asyncio.get_running_loop().create_task(runnable.run(event))
elif runnable.execution_mechanism == "naive":
Expand Down
28 changes: 25 additions & 3 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4689,7 +4689,7 @@ async def run(self, event):
return self._result


class RunnableAsyncNaive(ParallelExecutionRunnable):
class RunnableNaiveNoOp(ParallelExecutionRunnable):
execution_mechanism = "naive"
_result = 0

Expand All @@ -4707,16 +4707,38 @@ def run(self, event):
raise Exception("This shouldn't run!")


def test_parallel_execution_uniqueness():
def test_parallel_execution_runnable_uniqueness():
runnables = [
RunnableBusyWait("x"),
RunnableBusyWait("x"),
]
parallel_execution = ParallelExecution(runnables)
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="ParallelExecutionRunnable name 'x' is not unique"):
parallel_execution._init()


def test_select_runnable_uniqueness():
runnables = [
RunnableNaiveNoOp("x"),
RunnableNaiveNoOp("y"),
]

class MyParallelExecution(ParallelExecution):
def select_runnables(self, event):
return ["x", "x"]

parallel_execution = MyParallelExecution(runnables)

source = SyncEmitSource()
source.to(parallel_execution)

controller = source.run()
controller.emit(0)
controller.terminate()
with pytest.raises(ValueError, match=r"select_runnables\(\) returned more than one outlet named 'x'"):
controller.await_termination()


def test_parallel_execution():
runnables = [
RunnableBusyWait("busy1"),
Expand Down

0 comments on commit cbf7e32

Please sign in to comment.