Skip to content

Commit

Permalink
Retrieve number of partitions more efficiently
Browse files Browse the repository at this point in the history
  • Loading branch information
JanCBrammer committed Dec 15, 2024
1 parent 81cbfc9 commit 9b2a7ea
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
4 changes: 2 additions & 2 deletions docs/refinement_tree.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"outputs": [],
"source": [
"from tucan.io import graph_from_file\n",
"from tucan.canonicalization import partition_molecule_by_attribute, refine_partitions, get_refinement_tree_node_children, get_refinement_tree_levels, get_number_of_partitions\n",
"from tucan.canonicalization import partition_molecule_by_attribute, refine_partitions, get_refinement_tree_node_children, get_refinement_tree_levels\n",
"from tucan.visualization import draw_molecules\n",
"from tucan.graph_attributes import PARTITION, INVARIANT_CODE"
]
Expand Down Expand Up @@ -163,7 +163,7 @@
"refinement_tree_levels = list(get_refinement_tree_levels(m_refined))\n",
"print(len(refinement_tree_levels))\n",
"for level in refinement_tree_levels:\n",
" draw_molecules(level, [f\"N partitions = {get_number_of_partitions(m)}\" for m in level], highlight=PARTITION)"
" draw_molecules(level, [f\"N partitions = {m.graph[\"n_partitions\"]}\" for m in level], highlight=PARTITION)"
]
},
{
Expand Down
13 changes: 5 additions & 8 deletions tucan/canonicalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,19 @@ def partition_molecule_by_attribute(
nx.set_node_attributes(
m_partitioned, dict(zip(list(m_partitioned), partitions)), PARTITION
)
m_partitioned.graph["n_partitions"] = len(unique_attr_seqs)

return m_partitioned


def get_number_of_partitions(m: nx.Graph) -> int:
return len(set(nx.get_node_attributes(m, PARTITION).values()))


def partitioning_is_discrete(m):
return get_number_of_partitions(m) == m.number_of_nodes()
return m.graph["n_partitions"] == m.number_of_nodes()


def refine_partitions(m: nx.Graph) -> Generator[nx.Graph, None, None]:
n_partitions = get_number_of_partitions(m)
n_partitions = m.graph["n_partitions"]
m_refined = partition_molecule_by_attribute(m, PARTITION, copy=False)
n_partitions_refined = get_number_of_partitions(m_refined)
n_partitions_refined = m_refined.graph["n_partitions"]

if n_partitions == n_partitions_refined:
# No more refinement possible.
Expand All @@ -53,7 +50,7 @@ def get_target_partition(m: nx.Graph) -> int:


def get_refinement_tree_node_children(m: nx.Graph) -> Generator[nx.Graph, None, None]:
n_partitions = get_number_of_partitions(m)
n_partitions = m.graph["n_partitions"]
target_partition = get_target_partition(m)

for atom, partition in m.nodes(data=PARTITION): # type: ignore
Expand Down

0 comments on commit 9b2a7ea

Please sign in to comment.