Skip to content

Commit

Permalink
WIP: Prune refinement-tree
Browse files Browse the repository at this point in the history
  • Loading branch information
JanCBrammer committed Aug 30, 2024
1 parent c775fad commit 44dd22d
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 151 deletions.
323 changes: 187 additions & 136 deletions docs/refinement_tree.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ drawing = [
"notebook",
"ipywidgets",
]
dev = ["pytest", "syrupy"]
dev = ["pytest", "pytest-timeout", "syrupy"]


[tool.pytest.ini_options]
Expand Down
1 change: 1 addition & 0 deletions tests/test_canonicalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_permutation(m):
assert m.edges != m_permu.edges


@pytest.mark.timeout(10)
def test_permutation_invariance(m):
permutation_invariance(m)

Expand Down
42 changes: 28 additions & 14 deletions tucan/canonicalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from tucan.graph_utils import attribute_sequence
import networkx as nx
from typing import Generator
from collections import Counter, deque
from collections import Counter
from itertools import combinations


def partition_molecule_by_attribute(m: nx.Graph, attribute: str) -> nx.Graph:
Expand Down Expand Up @@ -66,25 +67,38 @@ def get_refinement_tree_node_children(m: nx.Graph) -> Generator[nx.Graph, None,
yield m_artificially_refined


def get_discrete_partitionings(m: nx.Graph) -> Generator[nx.Graph, None, None]:
def filter_out_automorphisms(ms: list[nx.Graph]) -> list[nx.Graph]:
# TODO: Make this more efficient. E.g., compare labelings as in `get_canonical_molecule()`?
node_matcher = nx.algorithms.isomorphism.categorical_node_match(PARTITION, 0)
filtered_ms = set(ms)
for m_i, m_j in combinations(ms, 2):
if nx.is_isomorphic(m_i, m_j, node_match=node_matcher):
filtered_ms.discard(m_j)

return list(filtered_ms)


def get_refinement_tree_levels(m: nx.Graph) -> Generator[list[nx.Graph], None, None]:
"""
Build BFS refinement-tree and return its leaves (i.e., discrete partitionings).
rtn = refinement-tree-node
Build BFS refinement-tree and yield each level.
"""
rtn_queue = deque([m])
parents = [m]

while rtn_queue:
rtn = rtn_queue.popleft()

if partitioning_is_discrete(rtn):
yield rtn
continue
while True:
yield filter_out_automorphisms(parents)
if all(map(partitioning_is_discrete, parents)):
return

rtn_queue.extend(list(get_refinement_tree_node_children(rtn)))
children = [
child
for parent in parents
for child in get_refinement_tree_node_children(parent)
]
parents = children


def get_canonical_molecule(ms: list[nx.Graph]) -> nx.Graph:
m_canonical = None
m_canonical = ms[0]
canonical_labeling = [[0, 0]]

for m in ms:
Expand All @@ -105,6 +119,6 @@ def get_canonical_molecule(ms: list[nx.Graph]) -> nx.Graph:
def canonicalize_molecule(m: nx.Graph) -> nx.Graph:
m_partitioned_by_invariant_code = partition_molecule_by_attribute(m, INVARIANT_CODE)
m_refined = list(refine_partitions(m_partitioned_by_invariant_code))[-1]
ms_discrete_partitionings = list(get_discrete_partitionings(m_refined))
ms_discrete_partitionings = list(get_refinement_tree_levels(m_refined))[-1]

return get_canonical_molecule(ms_discrete_partitionings)

0 comments on commit 44dd22d

Please sign in to comment.