Skip to content

Commit

Permalink
Refactor partition refinement
Browse files Browse the repository at this point in the history
  • Loading branch information
JanCBrammer committed Sep 27, 2023
1 parent 247ae66 commit d954dc6
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions tucan/canonicalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


def partition_molecule_by_attribute(m: nx.Graph, attribute: str) -> nx.Graph:
# Node degree (i.e., number of neighbors) is encoded in length of individual attribute sequences.
attr_seqs = [attribute_sequence(m, atom, attribute) for atom in m]
unique_attr_seqs = sorted(set(attr_seqs))
unique_attr_seqs_to_partitions = dict(
Expand All @@ -20,16 +21,25 @@ def partition_molecule_by_attribute(m: nx.Graph, attribute: str) -> nx.Graph:
return m_partitioned


def get_number_of_partitions(m: nx.Graph) -> int:
return max(nx.get_node_attributes(m, "partition").values())


def refine_partitions(m: nx.Graph) -> Iterator[nx.Graph]:
current_partitions = nx.get_node_attributes(m, "partition").values()
n_current_partitions = get_number_of_partitions(m)

if n_current_partitions == m.number_of_nodes() - 1:
# partitions are discrete (i.e., each node in a separate partition)
return m

m_refined = partition_molecule_by_attribute(m, "partition")
refined_partitions = nx.get_node_attributes(m_refined, "partition").values()
if get_number_of_partitions(m_refined) == n_current_partitions:
# no refinement possible
return m

yield m_refined

while max(current_partitions) != max(refined_partitions):
yield m_refined
current_partitions = refined_partitions
m_refined = partition_molecule_by_attribute(m_refined, "partition")
refined_partitions = nx.get_node_attributes(m_refined, "partition").values()
yield from refine_partitions(m_refined)


def assign_canonical_labels(m: nx.Graph) -> dict[int, int]:
Expand Down

0 comments on commit d954dc6

Please sign in to comment.