From 9b2a7ea7b6d0b14d7bb948463801ee9b5bad22b8 Mon Sep 17 00:00:00 2001 From: "Jan C. Brammer" Date: Sun, 15 Dec 2024 12:37:37 +0000 Subject: [PATCH] Retrieve number of partitions more efficiently --- docs/refinement_tree.ipynb | 4 ++-- tucan/canonicalization.py | 13 +++++-------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/docs/refinement_tree.ipynb b/docs/refinement_tree.ipynb index e987134..e4f4bc1 100644 --- a/docs/refinement_tree.ipynb +++ b/docs/refinement_tree.ipynb @@ -7,7 +7,7 @@ "outputs": [], "source": [ "from tucan.io import graph_from_file\n", - "from tucan.canonicalization import partition_molecule_by_attribute, refine_partitions, get_refinement_tree_node_children, get_refinement_tree_levels, get_number_of_partitions\n", + "from tucan.canonicalization import partition_molecule_by_attribute, refine_partitions, get_refinement_tree_node_children, get_refinement_tree_levels\n", "from tucan.visualization import draw_molecules\n", "from tucan.graph_attributes import PARTITION, INVARIANT_CODE" ] @@ -163,7 +163,7 @@ "refinement_tree_levels = list(get_refinement_tree_levels(m_refined))\n", "print(len(refinement_tree_levels))\n", "for level in refinement_tree_levels:\n", - " draw_molecules(level, [f\"N partitions = {get_number_of_partitions(m)}\" for m in level], highlight=PARTITION)" + " draw_molecules(level, [f\"N partitions = {m.graph[\"n_partitions\"]}\" for m in level], highlight=PARTITION)" ] }, { diff --git a/tucan/canonicalization.py b/tucan/canonicalization.py index 4d8ab7f..4f4b574 100644 --- a/tucan/canonicalization.py +++ b/tucan/canonicalization.py @@ -20,22 +20,19 @@ def partition_molecule_by_attribute( nx.set_node_attributes( m_partitioned, dict(zip(list(m_partitioned), partitions)), PARTITION ) + m_partitioned.graph["n_partitions"] = len(unique_attr_seqs) return m_partitioned -def get_number_of_partitions(m: nx.Graph) -> int: - return len(set(nx.get_node_attributes(m, PARTITION).values())) - - def partitioning_is_discrete(m): - return get_number_of_partitions(m) == m.number_of_nodes() + return m.graph["n_partitions"] == m.number_of_nodes() def refine_partitions(m: nx.Graph) -> Generator[nx.Graph, None, None]: - n_partitions = get_number_of_partitions(m) + n_partitions = m.graph["n_partitions"] m_refined = partition_molecule_by_attribute(m, PARTITION, copy=False) - n_partitions_refined = get_number_of_partitions(m_refined) + n_partitions_refined = m_refined.graph["n_partitions"] if n_partitions == n_partitions_refined: # No more refinement possible. @@ -53,7 +50,7 @@ def get_target_partition(m: nx.Graph) -> int: def get_refinement_tree_node_children(m: nx.Graph) -> Generator[nx.Graph, None, None]: - n_partitions = get_number_of_partitions(m) + n_partitions = m.graph["n_partitions"] target_partition = get_target_partition(m) for atom, partition in m.nodes(data=PARTITION): # type: ignore