Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise error in Choice on duplicate outlets #545

Merged
merged 3 commits into from
Nov 17, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from asyncio import Task
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union
from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Set, Union

import aiohttp

Expand Down Expand Up @@ -363,19 +363,21 @@ def _init(self):
# TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget
self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"]

def select_outlets(self, event) -> List[str]:
def select_outlets(self, event) -> Collection[str]:
"""
Override this method to route events based on a customer logic. The default implementation will route all
events to all outlets.
"""
return list(self._name_to_outlet.keys())
return self._name_to_outlet.keys()

async def _do(self, event):
if event is _termination_obj:
return await self._do_downstream(_termination_obj)
else:
event_body = event if self._full_event else event.body
outlet_names = self.select_outlets(event_body)
if len(set(outlet_names)) != len(outlet_names):
raise ValueError(f"select_outlets() returned duplicate outlets: {outlet_names}")
assaf758 marked this conversation as resolved.
Show resolved Hide resolved
outlets = []
if self._passthrough_for_preview:
outlet = self._name_to_outlet["dataframe"]
Expand Down
Loading