Skip to content

Commit

Permalink
add method to compute all mcf measures at once
Browse files Browse the repository at this point in the history
  • Loading branch information
d-schindler committed Nov 21, 2024
1 parent d69237d commit 89fcde0
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 2 deletions.
15 changes: 15 additions & 0 deletions mcf/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""I/O functions."""

import pickle


def save_results(results, filename="results.pkl"):
"""Save results in a pickle."""
with open(filename, "wb") as results_file:
pickle.dump(results, results_file)


def load_results(filename="results.pkl"): # pragma: no cover
"""Load results from a pickle."""
with open(filename, "rb") as results_file:
return pickle.load(results_file)
51 changes: 51 additions & 0 deletions mcf/mcf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from tqdm import tqdm

from mcf.io import save_results
from mcf.measures import (
compute_bettis,
compute_partition_size,
Expand Down Expand Up @@ -130,6 +131,56 @@ def compute_persistent_conflict(self):
c_1, c_2, c = compute_persistent_conflict(self)
return c_1, c_2, c

def compute_all_measures(
self,
file_path="mcf_results.pkl",
):
"""Construct filtration, compute PH and compute all derived measures."""

# build filtration
self.build_filtration()

# compute persistent homology
self.compute_persistence()

# obtain persistence
persistence = [
self.filtration_gudhi.persistence_intervals_in_dimension(dim)
for dim in range(self.max_dim)
]

# compute Betti numbers
betti_0, betti_1, betti_2 = self.compute_bettis()

# compute size of partitions
s_partitions = self.compute_partition_size()

# compute persistent hierarchy
h, h_bar = self.compute_persistent_hierarchy()

# compute persistent conflict
c_1, c_2, c = self.compute_persistent_conflict()

# compile results dictionary
mcf_results = {}
mcf_results["filtration_indices"] = self.filtration_indices
mcf_results["max_dim"] = self.max_dim
mcf_results["persistence"] = persistence
mcf_results["betti_0"] = betti_0
mcf_results["betti_1"] = betti_1
mcf_results["betti_2"] = betti_2
mcf_results["s_partitions"] = s_partitions
mcf_results["h"] = h
mcf_results["h_bar"] = h_bar
mcf_results["c_1"] = c_1
mcf_results["c_2"] = c_2
mcf_results["c"] = c

# save results
save_results(mcf_results, file_path)

return mcf_results


class MCNF(MCF):
"""Class to construct MCNF from a sequence of partitions using equivalent
Expand Down
2 changes: 1 addition & 1 deletion mcf/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _compute_death_count(mcf, dim):
death_count[i] = np.sum(all_deaths == mcf.filtration_indices[i])

# count inf
death_count[mcf.n_partitions] = np.sum(all_deaths == np.Inf)
death_count[mcf.n_partitions] = np.sum(all_deaths == np.inf)

return death_count

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from setuptools import find_packages
from setuptools import setup

__version__ = "0.0.4"
__version__ = "0.0.5"

setup(
name="MCF",
Expand Down
29 changes: 29 additions & 0 deletions tests/test_mcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,32 @@ def test_compute_persistent_conflict():
assert np.array_equal(c_1, np.array([0.0, 0.0, 0.0, 1.0, -1.0]))
assert np.array_equal(c_2, np.array([0.0, 0.0, 0.0, 0.0, 0.0]))
assert np.array_equal(c, np.array([0.0, 0.0, 0.0, 1.0, -1.0]))


def test_compute_all_measures():
"""Test for computing all MCF measures."""

# initialise MCF object
mcf = MCF()
mcf.load_data(partitions, filtration_indices)

# compute all MCF measures
mcf_results = mcf.compute_all_measures()

# check if results match
assert np.array_equal(mcf_results["filtration_indices"], filtration_indices)
assert mcf_results["max_dim"] == 3
assert np.array_equal(
mcf_results["persistence"][0], np.array([[1.0, 2.0], [1.0, 3.0], [1.0, np.inf]])
)
assert np.array_equal(mcf_results["persistence"][1], np.array([[4.0, 5.0]]))
assert len(mcf_results["persistence"][2]) == 0
assert np.array_equal(mcf_results["betti_0"], np.array([3, 2, 1, 1, 1]))
assert np.array_equal(mcf_results["betti_1"], np.array([0, 0, 0, 1, 0]))
assert np.array_equal(mcf_results["betti_2"], np.array([0, 0, 0, 0, 0]))
assert np.array_equal(mcf_results["s_partitions"], np.array([3, 2, 2, 2, 1]))
assert np.array_equal(mcf_results["h"], np.array([1.0, 1.0, 0.5, 0.5, 1.0]))
assert mcf_results["h_bar"] == 0.75
assert np.array_equal(mcf_results["c_1"], np.array([0.0, 0.0, 0.0, 1.0, -1.0]))
assert np.array_equal(mcf_results["c_2"], np.array([0.0, 0.0, 0.0, 0.0, 0.0]))
assert np.array_equal(mcf_results["c"], np.array([0.0, 0.0, 0.0, 1.0, -1.0]))
30 changes: 30 additions & 0 deletions tests/test_mcnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,33 @@ def test_compute_persistent_conflict():
assert np.array_equal(c_1, np.array([0.0, 0.0, 0.0, 1.0, -1.0]))
assert np.array_equal(c_2, np.array([0.0, 0.0, 0.0, 0.0, 0.0]))
assert np.array_equal(c, np.array([0.0, 0.0, 0.0, 1.0, -1.0]))


def test_compute_all_measures():
"""Test for computing all MCF measures."""

# initialise MCNF object
mcnf = MCNF()
mcnf.load_data(partitions, filtration_indices)

# compute all MCNF measures
mcnf_results = mcnf.compute_all_measures()

# check if results match
assert np.array_equal(mcnf_results["filtration_indices"], filtration_indices)
assert mcnf_results["max_dim"] == 3
assert np.array_equal(
mcnf_results["persistence"][0],
np.array([[1.0, 2.0], [1.0, 3.0], [1.0, np.inf]]),
)
assert np.array_equal(mcnf_results["persistence"][1], np.array([[4.0, 5.0]]))
assert len(mcnf_results["persistence"][2]) == 0
assert np.array_equal(mcnf_results["betti_0"], np.array([3, 2, 1, 1, 1]))
assert np.array_equal(mcnf_results["betti_1"], np.array([0, 0, 0, 1, 0]))
assert np.array_equal(mcnf_results["betti_2"], np.array([0, 0, 0, 0, 0]))
assert np.array_equal(mcnf_results["s_partitions"], np.array([3, 2, 2, 2, 1]))
assert np.array_equal(mcnf_results["h"], np.array([1.0, 1.0, 0.5, 0.5, 1.0]))
assert mcnf_results["h_bar"] == 0.75
assert np.array_equal(mcnf_results["c_1"], np.array([0.0, 0.0, 0.0, 1.0, -1.0]))
assert np.array_equal(mcnf_results["c_2"], np.array([0.0, 0.0, 0.0, 0.0, 0.0]))
assert np.array_equal(mcnf_results["c"], np.array([0.0, 0.0, 0.0, 1.0, -1.0]))

0 comments on commit 89fcde0

Please sign in to comment.