Skip to content

Commit

Permalink
Assemble attribute sequences more efficiently
Browse files Browse the repository at this point in the history
  • Loading branch information
JanCBrammer committed Dec 21, 2024
1 parent 9b2a7ea commit 8fd73f8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions tucan/canonicalization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from tucan.graph_attributes import PARTITION, INVARIANT_CODE
from tucan.graph_utils import attribute_sequence
from tucan.graph_utils import get_attribute_sequences
import networkx as nx
from typing import Generator
from collections import Counter
Expand All @@ -9,7 +9,7 @@ def partition_molecule_by_attribute(
m: nx.Graph, attribute: str, copy: bool = True
) -> nx.Graph:
# Node degree (i.e., number of neighbors) is encoded in length of individual attribute sequences.
attr_seqs = [attribute_sequence(m, atom, attribute) for atom in m]
attr_seqs = get_attribute_sequences(m, attribute)
unique_attr_seqs = sorted(set(attr_seqs))
unique_attr_seqs_to_partitions = dict(
zip(unique_attr_seqs, range(len(unique_attr_seqs)))
Expand Down
22 changes: 11 additions & 11 deletions tucan/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def _add_invariant_code(

def sort_molecule_by_attribute(m: nx.Graph, attribute: str) -> nx.Graph:
"""Sort atoms by attribute."""
attr_with_labels = [
(attribute_sequence(m, atom, attribute), atom) for atom in m
] # [(A, 0), (C, 1), (B, 2)]
attr_with_labels = zip(
get_attribute_sequences(m, attribute), list(m.nodes)
) # [(A, 0), (C, 1), (B, 2)]
sorted_attr, labels_sorted_by_attr = zip(
*sorted(attr_with_labels)
) # (A, B, C), (0, 2, 1)
Expand All @@ -66,15 +66,15 @@ def sort_molecule_by_attribute(m: nx.Graph, attribute: str) -> nx.Graph:
)


def attribute_sequence(
m: nx.Graph, atom: int, attribute: str
) -> tuple[str | int | float, ...]:
attr_atom = m.nodes[atom][attribute]
attr_neighbors = sorted(
[m.nodes[n][attribute] for n in m.neighbors(atom)], reverse=True
)
def get_attribute_sequences(
m: nx.Graph, attribute: str
) -> list[tuple[str | int | float, ...]]:
m_attrs = dict(m.nodes(data=attribute)) # type: ignore

return (attr_atom, *attr_neighbors)
return [
(attr, *sorted([m_attrs[n] for n in m.neighbors(node)], reverse=True))
for node, attr in m_attrs.items()
] # type: ignore


def permute_molecule(m: nx.Graph, random_seed: float = 1.0) -> nx.Graph:
Expand Down

0 comments on commit 8fd73f8

Please sign in to comment.