From d8416ff6151a58f2c2609e5a386e8ce4c7d2b893 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Wed, 13 Nov 2024 19:56:47 +0800 Subject: [PATCH 1/3] Raise error in `Choice` on duplicate outlets --- storey/flow.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index 7f2df3f0..f55679be 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -23,7 +23,7 @@ from asyncio import Task from collections import defaultdict from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union +from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Set, Union import aiohttp @@ -363,12 +363,12 @@ def _init(self): # TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"] - def select_outlets(self, event) -> List[str]: + def select_outlets(self, event) -> Collection[str]: """ Override this method to route events based on a customer logic. The default implementation will route all events to all outlets. """ - return list(self._name_to_outlet.keys()) + return self._name_to_outlet.keys() async def _do(self, event): if event is _termination_obj: @@ -376,6 +376,8 @@ async def _do(self, event): else: event_body = event if self._full_event else event.body outlet_names = self.select_outlets(event_body) + if len(set(outlet_names)) != len(outlet_names): + raise ValueError(f"select_outlets() returned duplicate outlets: {outlet_names}") outlets = [] if self._passthrough_for_preview: outlet = self._name_to_outlet["dataframe"] From 999b037b9bdc516d53a504ba2824553b580bb9d5 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 14 Nov 2024 12:55:16 +0800 Subject: [PATCH 2/3] Move validation --- storey/flow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index f55679be..e1381c35 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -376,13 +376,13 @@ async def _do(self, event): else: event_body = event if self._full_event else event.body outlet_names = self.select_outlets(event_body) - if len(set(outlet_names)) != len(outlet_names): - raise ValueError(f"select_outlets() returned duplicate outlets: {outlet_names}") outlets = [] if self._passthrough_for_preview: outlet = self._name_to_outlet["dataframe"] outlets.append(outlet) else: + if len(set(outlet_names)) != len(outlet_names): + raise ValueError(f"select_outlets() returned duplicate outlets: {outlet_names}") for outlet_name in outlet_names: if outlet_name not in self._name_to_outlet: raise ValueError( From 42c4ff219682540986b15fda5c70771f4fd551b6 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 14 Nov 2024 13:07:41 +0800 Subject: [PATCH 3/3] Add tests, improve error messages --- storey/flow.py | 5 ++++- tests/test_flow.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/storey/flow.py b/storey/flow.py index e1381c35..6956df30 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -382,7 +382,10 @@ async def _do(self, event): outlets.append(outlet) else: if len(set(outlet_names)) != len(outlet_names): - raise ValueError(f"select_outlets() returned duplicate outlets: {outlet_names}") + raise ValueError( + "select_outlets() returned duplicate outlets among the defined outlets: " + + ", ".join(outlet_names) + ) for outlet_name in outlet_names: if outlet_name not in self._name_to_outlet: raise ValueError( diff --git a/tests/test_flow.py b/tests/test_flow.py index 755597ce..9911efab 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -1728,6 +1728,50 @@ def select_outlets(self, event): assert termination_result == expected +def test_duplicate_choice(): + class DuplicateChoice(Choice): + def select_outlets(self, event): + outlets = ["all_events", "all_events"] + return outlets + + source = SyncEmitSource() + duplicate_choice = DuplicateChoice(termination_result_fn=lambda x, y: x + y) + all_events = Map(lambda x: x, name="all_events") + + source.to(duplicate_choice).to(all_events) + + controller = source.run() + controller.emit(0) + controller.terminate() + with pytest.raises( + ValueError, + match=r"select_outlets\(\) returned duplicate outlets among the defined outlets: all_events, all_events", + ): + controller.await_termination() + + +def test_nonexistent_choice(): + class NonexistentChoice(Choice): + def select_outlets(self, event): + outlets = ["wrong"] + return outlets + + source = SyncEmitSource() + nonexistent_choice = NonexistentChoice(termination_result_fn=lambda x, y: x + y) + all_events = Map(lambda x: x, name="all_events") + + source.to(nonexistent_choice).to(all_events) + + controller = source.run() + controller.emit(0) + controller.terminate() + with pytest.raises( + ValueError, + match=r"select_outlets\(\) returned outlet name 'wrong', which is not one of the defined outlets: all_events", + ): + controller.await_termination() + + def test_metadata(): def mapf(x): x.key = x.key + 1