Skip to content

Commit

Permalink
QA 0.209.0: Fix collection nodes (#1469)
Browse files Browse the repository at this point in the history
* Fix collection node name prefixes

* Fix collection nodes

* Add tests

* Update changelog
  • Loading branch information
cjao authored Jan 13, 2023
1 parent 493d988 commit 7676916
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 10 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [UNRELEASED]


### Fixed

- Fixed naming of collection nodes (was breaking postprocessing)
- Restored compatibility with stable release of AWS executors

## [0.212.0-rc.0] - 2023-01-13

### Authors
Expand Down
24 changes: 14 additions & 10 deletions covalent/_workflow/electron.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from .._shared_files.defaults import (
WAIT_EDGE_NAME,
DefaultMetadataValues,
electron_dict_prefix,
electron_list_prefix,
parameter_prefix,
prefix_separator,
sublattice_prefix,
Expand Down Expand Up @@ -409,8 +411,13 @@ def connect_node_with_others(
)

elif isinstance(param_value, list):

def _auto_list_node(*args, **kwargs):
return list(args)

list_electron = Electron(function=_auto_list_node, metadata=collection_metadata)
list_electron(*param_value)
bound_electron = list_electron(*param_value)
transport_graph.set_node_value(bound_electron.node_id, "name", electron_list_prefix)
transport_graph.add_edge(
list_electron.node_id,
node_id,
Expand All @@ -420,8 +427,13 @@ def connect_node_with_others(
)

elif isinstance(param_value, dict):

def _auto_dict_node(*args, **kwargs):
return dict(kwargs)

dict_electron = Electron(function=_auto_dict_node, metadata=collection_metadata)
dict_electron(**param_value)
bound_electron = dict_electron(**param_value)
transport_graph.set_node_value(bound_electron.node_id, "name", electron_dict_prefix)
transport_graph.add_edge(
dict_electron.node_id,
node_id,
Expand Down Expand Up @@ -650,11 +662,3 @@ def to_decoded_electron_collection(**x):
return TransportableObject.deserialize_list(collection)
elif isinstance(collection, dict):
return TransportableObject.deserialize_dict(collection)


def _auto_list_node(*args, **kwargs):
return list(args)


def _auto_dict_node(*args, **kwargs):
return dict(kwargs)
50 changes: 50 additions & 0 deletions tests/functional_tests/workflow_stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,3 +786,53 @@ def workflow():
workflow_result = rm.get_result(dispatch_id, wait=True)
assert workflow_result.result == 25
rm._delete_result(dispatch_id)


def test_workflows_with_list_nodes():
"""Test workflows with auto generated list nodes"""

@ct.electron
def sum_array(arr):
return sum(arr)

@ct.electron
def square(x):
return x * x

@ct.lattice
def workflow(x):
res_1 = sum_array(x)
return square(res_1)

dispatch_id = ct.dispatch(workflow)([1, 2, 3])

res_obj = rm.get_result(dispatch_id, wait=True)

assert res_obj.result == 36

rm._delete_result(dispatch_id)


def test_workflows_with_dict_nodes():
"""Test workflows with auto generated dictionary nodes"""

@ct.electron
def sum_values(assoc_array):
return sum(assoc_array.values())

@ct.electron
def square(x):
return x * x

@ct.lattice
def workflow(x):
res_1 = sum_values(x)
return square(res_1)

dispatch_id = ct.dispatch(workflow)({"x": 1, "y": 2, "z": 3})

res_obj = rm.get_result(dispatch_id, wait=True)

assert res_obj.result == 36

rm._delete_result(dispatch_id)

0 comments on commit 7676916

Please sign in to comment.