diff --git a/tests/conftest.py b/tests/conftest.py index 6b781cb..96df847 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,8 @@ import networkx as nx from pathlib import Path from networkx.algorithms.components import is_connected + +from tucan.graph_attributes import ATOMIC_NUMBER, ELEMENT_SYMBOL from tucan.io import graph_from_file @@ -25,8 +27,8 @@ def graph_from_dimacs(filepath): graph = nx.Graph() graph.add_nodes_from(node_labels) - nx.set_node_attributes(graph, "C", "element_symbol") - nx.set_node_attributes(graph, 6, "atomic_number") + nx.set_node_attributes(graph, "C", ELEMENT_SYMBOL) + nx.set_node_attributes(graph, 6, ATOMIC_NUMBER) graph.add_edges_from(bonds) return graph diff --git a/tests/io/test_molfile_v2000_reader.py b/tests/io/test_molfile_v2000_reader.py index d364d72..3b3efe9 100644 --- a/tests/io/test_molfile_v2000_reader.py +++ b/tests/io/test_molfile_v2000_reader.py @@ -1,9 +1,22 @@ import pytest from pathlib import Path + +from tucan.graph_attributes import ( + ATOMIC_NUMBER, + CHG, + ELEMENT_SYMBOL, + MASS, + PARTITION, + RAD, + X_COORD, + Y_COORD, + Z_COORD, +) from tucan.io import graph_from_file from tucan.io.exception import MolfileParserException from tucan.io.molfile_v2000_reader import ( _merge_tuples_into_additional_attributes, + _parse_atom_line, _parse_atom_value_assignments, _to_int, _to_float, @@ -27,6 +40,63 @@ def test_graphs_from_v2000_and_v3000_molfiles_match(mol): assert e1 == e2 +@pytest.mark.parametrize( + "line, expected_additional_attr", + [ + ( + " 0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0", + {}, + ), + ( + " 0.0000 0.0000 0.0000 C 0 1 0 0 0 0 0 0 0 0 0 0", + {CHG: 3}, + ), + ( + " 0.0000 0.0000 0.0000 C 0 2 0 0 0 0 0 0 0 0 0 0", + {CHG: 2}, + ), + ( + " 0.0000 0.0000 0.0000 C 0 3 0 0 0 0 0 0 0 0 0 0", + {CHG: 1}, + ), + ( + " 0.0000 0.0000 0.0000 C 0 4 0 0 0 0 0 0 0 0 0 0", + {RAD: 2}, + ), + ( + " 0.0000 0.0000 0.0000 C 0 5 0 0 0 0 0 0 0 0 0 0", + {CHG: -1}, + ), + ( + " 0.0000 0.0000 0.0000 C 0 6 0 0 0 0 0 0 0 0 0 0", + {CHG: -2}, + ), + ( + " 0.0000 0.0000 0.0000 C 0 7 0 0 0 0 0 0 0 0 0 0", + {CHG: -3}, + ), + ( + " 0.0000 0.0000 0.0000 C 0 8 0 0 0 0 0 0 0 0 0 0", + {}, # ignored + ), + ], +) +def test_parse_atom_line_charge_field(line, expected_additional_attr): + expected_attrs = { + ELEMENT_SYMBOL: "C", + ATOMIC_NUMBER: 6, + PARTITION: 0, + X_COORD: 0, + Y_COORD: 0, + Z_COORD: 0, + } + expected_attrs.update(expected_additional_attr) + + atom_attrs = _parse_atom_line(line) + + assert atom_attrs == expected_attrs + + @pytest.mark.parametrize( "tuples, additional_attrs, expected_additional_attrs_after_merge", [ @@ -37,13 +107,13 @@ def test_graphs_from_v2000_and_v3000_molfiles_match(mol): (2, 3), # atom index is not in additional_attrs yet ], { - 0: {"chg": 2}, # will add new key - 1: {"mass": 1}, # will overwrite value + 0: {CHG: 2}, # will add new key + 1: {MASS: 1}, # will overwrite value }, { - 0: {"chg": 2, "mass": 2}, - 1: {"mass": 13}, - 2: {"mass": 3}, + 0: {CHG: 2, MASS: 2}, + 1: {MASS: 13}, + 2: {MASS: 3}, }, ), ], @@ -51,7 +121,7 @@ def test_graphs_from_v2000_and_v3000_molfiles_match(mol): def test_merge_tuples_into_additional_attributes( tuples, additional_attrs, expected_additional_attrs_after_merge ): - _merge_tuples_into_additional_attributes(tuples, "mass", additional_attrs) + _merge_tuples_into_additional_attributes(tuples, MASS, additional_attrs) assert additional_attrs == expected_additional_attrs_after_merge diff --git a/tests/io/test_molfile_v3000_reader.py b/tests/io/test_molfile_v3000_reader.py index 45ceb8f..90508fe 100644 --- a/tests/io/test_molfile_v3000_reader.py +++ b/tests/io/test_molfile_v3000_reader.py @@ -1,5 +1,16 @@ import pytest import re + +from tucan.graph_attributes import ( + ATOMIC_NUMBER, + BOND_TYPE, + CHG, + ELEMENT_SYMBOL, + PARTITION, + X_COORD, + Y_COORD, + Z_COORD, +) from tucan.io import graph_from_file, graph_from_molfile_text from tucan.io.exception import MolfileParserException from tucan.io.molfile_v3000_reader import ( @@ -21,154 +32,154 @@ def test_parsing_atom_block(): atom_attrs, star_atoms = _parse_atom_block(filecontent) assert atom_attrs == { 0: { - "element_symbol": "C", - "atomic_number": 6, - "partition": 0, - "x_coord": 11.137, - "y_coord": -9.481, - "z_coord": 0.0, + ELEMENT_SYMBOL: "C", + ATOMIC_NUMBER: 6, + PARTITION: 0, + X_COORD: 11.137, + Y_COORD: -9.481, + Z_COORD: 0.0, }, 1: { - "element_symbol": "C", - "atomic_number": 6, - "partition": 0, - "x_coord": 12.1745, - "y_coord": -9.4807, - "z_coord": 0.0, + ELEMENT_SYMBOL: "C", + ATOMIC_NUMBER: 6, + PARTITION: 0, + X_COORD: 12.1745, + Y_COORD: -9.4807, + Z_COORD: 0.0, }, 2: { - "element_symbol": "C", - "atomic_number": 6, - "partition": 0, - "x_coord": 11.6567, - "y_coord": -9.1811, - "z_coord": 0.0, + ELEMENT_SYMBOL: "C", + ATOMIC_NUMBER: 6, + PARTITION: 0, + X_COORD: 11.6567, + Y_COORD: -9.1811, + Z_COORD: 0.0, }, 3: { - "element_symbol": "C", - "atomic_number": 6, - "partition": 0, - "x_coord": 12.1745, - "y_coord": -10.0809, - "z_coord": 0.0, + ELEMENT_SYMBOL: "C", + ATOMIC_NUMBER: 6, + PARTITION: 0, + X_COORD: 12.1745, + Y_COORD: -10.0809, + Z_COORD: 0.0, }, 4: { - "element_symbol": "C", - "atomic_number": 6, - "partition": 0, - "x_coord": 11.137, - "y_coord": -10.0835, - "z_coord": 0.0, + ELEMENT_SYMBOL: "C", + ATOMIC_NUMBER: 6, + PARTITION: 0, + X_COORD: 11.137, + Y_COORD: -10.0835, + Z_COORD: 0.0, }, 5: { - "element_symbol": "C", - "atomic_number": 6, - "partition": 0, - "x_coord": 11.658, - "y_coord": -10.3804, - "z_coord": 0.0, + ELEMENT_SYMBOL: "C", + ATOMIC_NUMBER: 6, + PARTITION: 0, + X_COORD: 11.658, + Y_COORD: -10.3804, + Z_COORD: 0.0, }, 6: { - "element_symbol": "N", - "atomic_number": 7, - "partition": 0, - "x_coord": 11.6691, - "y_coord": -7.3712, - "z_coord": 0.0, - "chg": 1, + ELEMENT_SYMBOL: "N", + ATOMIC_NUMBER: 7, + PARTITION: 0, + X_COORD: 11.6691, + Y_COORD: -7.3712, + Z_COORD: 0.0, + CHG: 1, }, 7: { - "element_symbol": "O", - "atomic_number": 8, - "partition": 0, - "x_coord": 12.1887, - "y_coord": -7.0712, - "z_coord": 0.0, - "chg": -1, + ELEMENT_SYMBOL: "O", + ATOMIC_NUMBER: 8, + PARTITION: 0, + X_COORD: 12.1887, + Y_COORD: -7.0712, + Z_COORD: 0.0, + CHG: -1, }, 8: { - "element_symbol": "O", - "atomic_number": 8, - "partition": 0, - "x_coord": 11.1495, - "y_coord": -7.0712, - "z_coord": 0.0, + ELEMENT_SYMBOL: "O", + ATOMIC_NUMBER: 8, + PARTITION: 0, + X_COORD: 11.1495, + Y_COORD: -7.0712, + Z_COORD: 0.0, }, 9: { - "element_symbol": "N", - "atomic_number": 7, - "partition": 0, - "x_coord": 8.8633, - "y_coord": -11.1246, - "z_coord": 0.0, - "chg": 1, + ELEMENT_SYMBOL: "N", + ATOMIC_NUMBER: 7, + PARTITION: 0, + X_COORD: 8.8633, + Y_COORD: -11.1246, + Z_COORD: 0.0, + CHG: 1, }, 10: { - "element_symbol": "O", - "atomic_number": 8, - "partition": 0, - "x_coord": 9.0299, - "y_coord": -12.4412, - "z_coord": 0.0, - "chg": -1, + ELEMENT_SYMBOL: "O", + ATOMIC_NUMBER: 8, + PARTITION: 0, + X_COORD: 9.0299, + Y_COORD: -12.4412, + Z_COORD: 0.0, + CHG: -1, }, 11: { - "element_symbol": "O", - "atomic_number": 8, - "partition": 0, - "x_coord": 8.3437, - "y_coord": -10.8246, - "z_coord": 0.0, + ELEMENT_SYMBOL: "O", + ATOMIC_NUMBER: 8, + PARTITION: 0, + X_COORD: 8.3437, + Y_COORD: -10.8246, + Z_COORD: 0.0, }, 12: { - "element_symbol": "N", - "atomic_number": 7, - "partition": 0, - "x_coord": 13.8431, - "y_coord": -11.1804, - "z_coord": 0.0, - "chg": 1, + ELEMENT_SYMBOL: "N", + ATOMIC_NUMBER: 7, + PARTITION: 0, + X_COORD: 13.8431, + Y_COORD: -11.1804, + Z_COORD: 0.0, + CHG: 1, }, 13: { - "element_symbol": "O", - "atomic_number": 8, - "partition": 0, - "x_coord": 14.3627, - "y_coord": -10.8804, - "z_coord": 0.0, - "chg": -1, + ELEMENT_SYMBOL: "O", + ATOMIC_NUMBER: 8, + PARTITION: 0, + X_COORD: 14.3627, + Y_COORD: -10.8804, + Z_COORD: 0.0, + CHG: -1, }, 14: { - "element_symbol": "O", - "atomic_number": 8, - "partition": 0, - "x_coord": 13.3607, - "y_coord": -12.0324, - "z_coord": 0.0, + ELEMENT_SYMBOL: "O", + ATOMIC_NUMBER: 8, + PARTITION: 0, + X_COORD: 13.3607, + Y_COORD: -12.0324, + Z_COORD: 0.0, }, 15: { - "element_symbol": "H", - "atomic_number": 1, - "partition": 0, - "x_coord": 9.4208, - "y_coord": -8.4533, - "z_coord": 0.0, + ELEMENT_SYMBOL: "H", + ATOMIC_NUMBER: 1, + PARTITION: 0, + X_COORD: 9.4208, + Y_COORD: -8.4533, + Z_COORD: 0.0, }, 16: { - "element_symbol": "H", - "atomic_number": 1, - "partition": 0, - "x_coord": 14.0661, - "y_coord": -8.4162, - "z_coord": 0.0, + ELEMENT_SYMBOL: "H", + ATOMIC_NUMBER: 1, + PARTITION: 0, + X_COORD: 14.0661, + Y_COORD: -8.4162, + Z_COORD: 0.0, }, 17: { - "element_symbol": "H", - "atomic_number": 1, - "partition": 0, - "x_coord": 11.2046, - "y_coord": -12.0581, - "z_coord": 0.0, + ELEMENT_SYMBOL: "H", + ATOMIC_NUMBER: 1, + PARTITION: 0, + X_COORD: 11.2046, + Y_COORD: -12.0581, + Z_COORD: 0.0, }, } assert len(star_atoms) == 0 @@ -178,7 +189,7 @@ def test_graph_from_file_with_multi_attachment(): graph = graph_from_file( "tests/molfiles/chromocene-multi-attachment/chromocene-multi-attachment.mol" ) - node_indices, elements = zip(*graph.nodes.data("element_symbol")) + node_indices, elements = zip(*graph.nodes.data(ELEMENT_SYMBOL)) # "star" atoms do not end up as graph nodes assert "".join(elements) == 10 * "C" + "Cr" + 10 * "H" @@ -187,36 +198,36 @@ def test_graph_from_file_with_multi_attachment(): assert node_indices == tuple(range(21)) assert list(graph.edges(data=True)) == [ - (0, 10, {"bond_type": 1}), - (0, 1, {"bond_type": 4}), - (0, 4, {"bond_type": 4}), - (0, 12, {"bond_type": 1}), - (1, 10, {"bond_type": 1}), - (1, 2, {"bond_type": 4}), - (1, 11, {"bond_type": 1}), - (2, 10, {"bond_type": 1}), - (2, 3, {"bond_type": 4}), - (2, 20, {"bond_type": 1}), - (3, 10, {"bond_type": 1}), - (3, 4, {"bond_type": 4}), - (3, 14, {"bond_type": 1}), - (4, 10, {"bond_type": 1}), - (4, 13, {"bond_type": 1}), - (5, 6, {"bond_type": 4}), - (5, 9, {"bond_type": 4}), - (5, 10, {"bond_type": 1}), - (5, 15, {"bond_type": 1}), - (6, 7, {"bond_type": 4}), - (6, 10, {"bond_type": 1}), - (6, 18, {"bond_type": 1}), - (7, 8, {"bond_type": 4}), - (7, 10, {"bond_type": 1}), - (7, 19, {"bond_type": 1}), - (8, 9, {"bond_type": 4}), - (8, 10, {"bond_type": 1}), - (8, 17, {"bond_type": 1}), - (9, 10, {"bond_type": 1}), - (9, 16, {"bond_type": 1}), + (0, 10, {BOND_TYPE: 1}), + (0, 1, {BOND_TYPE: 4}), + (0, 4, {BOND_TYPE: 4}), + (0, 12, {BOND_TYPE: 1}), + (1, 10, {BOND_TYPE: 1}), + (1, 2, {BOND_TYPE: 4}), + (1, 11, {BOND_TYPE: 1}), + (2, 10, {BOND_TYPE: 1}), + (2, 3, {BOND_TYPE: 4}), + (2, 20, {BOND_TYPE: 1}), + (3, 10, {BOND_TYPE: 1}), + (3, 4, {BOND_TYPE: 4}), + (3, 14, {BOND_TYPE: 1}), + (4, 10, {BOND_TYPE: 1}), + (4, 13, {BOND_TYPE: 1}), + (5, 6, {BOND_TYPE: 4}), + (5, 9, {BOND_TYPE: 4}), + (5, 10, {BOND_TYPE: 1}), + (5, 15, {BOND_TYPE: 1}), + (6, 7, {BOND_TYPE: 4}), + (6, 10, {BOND_TYPE: 1}), + (6, 18, {BOND_TYPE: 1}), + (7, 8, {BOND_TYPE: 4}), + (7, 10, {BOND_TYPE: 1}), + (7, 19, {BOND_TYPE: 1}), + (8, 9, {BOND_TYPE: 4}), + (8, 10, {BOND_TYPE: 1}), + (8, 17, {BOND_TYPE: 1}), + (9, 10, {BOND_TYPE: 1}), + (9, 16, {BOND_TYPE: 1}), ] @@ -256,24 +267,24 @@ def test_parsing_bond_block(): filecontent = _read_file("tests/molfiles/tnt/tnt.mol") bond_attrs = _parse_bond_block(filecontent, []) assert bond_attrs == { - (0, 4): {"bond_type": 1}, - (0, 2): {"bond_type": 2}, - (0, 15): {"bond_type": 1}, - (1, 2): {"bond_type": 1}, - (1, 3): {"bond_type": 2}, - (1, 16): {"bond_type": 1}, - (2, 6): {"bond_type": 1}, - (3, 5): {"bond_type": 1}, - (3, 12): {"bond_type": 1}, - (4, 5): {"bond_type": 2}, - (4, 9): {"bond_type": 1}, - (5, 17): {"bond_type": 1}, - (6, 8): {"bond_type": 2}, - (6, 7): {"bond_type": 1}, - (9, 11): {"bond_type": 2}, - (9, 10): {"bond_type": 1}, - (12, 14): {"bond_type": 2}, - (12, 13): {"bond_type": 1}, + (0, 4): {BOND_TYPE: 1}, + (0, 2): {BOND_TYPE: 2}, + (0, 15): {BOND_TYPE: 1}, + (1, 2): {BOND_TYPE: 1}, + (1, 3): {BOND_TYPE: 2}, + (1, 16): {BOND_TYPE: 1}, + (2, 6): {BOND_TYPE: 1}, + (3, 5): {BOND_TYPE: 1}, + (3, 12): {BOND_TYPE: 1}, + (4, 5): {BOND_TYPE: 2}, + (4, 9): {BOND_TYPE: 1}, + (5, 17): {BOND_TYPE: 1}, + (6, 8): {BOND_TYPE: 2}, + (6, 7): {BOND_TYPE: 1}, + (9, 11): {BOND_TYPE: 2}, + (9, 10): {BOND_TYPE: 1}, + (12, 14): {BOND_TYPE: 2}, + (12, 13): {BOND_TYPE: 1}, } diff --git a/tests/io/test_molfile_writer.py b/tests/io/test_molfile_writer.py index 6a0d146..429a604 100644 --- a/tests/io/test_molfile_writer.py +++ b/tests/io/test_molfile_writer.py @@ -1,4 +1,6 @@ import pytest + +from tucan.graph_attributes import X_COORD, Y_COORD, Z_COORD from tucan.io import graph_from_molfile_text, graph_to_molfile from tucan.io.molfile_writer import _add_header, _add_v30_line @@ -34,9 +36,9 @@ def test_recalculate_atom_coordinates(): ).nodes(data=True) for key, orig_atom_attrs in atoms_with_orig_coords: - assert orig_atom_attrs["x_coord"] != atoms_with_new_coords[key]["x_coord"] - assert orig_atom_attrs["y_coord"] != atoms_with_new_coords[key]["y_coord"] - assert atoms_with_new_coords[key]["z_coord"] == 0.0 + assert orig_atom_attrs[X_COORD] != atoms_with_new_coords[key][X_COORD] + assert orig_atom_attrs[Y_COORD] != atoms_with_new_coords[key][Y_COORD] + assert atoms_with_new_coords[key][Z_COORD] == 0.0 def test_add_header(): diff --git a/tests/parser/test_parser.py b/tests/parser/test_parser.py index f013b1c..cea2ced 100644 --- a/tests/parser/test_parser.py +++ b/tests/parser/test_parser.py @@ -1,5 +1,14 @@ import pytest import re + +from tucan.graph_attributes import ( + ATOMIC_NUMBER, + ELEMENT_SYMBOL, + INVARIANT_CODE, + MASS, + PARTITION, + RAD, +) from tucan.io import graph_from_tucan, TucanParserException from tucan.parser.parser import _prepare_parser, _walk_tree from tucan.test_utils import roundtrip_graph_tucan_graph_tucan_graph @@ -19,31 +28,31 @@ def _extract_atoms_from_sum_formula(s): ( "CHCl3", [ - {"element_symbol": "C", "atomic_number": 6, "partition": 0}, - {"element_symbol": "H", "atomic_number": 1, "partition": 0}, - {"element_symbol": "Cl", "atomic_number": 17, "partition": 0}, - {"element_symbol": "Cl", "atomic_number": 17, "partition": 0}, - {"element_symbol": "Cl", "atomic_number": 17, "partition": 0}, + {ELEMENT_SYMBOL: "C", ATOMIC_NUMBER: 6, PARTITION: 0}, + {ELEMENT_SYMBOL: "H", ATOMIC_NUMBER: 1, PARTITION: 0}, + {ELEMENT_SYMBOL: "Cl", ATOMIC_NUMBER: 17, PARTITION: 0}, + {ELEMENT_SYMBOL: "Cl", ATOMIC_NUMBER: 17, PARTITION: 0}, + {ELEMENT_SYMBOL: "Cl", ATOMIC_NUMBER: 17, PARTITION: 0}, ], ), ( "ClH", [ - {"element_symbol": "Cl", "atomic_number": 17, "partition": 0}, - {"element_symbol": "H", "atomic_number": 1, "partition": 0}, + {ELEMENT_SYMBOL: "Cl", ATOMIC_NUMBER: 17, PARTITION: 0}, + {ELEMENT_SYMBOL: "H", ATOMIC_NUMBER: 1, PARTITION: 0}, ], ), ( "Cu", [ - {"atomic_number": 29, "element_symbol": "Cu", "partition": 0}, + {ELEMENT_SYMBOL: "Cu", ATOMIC_NUMBER: 29, PARTITION: 0}, ], ), ( "CU", [ - {"atomic_number": 6, "element_symbol": "C", "partition": 0}, - {"atomic_number": 92, "element_symbol": "U", "partition": 0}, + {ELEMENT_SYMBOL: "C", ATOMIC_NUMBER: 6, PARTITION: 0}, + {ELEMENT_SYMBOL: "U", ATOMIC_NUMBER: 92, PARTITION: 0}, ], ), ], @@ -95,13 +104,13 @@ def _extract_node_attributes(s): "node_attributes, expected_node_attributes", [ ("", {}), - ("(1:mass=2)", {0: {"mass": 2}}), - ("(2:rad=5)", {1: {"rad": 5}}), - ("(1234:rad=5,mass=10)", {1233: {"mass": 10, "rad": 5}}), - ("(1:mass=10)(2:rad=1)", {0: {"mass": 10}, 1: {"rad": 1}}), + ("(1:mass=2)", {0: {MASS: 2}}), + ("(2:rad=5)", {1: {RAD: 5}}), + ("(1234:rad=5,mass=10)", {1233: {MASS: 10, RAD: 5}}), + ("(1:mass=10)(2:rad=1)", {0: {MASS: 10}, 1: {RAD: 1}}), ( "(1:mass=123456789)(1:rad=987654321)", - {0: {"mass": 123456789, "rad": 987654321}}, + {0: {MASS: 123456789, RAD: 987654321}}, ), ], ) @@ -133,58 +142,58 @@ def test_overriding_node_attribute_raises_exception( "C2H6O/(1-7)(2-7)(3-7)(4-8)(5-8)(6-9)(7-8)(8-9)", { 0: { - "element_symbol": "H", - "atomic_number": 1, - "partition": 0, - "invariant_code": (1, 0, 0), + ELEMENT_SYMBOL: "H", + ATOMIC_NUMBER: 1, + PARTITION: 0, + INVARIANT_CODE: (1, 0, 0), }, 1: { - "element_symbol": "H", - "atomic_number": 1, - "partition": 0, - "invariant_code": (1, 0, 0), + ELEMENT_SYMBOL: "H", + ATOMIC_NUMBER: 1, + PARTITION: 0, + INVARIANT_CODE: (1, 0, 0), }, 2: { - "element_symbol": "H", - "atomic_number": 1, - "partition": 0, - "invariant_code": (1, 0, 0), + ELEMENT_SYMBOL: "H", + ATOMIC_NUMBER: 1, + PARTITION: 0, + INVARIANT_CODE: (1, 0, 0), }, 3: { - "element_symbol": "H", - "atomic_number": 1, - "partition": 0, - "invariant_code": (1, 0, 0), + ELEMENT_SYMBOL: "H", + ATOMIC_NUMBER: 1, + PARTITION: 0, + INVARIANT_CODE: (1, 0, 0), }, 4: { - "element_symbol": "H", - "atomic_number": 1, - "partition": 0, - "invariant_code": (1, 0, 0), + ELEMENT_SYMBOL: "H", + ATOMIC_NUMBER: 1, + PARTITION: 0, + INVARIANT_CODE: (1, 0, 0), }, 5: { - "element_symbol": "H", - "atomic_number": 1, - "partition": 0, - "invariant_code": (1, 0, 0), + ELEMENT_SYMBOL: "H", + ATOMIC_NUMBER: 1, + PARTITION: 0, + INVARIANT_CODE: (1, 0, 0), }, 6: { - "element_symbol": "C", - "atomic_number": 6, - "partition": 0, - "invariant_code": (6, 0, 0), + ELEMENT_SYMBOL: "C", + ATOMIC_NUMBER: 6, + PARTITION: 0, + INVARIANT_CODE: (6, 0, 0), }, 7: { - "element_symbol": "C", - "atomic_number": 6, - "partition": 0, - "invariant_code": (6, 0, 0), + ELEMENT_SYMBOL: "C", + ATOMIC_NUMBER: 6, + PARTITION: 0, + INVARIANT_CODE: (6, 0, 0), }, 8: { - "element_symbol": "O", - "atomic_number": 8, - "partition": 0, - "invariant_code": (8, 0, 0), + ELEMENT_SYMBOL: "O", + ATOMIC_NUMBER: 8, + PARTITION: 0, + INVARIANT_CODE: (8, 0, 0), }, }, [(0, 6), (1, 6), (2, 6), (3, 7), (4, 7), (5, 8), (6, 7), (7, 8)], @@ -193,10 +202,10 @@ def test_overriding_node_attribute_raises_exception( "Xe/", { 0: { - "element_symbol": "Xe", - "atomic_number": 54, - "partition": 0, - "invariant_code": (54, 0, 0), + ELEMENT_SYMBOL: "Xe", + ATOMIC_NUMBER: 54, + PARTITION: 0, + INVARIANT_CODE: (54, 0, 0), } }, [], @@ -205,50 +214,50 @@ def test_overriding_node_attribute_raises_exception( "C2H4O/(1-5)(2-5)(3-5)(4-7)(5-6)(6-7)/(4:mass=2)(5:mass=14)(6:rad=3)(7:mass=17)", { 0: { - "element_symbol": "H", - "atomic_number": 1, - "partition": 0, - "invariant_code": (1, 0, 0), + ELEMENT_SYMBOL: "H", + ATOMIC_NUMBER: 1, + PARTITION: 0, + INVARIANT_CODE: (1, 0, 0), }, 1: { - "element_symbol": "H", - "atomic_number": 1, - "partition": 0, - "invariant_code": (1, 0, 0), + ELEMENT_SYMBOL: "H", + ATOMIC_NUMBER: 1, + PARTITION: 0, + INVARIANT_CODE: (1, 0, 0), }, 2: { - "element_symbol": "H", - "atomic_number": 1, - "partition": 0, - "invariant_code": (1, 0, 0), + ELEMENT_SYMBOL: "H", + ATOMIC_NUMBER: 1, + PARTITION: 0, + INVARIANT_CODE: (1, 0, 0), }, 3: { - "element_symbol": "H", - "atomic_number": 1, - "partition": 0, - "mass": 2, - "invariant_code": (1, 2, 0), + ELEMENT_SYMBOL: "H", + ATOMIC_NUMBER: 1, + PARTITION: 0, + MASS: 2, + INVARIANT_CODE: (1, 2, 0), }, 4: { - "element_symbol": "C", - "atomic_number": 6, - "partition": 0, - "mass": 14, - "invariant_code": (6, 14, 0), + ELEMENT_SYMBOL: "C", + ATOMIC_NUMBER: 6, + PARTITION: 0, + MASS: 14, + INVARIANT_CODE: (6, 14, 0), }, 5: { - "element_symbol": "C", - "atomic_number": 6, - "partition": 0, - "rad": 3, - "invariant_code": (6, 0, 3), + ELEMENT_SYMBOL: "C", + ATOMIC_NUMBER: 6, + PARTITION: 0, + RAD: 3, + INVARIANT_CODE: (6, 0, 3), }, 6: { - "element_symbol": "O", - "atomic_number": 8, - "partition": 0, - "mass": 17, - "invariant_code": (8, 17, 0), + ELEMENT_SYMBOL: "O", + ATOMIC_NUMBER: 8, + PARTITION: 0, + MASS: 17, + INVARIANT_CODE: (8, 17, 0), }, }, [(0, 4), (1, 4), (2, 4), (3, 6), (4, 5), (5, 6)], @@ -257,18 +266,18 @@ def test_overriding_node_attribute_raises_exception( "CH/(2-1)(1-2)/(2:rad=3,mass=13)", { 0: { - "element_symbol": "H", - "atomic_number": 1, - "partition": 0, - "invariant_code": (1, 0, 0), + ELEMENT_SYMBOL: "H", + ATOMIC_NUMBER: 1, + PARTITION: 0, + INVARIANT_CODE: (1, 0, 0), }, 1: { - "element_symbol": "C", - "atomic_number": 6, - "partition": 0, - "rad": 3, - "mass": 13, - "invariant_code": (6, 13, 3), + ELEMENT_SYMBOL: "C", + ATOMIC_NUMBER: 6, + PARTITION: 0, + RAD: 3, + MASS: 13, + INVARIANT_CODE: (6, 13, 3), }, }, [(0, 1)], diff --git a/tests/test_canonicalization.py b/tests/test_canonicalization.py index c742f9a..65717d9 100644 --- a/tests/test_canonicalization.py +++ b/tests/test_canonicalization.py @@ -3,6 +3,7 @@ partition_molecule_by_attribute, refine_partitions, ) +from tucan.graph_attributes import ATOMIC_NUMBER, PARTITION from tucan.serialization import serialize_molecule from tucan.graph_utils import permute_molecule from tucan.test_utils import permutation_invariance @@ -21,31 +22,31 @@ ) def test_partition_molecule_by_attribute(m, expected_partitions): m_partitioned = partition_molecule_by_attribute( - graph_from_file(f"tests/molfiles/{m}/{m}.mol"), "atomic_number" + graph_from_file(f"tests/molfiles/{m}/{m}.mol"), ATOMIC_NUMBER ) - partitions = sorted(nx.get_node_attributes(m_partitioned, "partition").values()) + partitions = sorted(nx.get_node_attributes(m_partitioned, PARTITION).values()) assert partitions == expected_partitions def test_partition_molecule_by_attribute_is_stable(m): - m_partitioned = partition_molecule_by_attribute(m, "atomic_number") - m_re_partitioned = partition_molecule_by_attribute(m_partitioned, "atomic_number") + m_partitioned = partition_molecule_by_attribute(m, ATOMIC_NUMBER) + m_re_partitioned = partition_molecule_by_attribute(m_partitioned, ATOMIC_NUMBER) - assert sorted( - nx.get_node_attributes(m_partitioned, "partition").values() - ) == sorted(nx.get_node_attributes(m_re_partitioned, "partition").values()) + assert sorted(nx.get_node_attributes(m_partitioned, PARTITION).values()) == sorted( + nx.get_node_attributes(m_re_partitioned, PARTITION).values() + ) def test_refine_partitions(m): - m_partitioned = partition_molecule_by_attribute(m, "atomic_number") + m_partitioned = partition_molecule_by_attribute(m, ATOMIC_NUMBER) m_refined = list(refine_partitions(m_partitioned)) assert all( max_p_i < max_p_j for max_p_i, max_p_j in pairwise( ( - max(nx.get_node_attributes(m_rfnd, "partition").values()) + max(nx.get_node_attributes(m_rfnd, PARTITION).values()) for m_rfnd in m_refined ) ) diff --git a/tucan/canonicalization.py b/tucan/canonicalization.py index cab9e53..7fcf778 100644 --- a/tucan/canonicalization.py +++ b/tucan/canonicalization.py @@ -1,3 +1,4 @@ +from tucan.graph_attributes import PARTITION, INVARIANT_CODE from tucan.graph_utils import attribute_sequence import networkx as nx from igraph import Graph as iGraph @@ -15,14 +16,14 @@ def partition_molecule_by_attribute(m: nx.Graph, attribute: str) -> nx.Graph: m_partitioned = m.copy() nx.set_node_attributes( - m_partitioned, dict(zip(list(m_partitioned), partitions)), "partition" + m_partitioned, dict(zip(list(m_partitioned), partitions)), PARTITION ) return m_partitioned def get_number_of_partitions(m: nx.Graph) -> int: - return max(nx.get_node_attributes(m, "partition").values()) + return max(nx.get_node_attributes(m, PARTITION).values()) def refine_partitions(m: nx.Graph) -> Iterator[nx.Graph]: @@ -32,7 +33,7 @@ def refine_partitions(m: nx.Graph) -> Iterator[nx.Graph]: # partitions are discrete (i.e., each node in a separate partition) return m - m_refined = partition_molecule_by_attribute(m, "partition") + m_refined = partition_molecule_by_attribute(m, PARTITION) if get_number_of_partitions(m_refined) == n_current_partitions: # no refinement possible return m @@ -61,16 +62,14 @@ def assign_canonical_labels(m: nx.Graph) -> dict[int, int]: m_igraph = iGraph.from_networkx(m) old_labels = m_igraph.vs["_nx_name"] - partitions = m_igraph.vs["partition"] + partitions = m_igraph.vs[PARTITION] canonical_labels = m_igraph.canonical_permutation(color=partitions) return dict(zip(old_labels, canonical_labels)) def canonicalize_molecule(m: nx.Graph) -> nx.Graph: - m_partitioned_by_invariant_code = partition_molecule_by_attribute( - m, "invariant_code" - ) + m_partitioned_by_invariant_code = partition_molecule_by_attribute(m, INVARIANT_CODE) m_refined = list(refine_partitions(m_partitioned_by_invariant_code)) m_partitioned = m_refined[-1] if m_refined else m_partitioned_by_invariant_code diff --git a/tucan/element_attributes.py b/tucan/element_attributes.py index bde2e2a..0abb759 100644 --- a/tucan/element_attributes.py +++ b/tucan/element_attributes.py @@ -9,6 +9,12 @@ from __future__ import annotations from typing import Final +from tucan.graph_attributes import ( + ATOMIC_NUMBER, + CHG, + RAD, +) + element_symbols = [ "H", "He", @@ -375,20 +381,20 @@ ] ELEMENT_ATTRS: Final[dict[str, dict[str, str | int]]] = { - s: {"atomic_number": n, "element_name": e, "element_color": c} + s: {ATOMIC_NUMBER: n, "element_name": e, "element_color": c} for s, n, e, c in zip( element_symbols, atomic_numbers, element_names, element_colors ) } MOLFILE_V2000_CHARGES: Final[dict[int, dict[str, int]]] = { - 1: {"chg": 3}, - 2: {"chg": 2}, - 3: {"chg": 1}, - 4: {"rad": 2}, # doublet radical - 5: {"chg": -1}, - 6: {"chg": -2}, - 7: {"chg": -3}, + 1: {CHG: 3}, + 2: {CHG: 2}, + 3: {CHG: 1}, + 4: {RAD: 2}, # doublet radical + 5: {CHG: -1}, + 6: {CHG: -2}, + 7: {CHG: -3}, } diff --git a/tucan/graph_attributes.py b/tucan/graph_attributes.py new file mode 100644 index 0000000..f792ce1 --- /dev/null +++ b/tucan/graph_attributes.py @@ -0,0 +1,21 @@ +# +# Node attributes +# +ATOMIC_NUMBER = "atomic_number" +CHG = "chg" +ELEMENT_SYMBOL = "element_symbol" +INVARIANT_CODE = "invariant_code" +MASS = "mass" +PARTITION = "partition" +RAD = "rad" +X_COORD = "x_coord" +Y_COORD = "y_coord" +Z_COORD = "z_coord" + +# for internal use +EXPLORED = "explored" + +# +# Bond attributes +# +BOND_TYPE = "bond_type" diff --git a/tucan/graph_utils.py b/tucan/graph_utils.py index 98d62bc..79ab1b6 100644 --- a/tucan/graph_utils.py +++ b/tucan/graph_utils.py @@ -3,15 +3,22 @@ import random from typing import Any, NamedTuple +from tucan.graph_attributes import ( + ATOMIC_NUMBER, + INVARIANT_CODE, + MASS, + RAD, +) + def graph_from_molecule( atom_attrs: dict[int, dict[str, Any]], bond_attrs: dict[tuple[int, int], dict[str, int]], ) -> nx.Graph: invariant_code_definitions = [ - InvariantCodeDefinition("atomic_number"), - InvariantCodeDefinition("mass", 0), - InvariantCodeDefinition("rad", 0), + InvariantCodeDefinition(ATOMIC_NUMBER), + InvariantCodeDefinition(MASS, 0), + InvariantCodeDefinition(RAD, 0), ] _add_invariant_code(atom_attrs, invariant_code_definitions) @@ -40,7 +47,7 @@ def _add_invariant_code( else attrs.get(icd.key, default_value) for icd in invariant_code_definitions ) - atom_attrs[atom].update({"invariant_code": invariant_code}) + atom_attrs[atom].update({INVARIANT_CODE: invariant_code}) def sort_molecule_by_attribute(m: nx.Graph, attribute: str) -> nx.Graph: diff --git a/tucan/io/molfile_v2000_reader.py b/tucan/io/molfile_v2000_reader.py index 4a802a2..44a01a9 100644 --- a/tucan/io/molfile_v2000_reader.py +++ b/tucan/io/molfile_v2000_reader.py @@ -4,6 +4,18 @@ MOLFILE_V2000_CHARGES, detect_hydrogen_isotopes, ) +from tucan.graph_attributes import ( + ATOMIC_NUMBER, + BOND_TYPE, + CHG, + ELEMENT_SYMBOL, + MASS, + PARTITION, + RAD, + X_COORD, + Y_COORD, + Z_COORD, +) from tucan.io.exception import MolfileParserException @@ -41,19 +53,19 @@ def _parse_atom_line(line: str) -> dict[str, Any]: element_symbol, isotope_mass = detect_hydrogen_isotopes(element_symbol) atom_attrs = { - "element_symbol": element_symbol, - "atomic_number": ELEMENT_ATTRS[element_symbol]["atomic_number"], - "partition": 0, - "x_coord": _to_float(line[0:10]), # xxxxx.xxxx - "y_coord": _to_float(line[10:20]), # yyyyy.yyyy - "z_coord": _to_float(line[20:30]), # zzzzz.zzzz + ELEMENT_SYMBOL: element_symbol, + ATOMIC_NUMBER: ELEMENT_ATTRS[element_symbol][ATOMIC_NUMBER], + PARTITION: 0, + X_COORD: _to_float(line[0:10]), # xxxxx.xxxx + Y_COORD: _to_float(line[10:20]), # yyyyy.yyyy + Z_COORD: _to_float(line[20:30]), # zzzzz.zzzz } atom_attrs |= MOLFILE_V2000_CHARGES.get(_to_int(line[36:39]), {}) # ccc # Field "dd" (mass difference) is ignored. Only consider hydrogen # isotopes (D and T) here and "M ISO" in the attribute block (later). if isotope_mass: - atom_attrs["mass"] = isotope_mass + atom_attrs[MASS] = isotope_mass return atom_attrs @@ -74,7 +86,7 @@ def _parse_bond_line( _validate_atom_index(index1, atom_attrs, line) _validate_atom_index(index2, atom_attrs, line) - bond_attrs = {"bond_type": _to_int(line[6:9])} # ttt + bond_attrs = {BOND_TYPE: _to_int(line[6:9])} # ttt return (index1, index2), bond_attrs @@ -90,20 +102,20 @@ def _parse_attribute_block( if line.startswith("M CHG"): # M CHGnn8 aaa vvv ... _merge_tuples_into_additional_attributes( - _parse_atom_value_assignments(line, atom_attrs), "chg", additional_attrs + _parse_atom_value_assignments(line, atom_attrs), CHG, additional_attrs ) reset_chg_and_rad = True elif line.startswith("M RAD"): # M RADnn8 aaa vvv ... _merge_tuples_into_additional_attributes( - _parse_atom_value_assignments(line, atom_attrs), "rad", additional_attrs + _parse_atom_value_assignments(line, atom_attrs), RAD, additional_attrs ) reset_chg_and_rad = True elif line.startswith("M ISO"): # M ISOnn8 aaa vvv ... _merge_tuples_into_additional_attributes( _parse_atom_value_assignments(line, atom_attrs), - "mass", + MASS, additional_attrs, ) reset_mass = True @@ -114,11 +126,11 @@ def _parse_attribute_block( if reset_chg_and_rad: # CHG or RAD lines supersede all charge and radical values from the atom block. - _clear_atom_attribute("chg", atom_attrs) - _clear_atom_attribute("rad", atom_attrs) + _clear_atom_attribute(CHG, atom_attrs) + _clear_atom_attribute(RAD, atom_attrs) if reset_mass: # ISO lines supersede all isotope values from the atom block. - _clear_atom_attribute("mass", atom_attrs) + _clear_atom_attribute(MASS, atom_attrs) _merge_atom_attributes_and_additional_attributes(atom_attrs, additional_attrs) diff --git a/tucan/io/molfile_v3000_reader.py b/tucan/io/molfile_v3000_reader.py index 5083b56..94e1ce9 100644 --- a/tucan/io/molfile_v3000_reader.py +++ b/tucan/io/molfile_v3000_reader.py @@ -2,6 +2,18 @@ from collections import deque from typing import Any from tucan.element_attributes import ELEMENT_ATTRS, detect_hydrogen_isotopes +from tucan.graph_attributes import ( + ATOMIC_NUMBER, + BOND_TYPE, + CHG, + ELEMENT_SYMBOL, + MASS, + PARTITION, + RAD, + X_COORD, + Y_COORD, + Z_COORD, +) from tucan.io.exception import MolfileParserException @@ -98,20 +110,20 @@ def _parse_atom_attributes( element_symbol, isotope_mass = detect_hydrogen_isotopes(element_symbol) atom_attrs = { - "element_symbol": element_symbol, - "atomic_number": ELEMENT_ATTRS[element_symbol]["atomic_number"], - "partition": 0, - "x_coord": float(line[4]), - "y_coord": float(line[5]), - "z_coord": float(line[6]), + ELEMENT_SYMBOL: element_symbol, + ATOMIC_NUMBER: ELEMENT_ATTRS[element_symbol][ATOMIC_NUMBER], + PARTITION: 0, + X_COORD: float(line[4]), + Y_COORD: float(line[5]), + Z_COORD: float(line[6]), } optional_attrs = { - "chg": [int(i.split("=")[1]) for i in line if "CHG" in i], - "mass": [int(i.split("=")[1]) for i in line if "MASS" in i] + CHG: [int(i.split("=")[1]) for i in line if "CHG" in i], + MASS: [int(i.split("=")[1]) for i in line if "MASS" in i] if not isotope_mass else [isotope_mass], - "rad": [int(i.split("=")[1]) for i in line if "RAD" in i], + RAD: [int(i.split("=")[1]) for i in line if "RAD" in i], } for key, val in optional_attrs.items(): if val: @@ -170,7 +182,7 @@ def _parse_bond_block( def _parse_bond_attributes(line: list[str]) -> dict[str, int]: - return {"bond_type": int(line[3])} + return {BOND_TYPE: int(line[3])} def _parse_bond_line_with_star_atom( diff --git a/tucan/io/molfile_writer.py b/tucan/io/molfile_writer.py index 99eb612..6169e90 100644 --- a/tucan/io/molfile_writer.py +++ b/tucan/io/molfile_writer.py @@ -1,6 +1,16 @@ from datetime import datetime import networkx as nx import tucan +from tucan.graph_attributes import ( + BOND_TYPE, + CHG, + ELEMENT_SYMBOL, + MASS, + RAD, + X_COORD, + Y_COORD, + Z_COORD, +) def graph_to_molfile(graph: nx.Graph, calc_coordinates=False) -> str: @@ -69,19 +79,17 @@ def _add_atom_block(lines: list[str], graph: nx.Graph, calc_coordinates: bool): _add_v30_line(lines, "BEGIN ATOM") for index, attrs in graph.nodes(data=True): - x = coords[index][0] if calc_coordinates else attrs.get("x_coord", 0) - y = coords[index][1] if calc_coordinates else attrs.get("y_coord", 0) - z = 0 if calc_coordinates else attrs.get("z_coord", 0) - - charge = f" CHG={chg}" if (chg := attrs.get("chg")) and -15 <= chg <= 15 else "" - radical = f" RAD={rad}" if (rad := attrs.get("rad")) and 0 < rad <= 3 else "" - atomic_mass = ( - f" MASS={mass}" if (mass := attrs.get("mass")) and mass > 0 else "" - ) + x = coords[index][0] if calc_coordinates else attrs.get(X_COORD, 0) + y = coords[index][1] if calc_coordinates else attrs.get(Y_COORD, 0) + z = 0 if calc_coordinates else attrs.get(Z_COORD, 0) + + charge = f" CHG={chg}" if (chg := attrs.get(CHG)) and -15 <= chg <= 15 else "" + radical = f" RAD={rad}" if (rad := attrs.get(RAD)) and 0 < rad <= 3 else "" + atomic_mass = f" MASS={mass}" if (mass := attrs.get(MASS)) and mass > 0 else "" _add_v30_line( lines, - f"{index + 1} {attrs['element_symbol']} {x:.6f} {y:.6f} {z:.6f} 0{charge}{radical}{atomic_mass}", + f"{index + 1} {attrs[ELEMENT_SYMBOL]} {x:.6f} {y:.6f} {z:.6f} 0{charge}{radical}{atomic_mass}", ) _add_v30_line(lines, "END ATOM") @@ -95,7 +103,7 @@ def _add_bond_block(lines: list[str], graph: nx.Graph): for index, edge in enumerate(graph.edges(data=True), start=1): node_index1, node_index2, attrs = edge - bond_type = attrs.get("bond_type", 1) + bond_type = attrs.get(BOND_TYPE, 1) _add_v30_line(lines, f"{index} {bond_type} {node_index1 + 1} {node_index2 + 1}") diff --git a/tucan/parser/parser.py b/tucan/parser/parser.py index bd69869..477991c 100644 --- a/tucan/parser/parser.py +++ b/tucan/parser/parser.py @@ -1,13 +1,21 @@ +from __future__ import annotations + import networkx as nx -from typing import Any +from typing import Any, Final from antlr4 import InputStream, CommonTokenStream from antlr4.error.ErrorListener import ErrorListener from antlr4.tree.Tree import ParseTreeWalker from tucan.element_attributes import ELEMENT_ATTRS +from tucan.graph_attributes import ( + ATOMIC_NUMBER, + ELEMENT_SYMBOL, + PARTITION, +) from tucan.graph_utils import graph_from_molecule from tucan.parser.tucanLexer import tucanLexer from tucan.parser.tucanListener import tucanListener from tucan.parser.tucanParser import tucanParser +from tucan.serialization import _SERIALIZER_NODE_ATTRIBUTE_MAPPING def graph_from_tucan(tucan: str) -> nx.Graph: @@ -28,7 +36,7 @@ def graph_from_tucan(tucan: str) -> nx.Graph: return listener.to_graph() -def _prepare_parser(to_parse: str) -> nx.Graph: +def _prepare_parser(to_parse: str) -> tucanParser: stream = InputStream(to_parse) lexer = tucanLexer(stream) token_stream = CommonTokenStream(lexer) @@ -43,13 +51,18 @@ def _prepare_parser(to_parse: str) -> nx.Graph: return parser -def _walk_tree(tree): +def _walk_tree(tree) -> TucanListenerImpl: walker = ParseTreeWalker() listener = TucanListenerImpl() walker.walk(listener, tree) return listener +_DESERIALIZER_NODE_ATTRIBUTE_MAPPING: Final[dict[str, str]] = { + v: k for k, v in _SERIALIZER_NODE_ATTRIBUTE_MAPPING.items() +} + + class TucanListenerImpl(tucanListener): def __init__(self): self._atoms = [] @@ -90,9 +103,9 @@ def _parse_sum_formula(self, formula_ctx): def _add_atoms(self, element, count): atom_attrs = { - "element_symbol": element, - "atomic_number": ELEMENT_ATTRS[element]["atomic_number"], - "partition": 0, + ELEMENT_SYMBOL: element, + ATOMIC_NUMBER: ELEMENT_ATTRS[element][ATOMIC_NUMBER], + PARTITION: 0, } self._atoms.extend([atom_attrs.copy() for _ in range(count)]) @@ -102,12 +115,13 @@ def _add_bond(self, index1, index2): def _add_node_attribute(self, node_index, key, value): attrs_for_node = self._node_attributes.setdefault(node_index - 1, {}) + attr_key = _DESERIALIZER_NODE_ATTRIBUTE_MAPPING[key] - if key in attrs_for_node: + if attr_key in attrs_for_node: raise TucanParserException( f'Atom {node_index}: Attribute "{key}" was already defined.' ) - attrs_for_node[key] = value + attrs_for_node[attr_key] = value def to_graph(self) -> nx.Graph: # node index validation @@ -115,7 +129,7 @@ def to_graph(self) -> nx.Graph: self._validate_atom_index(i1) self._validate_atom_index(i2) - sorted_atoms = sorted(self._atoms, key=lambda a: a["atomic_number"]) + sorted_atoms = sorted(self._atoms, key=lambda a: a[ATOMIC_NUMBER]) # dict of dict (atom_index -> dict of atom attributes) atoms_dict: dict[int, dict[str, Any]] = { diff --git a/tucan/serialization.py b/tucan/serialization.py index bd45d27..afcd286 100644 --- a/tucan/serialization.py +++ b/tucan/serialization.py @@ -1,13 +1,22 @@ from collections import Counter, deque + +from tucan.graph_attributes import ( + ATOMIC_NUMBER, + ELEMENT_SYMBOL, + EXPLORED, + MASS, + PARTITION, + RAD, +) from tucan.graph_utils import sort_molecule_by_attribute from operator import gt, lt, eq -from typing import Callable +from typing import Callable, Final import networkx as nx def serialize_molecule(m: nx.Graph) -> str: """Serialize a molecule.""" - m_sorted = sort_molecule_by_attribute(_assign_final_labels(m), "atomic_number") + m_sorted = sort_molecule_by_attribute(_assign_final_labels(m), ATOMIC_NUMBER) serialization = _write_sum_formula(m_sorted) serialization += f"/{_write_edge_list(m_sorted)}" node_attributes = _write_node_attributes(m_sorted) @@ -25,12 +34,19 @@ def _write_edge_list(m: nx.Graph) -> str: return edge_list_string +_SERIALIZER_NODE_ATTRIBUTE_MAPPING: Final[dict[str, str]] = { + MASS: "mass", + RAD: "rad", +} + + def _write_node_attributes(m: nx.Graph) -> str: node_attribute_string = "" - for node in sorted(m.nodes(data=True)): - label, attrs = node + for label, attrs in sorted(m.nodes(data=True)): available_attrs = [ - f"{attr}={attrs[attr]}" for attr in ("mass", "rad") if attr in attrs + f"{_SERIALIZER_NODE_ATTRIBUTE_MAPPING[attr]}={attrs[attr]}" + for attr in _SERIALIZER_NODE_ATTRIBUTE_MAPPING + if attr in attrs ] if not available_attrs: continue @@ -53,7 +69,7 @@ def _write_sum_formula(m: nx.Graph) -> str: ---------- [1] doi:10.1021/ja02046a005 """ - element_counts = Counter(nx.get_node_attributes(m, "element_symbol").values()) + element_counts = Counter(nx.get_node_attributes(m, ELEMENT_SYMBOL).values()) sum_formula_string = "" carbon_count = element_counts.pop("C", None) if carbon_count: @@ -75,20 +91,20 @@ def _assign_final_labels( smallest possible labels. This is not part of (and not required for) the canonicalization. The re-labeling is for cosmetic purposes.""" - partitions = m.nodes.data("partition") + partitions = m.nodes.data(PARTITION) labels_by_partition = _labels_by_partition(m) final_labels = {} - nx.set_node_attributes(m, False, "explored") + nx.set_node_attributes(m, False, EXPLORED) # outer loop iterates over all fragments of the graph (= graph components), # starting with the lowest unexplored node label - while unexplored := sorted([k for k, v in m.nodes(data="explored") if not v]): + while unexplored := sorted([k for k, v in m.nodes(data=EXPLORED) if not v]): atom_queue = deque([unexplored[0]]) # inner loop reaches out to all atoms in a fragment while atom_queue: a = atom_queue.pop() - if m.nodes[a]["explored"]: + if m.nodes[a][EXPLORED]: continue a_final = labels_by_partition[ partitions[a] @@ -104,22 +120,22 @@ def _assign_final_labels( n for n in neighbors if priority(partitions[a], partitions[n]) ] neighbor_traversal_order.extend(sorted(neighbors_this_priority)) - m.nodes[a]["explored"] = True + m.nodes[a][EXPLORED] = True atom_queue.extendleft(neighbor_traversal_order) assert len(final_labels) == len(m.nodes) - nx.set_node_attributes(m, False, "explored") + nx.set_node_attributes(m, False, EXPLORED) return nx.relabel_nodes(m, final_labels, copy=True) def _labels_by_partition(m: nx.Graph) -> dict[int, list[int]]: """Create dictionary of partitions to node labels.""" - partitions = set(sorted([v for _, v in m.nodes.data("partition")])) + partitions = set(sorted([v for _, v in m.nodes.data(PARTITION)])) labels_by_partition: dict[int, list[int]] = {p: [] for p in partitions} for a in m: - labels_by_partition[m.nodes[a]["partition"]].append(a) + labels_by_partition[m.nodes[a][PARTITION]].append(a) labels_by_partition.update( (k, sorted(list(v), reverse=True)) for k, v in labels_by_partition.items() ) diff --git a/tucan/visualization.py b/tucan/visualization.py index 76b90ae..7a32f90 100644 --- a/tucan/visualization.py +++ b/tucan/visualization.py @@ -4,6 +4,12 @@ import plotly.graph_objects as go import plotly.subplots as sp +from tucan.graph_attributes import ( + ATOMIC_NUMBER, + INVARIANT_CODE, + PARTITION, +) + def _draw_networkx_graph(m, highlight, labels, ax): highlight_colors = list(nx.get_node_attributes(m, highlight).values()) @@ -69,7 +75,7 @@ def _draw_networkx_graph_3d(m, highlight, labels, fig, col): def draw_molecules( - m_list, caption_list, labels=None, highlight="atomic_number", title="", dim=2 + m_list, caption_list, labels=None, highlight=ATOMIC_NUMBER, title="", dim=2 ): """Draw molecule(s). @@ -80,8 +86,10 @@ def draw_molecules( dim: int Plot in "2" (default) or "3" dimensions. """ - if highlight not in ["atomic_number", "partition"]: - print("Please select one of {'partition', 'atomic_number'} for `highlight`.") + if highlight not in [ATOMIC_NUMBER, PARTITION]: + print( + f"Please select one of {{'{PARTITION}', '{ATOMIC_NUMBER}'}} for `highlight`." + ) return n_molecules = len(m_list) @@ -117,10 +125,10 @@ def print_molecule(m, caption=""): print(caption) table = [] for atom in sorted(list(m.nodes)): - invariant_code = m.nodes[atom]["invariant_code"] - partition = m.nodes[atom]["partition"] + invariant_code = m.nodes[atom][INVARIANT_CODE] + partition = m.nodes[atom][PARTITION] neighbors = [ - (n, m.nodes[n]["invariant_code"], m.nodes[n]["partition"]) + (n, m.nodes[n][INVARIANT_CODE], m.nodes[n][PARTITION]) for n in m.neighbors(atom) ] neighbors = sorted(neighbors, key=lambda x: x[2], reverse=True)