diff --git a/CHANGELOG.md b/CHANGELOG.md index fd9fe5fe0..d12288280 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] + +### Fixed + +- Fixed naming of collection nodes (was breaking postprocessing) +- Restored compatibility with stable release of AWS executors + ## [0.212.0-rc.0] - 2023-01-13 ### Authors diff --git a/covalent/_workflow/electron.py b/covalent/_workflow/electron.py index a14d8e5df..501f29cc4 100644 --- a/covalent/_workflow/electron.py +++ b/covalent/_workflow/electron.py @@ -34,6 +34,8 @@ from .._shared_files.defaults import ( WAIT_EDGE_NAME, DefaultMetadataValues, + electron_dict_prefix, + electron_list_prefix, parameter_prefix, prefix_separator, sublattice_prefix, @@ -409,8 +411,13 @@ def connect_node_with_others( ) elif isinstance(param_value, list): + + def _auto_list_node(*args, **kwargs): + return list(args) + list_electron = Electron(function=_auto_list_node, metadata=collection_metadata) - list_electron(*param_value) + bound_electron = list_electron(*param_value) + transport_graph.set_node_value(bound_electron.node_id, "name", electron_list_prefix) transport_graph.add_edge( list_electron.node_id, node_id, @@ -420,8 +427,13 @@ def connect_node_with_others( ) elif isinstance(param_value, dict): + + def _auto_dict_node(*args, **kwargs): + return dict(kwargs) + dict_electron = Electron(function=_auto_dict_node, metadata=collection_metadata) - dict_electron(**param_value) + bound_electron = dict_electron(**param_value) + transport_graph.set_node_value(bound_electron.node_id, "name", electron_dict_prefix) transport_graph.add_edge( dict_electron.node_id, node_id, @@ -650,11 +662,3 @@ def to_decoded_electron_collection(**x): return TransportableObject.deserialize_list(collection) elif isinstance(collection, dict): return TransportableObject.deserialize_dict(collection) - - -def _auto_list_node(*args, **kwargs): - return list(args) - - -def _auto_dict_node(*args, **kwargs): - return dict(kwargs) diff --git a/tests/functional_tests/workflow_stack_test.py b/tests/functional_tests/workflow_stack_test.py index 32e649ae1..ec91727e3 100644 --- a/tests/functional_tests/workflow_stack_test.py +++ b/tests/functional_tests/workflow_stack_test.py @@ -786,3 +786,53 @@ def workflow(): workflow_result = rm.get_result(dispatch_id, wait=True) assert workflow_result.result == 25 rm._delete_result(dispatch_id) + + +def test_workflows_with_list_nodes(): + """Test workflows with auto generated list nodes""" + + @ct.electron + def sum_array(arr): + return sum(arr) + + @ct.electron + def square(x): + return x * x + + @ct.lattice + def workflow(x): + res_1 = sum_array(x) + return square(res_1) + + dispatch_id = ct.dispatch(workflow)([1, 2, 3]) + + res_obj = rm.get_result(dispatch_id, wait=True) + + assert res_obj.result == 36 + + rm._delete_result(dispatch_id) + + +def test_workflows_with_dict_nodes(): + """Test workflows with auto generated dictionary nodes""" + + @ct.electron + def sum_values(assoc_array): + return sum(assoc_array.values()) + + @ct.electron + def square(x): + return x * x + + @ct.lattice + def workflow(x): + res_1 = sum_values(x) + return square(res_1) + + dispatch_id = ct.dispatch(workflow)({"x": 1, "y": 2, "z": 3}) + + res_obj = rm.get_result(dispatch_id, wait=True) + + assert res_obj.result == 36 + + rm._delete_result(dispatch_id)