diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..67313c5 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,50 @@ +name: CI + +on: + pull_request: + push: + branches: + - main + - "dev*" + +jobs: + test: + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash + strategy: + matrix: + os: [ubuntu-latest, windows-latest] + py-major: [3] + py-minor: [8, 11] + env: + python-version: | + ${{ format('{0}.{1}', matrix.py-major, matrix.py-minor) }} + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Set up Python ${{ env.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ env.python-version }} + cache: "pip" + - name: Update pip and install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install ".[test]" pytest-sugar + - name: Run tests (<3.9) + if: matrix.py-minor < 9 + run: | + python -m pytest --cov=centhesus tests + - name: Run tests (>=3.11) + if: matrix.py-minor >= 11 + run: | + python -m pytest --cov=centhesus --cov-fail-under=100 tests + - name: Install and run linters + if: matrix.os == 'ubuntu-latest' && matrix.py-minor == 11 + run: | + python -m pip install ".[lint]" + python -m black --check . + python -m ruff . diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d7d195a --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +.coverage +.hypothesis + +__pycache__ + +*.egg-info + +scrap \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 546f243..d6618a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,8 @@ readme = "README.md" requires-python = ">=3.7" license = {text = "MIT License"} dependencies = [ + "census21api@git+https://github.com/MichaelaLawrenceONS/Census21_CACD_Wrapper@dev-0.0.1", + "dask[complete]", "numpy", "pandas", "private-pgm@git+https://github.com/ryan112358/private-pgm", @@ -30,7 +32,7 @@ test = [ "pytest-randomly", ] lint = [ - "black>=22.6.0,<23", + "black<24", "ruff>=0.1.1", ] dev = [ diff --git a/src/centhesus/__init__.py b/src/centhesus/__init__.py index cb55d3c..ada8767 100644 --- a/src/centhesus/__init__.py +++ b/src/centhesus/__init__.py @@ -1,5 +1,7 @@ """Synthesising the 2021 England and Wales Census with public data.""" +from .mst import MST + __version__ = "0.0.1" -__all__ = ["__version__"] +__all__ = ["MST", "__version__"] diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py new file mode 100644 index 0000000..28fc146 --- /dev/null +++ b/src/centhesus/mst.py @@ -0,0 +1,641 @@ +"""Module for the Maximum Spanning Tree generator.""" + +import itertools + +import dask +import dask.array as da +import dask.dataframe as dd +import networkx as nx +import numpy as np +from census21api import CensusAPI +from census21api.constants import DIMENSIONS_BY_POPULATION_TYPE as DIMENSIONS +from mbi import Domain, FactoredInference +from scipy import sparse + + +class MST: + """ + Data synthesiser based on the Maximum Spanning Tree (MST) method. + + This class uses the principles of the + [MST method](https://doi.org/10.29012/jpc.778) that won the 2018 + NIST Differential Privacy Synthetic Data Challenge. The original + method makes use of a formal privacy framework to protect the + confidentiality of the dataset being synthesised. In our case, we + use the publicly available tables to create our synthetic data. + These tables have undergone stringent statistical disclosure control + to make them safe to be in the public domain. + + As such, we adapt MST by removing the formal privacy mechanisms. We + do not add noise to the public tables, and we use Kruskal's + algorithm to find the true maximum spanning tree of the feature + graph. We still make use of the Private-PGM method to generate the + graphical model and subsequent synthetic data with a nominal amount + of noise (1e-10). + + The public tables are drawn from the ONS "Create a custom dataset" + API, which is accessed via the `census21api` package. See + `census21api.constants` for details of available population types, + area types, and dimensions. + + Parameters + ---------- + population_type : str + Population type to synthesise. Defaults to usual residents in + households (`"UR_HH"`). + area_type : str, optional + Area type to synthesise. If you wish to include an area type + column (like local authority) in the final dataset, include it + here. The lowest recommended level is MSOA because of issues + handling too-large marginal tables. + dimensions : list of str, optional + Dimensions to synthesise. All features (other than an area type) + you would like in the final dataset. If not specified, all + available dimensions will be included. + + Attributes + ---------- + api : census21api.CensusAPI + Client instance to connect to the 2021 Census API. + domain : mbi.Domain + Dictionary-like object defining the domain size of every column + in the synthetic dataset (area type and dimensions). + """ + + def __init__( + self, population_type="UR_HH", area_type=None, dimensions=None + ): + self.population_type = population_type + self.area_type = area_type + self.dimensions = dimensions or DIMENSIONS[self.population_type] + + self.api = CensusAPI() + self.domain = self.get_domain() + + def _get_domain_of_feature(self, feature): + """ + Retrieve the domain for items in a feature of the API. + + Parameters + ---------- + feature : {"area-types", "dimensions"} + Feature of the API from which to call. + + Raises + ------ + ValueError + If `feature` is invalid. + + Returns + ------- + domain : dict + Dictionary containing the domain metadata. Empty if + `feature` is `"area-types"` and `self.area_type` is `None`. + """ + + if feature == "area-types" and self.area_type is None: + return {} + elif feature == "area-types": + items = [self.area_type] + elif feature == "dimensions": + items = self.dimensions + else: + raise ValueError( + "Feature must be one of 'area-types' or 'dimensions', " + f"not '{feature}'" + ) + + metadata = self.api.query_feature( + self.population_type, feature, *items + ) + domain = dict(metadata[["id", "total_count"]].to_dict("split")["data"]) + + return domain + + def get_domain(self): + """ + Retrieve domain metadata from the API. + + Returns + ------- + domain : mbi.Domain + Dictionary-like object defining the domain size of every column + in the synthetic dataset (area type and dimensions). + """ + + area_type_domain = self._get_domain_of_feature("area-types") + dimension_domain = self._get_domain_of_feature("dimensions") + + domain = Domain.fromdict({**area_type_domain, **dimension_domain}) + + return domain + + def get_marginal(self, clique, flatten=True): + """ + Retrieve the marginal table for a clique from the API. + + This function also returns the metadata to "measure" the + marginal in the package that underpins the synthesis, `mbi`. + + Parameters + ---------- + clique : tuple of str + Tuple defining the columns of the clique to be measured. + Should be of the form `(col,)` or `(col1, col2)`. + flatten : bool + Whether the marginal should be flattened or not. Default is + `True` to work with `mbi`. Flattened marginals are NumPy + arrays rather than Pandas series. + + Returns + ------- + marginal : numpy.ndarray or pandas.Series or None + Marginal table if the API call succeeds and `None` if not. + On a success, if `flatten` is `True`, this a flat array. + Otherwise, the indexed series is returned. + """ + + area_type = self.area_type or "nat" + dimensions = [col for col in clique if col != area_type] + if not dimensions: + dimensions = self.dimensions[0:1] + + marginal = self.api.query_table( + self.population_type, area_type, dimensions + ) + + if marginal is not None: + marginal = marginal.groupby(list(clique))["count"].sum() + if flatten is True: + marginal = marginal.to_numpy().flatten() + + return marginal + + def measure(self, cliques): + """ + Measure the marginals of a set of cliques. + + This function returns a list of "measurements" to be passed to + the `mbi` package. Each measurement consists of a sparse + identity matrix, the marginal table, a nominally small float + representing the "noise" added to the marginal, and the clique + associated with the marginal. + + Although we are not applying differential privacy to our tables, + `mbi` requires non-zero noise for each measurement to form the + graphical model. + + If a column pair has been blocked by the API, then their + marginal is `None` and we skip over them. + + We use `dask` to compute these marginals in parallel. + + Parameters + ---------- + cliques : iterable of tuple + The cliques to measure. These cliques should be of the form + `(col,)` or `(col1, col2)`. + + Returns + ------- + measurements : list of tuple + Measurement tuples for each clique. + """ + + tasks = [] + for clique in cliques: + marginal = dask.delayed(self.get_marginal)(clique) + tasks.append(marginal) + + marginals = dask.compute(*tasks) + + measurements = [ + (sparse.eye(marginal.size), marginal, 1, clique) + for clique, marginal in zip(cliques, marginals) + if marginal is not None + ] + + return measurements + + def fit_model(self, measurements, iters=5000): + """ + Fit a graphical model to some measurements. + + Parameters + ---------- + measurements : list of tuple + Measurement tuples associated with some cliques to fit. + iters : int + Number of iterations to use when fitting the model. Default + is 5000. + + Returns + ------- + model : mbi.GraphicalModel + Fitted graphical model. + """ + + engine = FactoredInference(self.domain, iters=iters) + model = engine.estimate(measurements) + + return model + + def _calculate_importance_of_pair(self, interim, pair): + """ + Determine the importance of a column pair with an interim model. + + Importance is defined as the L1 norm between the observed + marginal table for the column pair and that estimated by our + interim model. + + Parameters + ---------- + interim : mbi.GraphicalModel + Interim model based on one-way marginals only. + pair : tuple of str + Column pair to be assessed. + + Returns + ------- + pair : tuple of str + Assessed column pair. + weight : float or None + Importance of the pair given as the L1 norm between the + observed and estimated marginals for the pair. If the API + call fails, this is `None`. + """ + + weight = None + marginal = self.get_marginal(pair) + if marginal is not None: + estimate = interim.project(pair).datavector() + weight = np.linalg.norm(marginal - estimate, 1) + + return weight + + def _calculate_importances(self, interim): + """ + Determine every column pair's importance given an interim model. + + We use `dask` to compute these importances in parallel. + + Parameters + ---------- + interim : mbi.GraphicalModel + Interim model based on one-way marginals only. + + Returns + ------- + weights : dict + Dictionary mapping column pairs to their weight. If a column + pair is blocked by the API, it is skipped. + """ + + pairs = list(itertools.combinations(self.domain.attrs, 2)) + tasks = [] + for pair in pairs: + importance = dask.delayed(self._calculate_importance_of_pair)( + interim, pair + ) + tasks.append(importance) + + importances = dask.compute(*tasks) + + weights = { + pair: importance + for pair, importance in zip(pairs, importances) + if importance is not None + } + + return weights + + def _find_maximum_spanning_tree(self, weights): + """ + Find the maximum spanning tree given a set of edge importances. + + To find the tree, we use Kruskal's algorithm to find the minimum + spanning tree with negative weights. + + Parameters + ---------- + weights : dict + Dictionary mapping edges (column pairs) to their importance. + + Returns + ------- + tree : nx.Graph + Maximum spanning tree of all column pairs. + """ + + graph = nx.Graph() + graph.add_nodes_from(self.domain) + for edge, weight in weights.items(): + graph.add_edge(*edge, weight=-weight) + + tree = nx.minimum_spanning_tree(graph) + + return tree + + def select(self, measurements): + """ + Select the most informative two-way cliques. + + To determine how informative a column pair is, we first create + an interim graphical model from all observed one-way marginals. + Then, each column pair's importance is defined as the L1 + difference between its observed two-way marginal and the + estimated marginal from the interim model. + + With all the importances calculated, we model the column pairs + as a weighted graph where columns are nodes and an edge + represents the importance of the column pair at its endpoints. + In this way, the smallest set of the most informative column + pairs is given as the maximum spanning tree of this graph. + + The selected two-way cliques are the edges of this tree. + + Parameters + ---------- + measurements : list of tuple + One-way marginal measurements with which to fit an interim + graphical model. + + Returns + ------- + cliques : list of tuple + Edges of the maximum spanning tree of our weighted graph. + """ + + interim = self.fit_model(measurements, iters=1000) + weights = self._calculate_importances(interim) + tree = self._find_maximum_spanning_tree(weights) + + return list(tree.edges) + + @staticmethod + def _setup_generate(model, nrows, seed): + """ + Set everything up for the generation of the synthetic data. + + Parameters + ---------- + model : mbi.GraphicalModel + Model from which the synthetic data will be drawn. + nrows : int or None + Number of rows in the synthetic data. Inferred from `model` + if `None`. + seed : int or None + Pseudo-random seed. If `None`, randomness not reproducible. + + Returns + ------- + nrows : int + Number of rows to generate. + prng : dask.array.random.Generator + Pseudo-random number generator. + cliques : list of set + Cliques identified by the graphical model. + order : list of str + Order in which to synthesise the columns. + """ + + nrows = int(model.total) if nrows is None else nrows + prng = da.random.default_rng(seed) + cliques = [set(clique) for clique in model.cliques] + column, *order = model.elimination_order[::-1] + + return nrows, prng, cliques, column, order + + @staticmethod + def _synthesise_column(marginal, nrows, prng, chunksize=1e6): + """ + Sample a column of given length based on a marginal. + + Columns are synthesised to match the distribution of the + marginal very closely. The process for synthesising the column + is as follows: + + 1. Scale the marginal against the total count required, and then + separate its integer and fractional components. + 2. If there are insufficient integer counts, distribute the + additional elements among the integer counts randomly + using the fractional component as a weight. In this way, the + difference between the normalised marginal and the final + counts in the synthetic data will be at most one. + 3. Create an array by repeating the index of the marginal + according to the adjusted integer counts. + 4. Permute the array to give a synthetic column. + + Parameters + ---------- + marginal : np.ndarray + Marginal counts from which to synthesise the column. + nrows : int + Number of elements in the synthesised column. + prng : dask.array.random.Generator + Pseudo-random number generator. We use this to distribute + additional elements in the synthetic column, and to shuffle + its elements after creation. + chunksize : int or float + Target size of a chunk or partition in the column. + + Returns + ------- + column : dask.dataframe.Series + Synthetic column closely matching the distribution of the + marginal. + """ + + marginal *= nrows / marginal.sum() + fractions, integers = np.modf(marginal) + + integers = integers.astype(int) + extra = nrows - integers.sum() + if extra > 0: + idx = prng.choice( + marginal.size, extra, False, fractions / fractions.sum() + ).compute() + integers[idx] += 1 + + uniques = np.arange(integers.size) + repeats = ( + unique * da.ones(shape=count, dtype=int) + for unique, count in zip(uniques, integers) + ) + + values = da.concatenate(repeats).rechunk(chunksize) + column = dd.from_dask_array(prng.permutation(values)) + + return column + + @staticmethod + def _synthesise_column_in_group_by_partition( + partition, clique, column, marginal, prng, chunksize=1e6 + ): + """ + Synthesise a column inside a groupby-apply over the partitions. + + This operation is used for synthesising columns that depend on + those that have already been synthesised. By performing this + synthesis in a group-by operation, we ensure a close matching to + the marginal distribution estimated by the graphical model given + what has already been synthesised. + + Mapping the groupby-apply operation across the partitions allows + even very large datasets (10M+) to be created without + out-of-memory errors. + + Parameters + ---------- + partition : dask.dataframe.DataFrame + Partition of the data frame on which to operate. + clique : list of str + Prerequisite columns by which to group the operation. + column : str + Name of the column to be synthesised. + marginal : np.ndarray + Marginal estimated from the graphical model for the column + and its prerequisites. + prng : dask.array.random.Generator + Pseudo-random number generator. Used to synthesise the + column within groups in this partition. + + Returns + ------- + partition : dask.dataframe.DataFrame + Partition with new synthetic column. + """ + + def synthesise_in_group(group): + """Synthesise a column within a groupby-apply operation.""" + + idx = group.name + group[column] = MST._synthesise_column( + marginal[idx], group.shape[0], prng, chunksize + ) + + return group + + partition = ( + partition.groupby(list(clique)) + .apply(synthesise_in_group) + .reset_index(drop=True) + ) + + return partition + + @staticmethod + def _synthesise_first_column(model, column, nrows, prng): + """ + Sample the first column from the model as a data frame. + + Parameters + ---------- + model : mbi.GraphicalModel + Model from which to synthesise the column. + column : str + Name of the column to synthesise. + nrows : int + Number of rows to generate. + prng : dask.array.random.Generator + Pseudo-random number generator. + + Returns + ------- + data : dask.dataframe.DataFrame + Data frame containing the first synthetic column. + """ + + marginal = model.project([column]).datavector(flatten=False) + data = MST._synthesise_column(marginal, nrows, prng).to_frame( + name=column + ) + + return data + + @staticmethod + def _find_prerequisite_columns(column, cliques, used): + """ + Find the columns that inform the synthesis of a new column. + + Parameters + ---------- + column : str + Name of column to be synthesised. + cliques : list of set + Cliques identified by the graphical model. + used : set of str + Names of columns that have already been synthesised. + + Returns + ------- + prerequisites : tuple of str + All columns needed to synthesise the new column. + """ + + member_of_cliques = [clique for clique in cliques if column in clique] + prerequisites = used.intersection(set.union(*member_of_cliques)) + + return tuple(prerequisites) + + @staticmethod + def generate(model, nrows=None, seed=None): + """ + Generate a synthetic dataset from the estimated model. + + Columns are synthesised in the order determined by the graphical + model. With each column after the first, we search for all the + columns on which it depends according to the model that have + been synthesised already. + + Parameters + ---------- + model : mbi.GraphicalModel + Model from which to draw synthetic data. This model should + be fit to all the marginal tables you care about. + nrows : int, optional + Number of rows in the synthetic dataset. If not specified, + the length of the dataset is inferred from the model. + seed : int, optional + Seed for pseudo-random number generation. If not specified, + the results will not be reproducible. + + Returns + ------- + data : dask.dataframe.DataFrame + Data frame containing the synthetic data. We use Dask to + allow for larger-than-memory datasets. As such, it is lazily + executed. + """ + + nrows, prng, cliques, column, order = MST._setup_generate( + model, nrows, seed + ) + data = MST._synthesise_first_column(model, column, nrows, prng) + used = {column} + + for column in order: + prerequisites = MST._find_prerequisite_columns( + column, cliques, used + ) + used.add(column) + + marginal = model.project(prerequisites + (column,)).datavector( + flatten=False + ) + + if len(prerequisites) >= 1: + data = data.map_partitions( + MST._synthesise_column_in_group_by_partition, + prerequisites, + column, + marginal, + prng, + meta={**data.dtypes, column: int}, + ) + else: + data[column] = MST._synthesise_column(marginal, nrows, prng) + + data = data.repartition("100MB") + + return data diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..2af27a9 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Make the testing modules importable.""" diff --git a/tests/mst/__init__.py b/tests/mst/__init__.py new file mode 100644 index 0000000..ca47d11 --- /dev/null +++ b/tests/mst/__init__.py @@ -0,0 +1 @@ +"""MST-level tests.""" diff --git a/tests/mst/test_fit_model.py b/tests/mst/test_fit_model.py new file mode 100644 index 0000000..fab69e7 --- /dev/null +++ b/tests/mst/test_fit_model.py @@ -0,0 +1,27 @@ +"""Unit test(s) for the model fitting in `centhesus.MST`.""" + +from hypothesis import given +from hypothesis import strategies as st +from mbi import Domain, GraphicalModel +from scipy import sparse + +from ..strategies import mocked_mst, st_single_marginals + + +@given(st_single_marginals(), st.integers(1, 5)) +def test_fit_model(params, iters): + """Test that a model can be fitted to some measurements.""" + + population_type, area_type, dimensions, clique, table = params + domain = Domain.fromdict(table.drop("count", axis=1).nunique().to_dict()) + table = table["count"] + mst = mocked_mst(population_type, area_type, dimensions, domain=domain) + + measurements = [(sparse.eye(table.size), table, 1e-12, clique)] + model = mst.fit_model(measurements, iters) + + assert isinstance(model, GraphicalModel) + assert model.domain == mst.domain + assert model.cliques == [clique] + assert model.elimination_order == list(clique) + assert model.total == table.sum() or 1 diff --git a/tests/mst/test_generate.py b/tests/mst/test_generate.py new file mode 100644 index 0000000..bebc564 --- /dev/null +++ b/tests/mst/test_generate.py @@ -0,0 +1,293 @@ +"""Unit tests for the generation methods of `centhesus.MST`.""" + +from unittest import mock + +import dask +import dask.array as da +import dask.dataframe as dd +import numpy as np +import pandas as pd +from hypothesis import assume, given, settings +from hypothesis import strategies as st +from hypothesis.extra.numpy import arrays + +from centhesus import MST + +from ..strategies import st_existing_new_columns, st_prerequisite_columns + + +@given( + st.floats(1, 100), + st.lists(st.text(), min_size=1, max_size=10), + st.lists(st.tuples(st.text(), st.text()), min_size=1, max_size=5), + st.one_of((st.just(None), st.integers(1, 100))), + st.integers(0, 10), +) +def test_setup_generate(total, elimination_order, cliques_, nrows, seed): + """Test that generation can be set up correctly.""" + + model = mock.MagicMock() + model.total = total + model.elimination_order = elimination_order + model.cliques = cliques_ + + nrows, prng, cliques, column, order = MST._setup_generate( + model, nrows, seed + ) + + assert isinstance(nrows, int) + assert nrows == total or int(model.total) + assert isinstance(prng, da.random.Generator) + assert cliques == [set(clique) for clique in cliques_] + assert column == elimination_order[-1] + assert order == elimination_order[-2::-1] + + +@settings(deadline=None) +@given( + arrays( + float, + st.integers(2, 10), + elements=st.one_of((st.just(0), st.floats(1, 50))), + ), + st.integers(10, 100), +) +def test_synthesise_column(marginal, total): + """Test a column can be synthesised from a marginal.""" + + assume(marginal.sum()) + + prng = da.random.default_rng(0) + column = MST._synthesise_column(marginal, total, prng) + + assert isinstance(column, dd.Series) + assert dask.compute(*column.shape) == (total,) + assert column.dtype == int + + uniques, counts = dask.compute( + *da.unique(column.to_dask_array(lengths=True), return_counts=True) + ) + if len(uniques) == marginal.size: + assert np.array_equal(uniques, np.arange(marginal.size)) + assert np.all(counts - marginal * total / marginal.sum() <= 1) + else: + assert set(uniques).issubset(range(marginal.size)) + assert np.all( + counts - marginal[uniques] * total / marginal[uniques].sum() <= 1 + ) + + +@given(st_existing_new_columns()) +def test_synthesise_column_in_group_by_partition(params): + """Test that a dependent column can be synthesised in groups.""" + + existing, new = params + + num_groups = existing["a"].nunique() + column, prng = "foo", da.random.default_rng(0) + empty_marginal = [[]] * num_groups + + with mock.patch("centhesus.mst.MST._synthesise_column") as synth: + synth.return_value = new + synthetic = MST._synthesise_column_in_group_by_partition( + existing.copy(), + ["a"], + column, + empty_marginal, + prng, + ) + + assert isinstance(synthetic, pd.DataFrame) + assert synthetic.shape[0] == existing.shape[0] + assert synthetic.columns.to_list() == [*existing.columns.to_list(), column] + + assert np.array_equal(synthetic[column], new * num_groups) + + assert synth.call_count == num_groups + for i, call in enumerate(synth.call_args_list): + assert call.args == ([], (existing["a"] == i).sum(), prng, 1e6) + + assert synth.call_count == num_groups + + +@settings(deadline=None) +@given( + arrays( + int, + st.integers(2, 10), + elements=st.integers(0, 50), + ), + st.text(min_size=1), + st.integers(10, 100), +) +def test_synthesise_first_column(values, column, nrows): + """Test that a single column frame can be created.""" + + prng = da.random.default_rng(0) + model = mock.MagicMock() + model.project.return_value.datavector.return_value = "marginal" + + with mock.patch("centhesus.mst.MST._synthesise_column") as synth: + synth.return_value = dd.from_array(values) + first = MST._synthesise_first_column(model, column, nrows, prng) + + assert isinstance(first, dd.DataFrame) + assert first.columns.to_list() == [column] + assert np.array_equal(first[column].compute(), values) + + model.project.assert_called_once_with([column]) + model.project.return_value.datavector.called_once_with(flatten=False) + synth.assert_called_once_with("marginal", nrows, prng) + + +@given(st_prerequisite_columns()) +def test_find_prerequisite_columns(params): + """Test we can find all the columns on which another depends.""" + + column, cliques, used = params + + prerequisites = MST._find_prerequisite_columns(column, cliques, used) + + expected = set( + other + for clique in cliques + for other in clique + if column in clique and other != column and other in used + ) + assert isinstance(prerequisites, tuple) + assert set(prerequisites) == expected + + +@given( + st.integers(1, 100), + st.lists(st.text(), min_size=2, max_size=10, unique=True), +) +def test_generate(nrows, params): + """Test that generation can be executed correctly.""" + + column, *order = params + + prng = da.random.default_rng(0) + + data = mock.MagicMock() + data.dtypes = {"data": "dtypes"} + data.map_partitions.return_value = data + data.repartition.return_value = data + + marginal = mock.MagicMock() + + model = mock.MagicMock() + model.project.return_value.datavector.return_value = marginal + + with mock.patch("centhesus.mst.MST._setup_generate") as setup, mock.patch( + "centhesus.mst.MST._synthesise_first_column" + ) as first, mock.patch( + "centhesus.mst.MST._find_prerequisite_columns" + ) as find, mock.patch( + "centhesus.mst.MST._synthesise_column" + ) as synth: + setup.return_value = (nrows, prng, "cliques", column, order) + first.return_value = data + find.return_value = ("prerequisites",) + synth.return_value = "independent" + + synthetic = MST.generate(model, nrows) + + setup.assert_called_once_with(model, nrows, None) + first.assert_called_once_with(model, column, nrows, prng) + + used = {column} + num_subsequent_columns = len(order) + assert find.call_count == num_subsequent_columns + for call, col in zip(find.call_args_list, order): + assert tuple(call.args[:-1]) == (col, "cliques") + used.add(col) + + assert used == set((column, *order)) + + assert model.project.call_count == num_subsequent_columns + for call, col in zip(model.project.call_args_list, order): + assert call.args == (("prerequisites", col),) + + assert ( + model.project.return_value.datavector.call_count + == num_subsequent_columns + ) + for call in model.project.return_value.datavector.call_args_list: + assert call.args == () + assert call.kwargs == {"flatten": False} + + assert data.map_partitions.call_count == num_subsequent_columns + for call, col in zip(data.map_partitions.call_args_list, order): + assert call.args == ( + MST._synthesise_column_in_group_by_partition, + ("prerequisites",), + col, + marginal, + prng, + ) + assert call.kwargs == {"meta": {"data": "dtypes", col: int}} + + synth.assert_not_called() + + data.repartition.assert_called_once_with("100MB") + + assert synthetic is data + + +@given( + st.integers(1, 100), + st.lists(st.text(), min_size=2, max_size=10, unique=True), +) +def test_generate_with_extra_independents(nrows, params): + """Test generation executes with multiple independent columns.""" + + column, *order = params + + prng = da.random.default_rng(0) + + data = mock.MagicMock() + data.repartition.return_value = data + + model = mock.MagicMock() + marginal = mock.MagicMock() + model.project.return_value.datavector.return_value = marginal + + with mock.patch("centhesus.mst.MST._setup_generate") as setup, mock.patch( + "centhesus.mst.MST._synthesise_first_column" + ) as first, mock.patch( + "centhesus.mst.MST._find_prerequisite_columns" + ) as find, mock.patch( + "centhesus.mst.MST._synthesise_column" + ) as synth: + setup.return_value = (nrows, prng, "cliques", column, order) + first.return_value = data + find.return_value = () + synth.return_value = "independent" + + synthetic = MST.generate(model, nrows) + + setup.assert_called_once_with(model, nrows, None) + first.assert_called_once_with(model, column, nrows, prng) + + num_subsequent_columns = len(order) + assert model.project.call_count == num_subsequent_columns + for call, col in zip(model.project.call_args_list, order): + assert call.args == ((col,),) + + assert ( + model.project.return_value.datavector.call_count + == num_subsequent_columns + ) + for call in model.project.return_value.datavector.call_args_list: + assert call.args == () + assert call.kwargs == {"flatten": False} + + assert synth.call_count == num_subsequent_columns + for call, col in zip(synth.call_args_list, order): + assert call.args == (marginal, nrows, prng) + assert hasattr(data, col) + + data.repartition.assert_called_once_with("100MB") + + assert synthetic is data diff --git a/tests/mst/test_get_domain.py b/tests/mst/test_get_domain.py new file mode 100644 index 0000000..ad35bf3 --- /dev/null +++ b/tests/mst/test_get_domain.py @@ -0,0 +1,94 @@ +"""Unit tests for getting domain in `centhesus.MST`.""" + +import string +from unittest import mock + +import pytest +from hypothesis import given +from hypothesis import strategies as st +from mbi import Domain + +from ..strategies import ( + mocked_mst, + st_api_parameters, + st_feature_metadata_parameters, +) + + +@given(st_feature_metadata_parameters()) +def test_get_domain_of_feature(params): + """Test the domain of a feature can be retrieved correctly.""" + + population_type, area_type, dimensions, feature, metadata = params + + mst = mocked_mst(population_type, area_type, dimensions) + + with mock.patch("centhesus.mst.CensusAPI.query_feature") as query: + query.return_value = metadata + domain = mst._get_domain_of_feature(feature) + + assert isinstance(domain, dict) + + items = [area_type] if feature == "area-types" else dimensions + assert list(domain.keys()) == metadata["id"].to_list() == items + assert list(domain.values()) == metadata["total_count"].to_list() + + query.assert_called_once_with(population_type, feature, *items) + + +@given(st_feature_metadata_parameters()) +def test_get_domain_of_feature_none_area_type(params): + """Test the feature domain getter when area type is None.""" + + population_type, _, dimensions, _, metadata = params + + mst = mocked_mst(population_type, None, dimensions) + + with mock.patch("centhesus.mst.CensusAPI.query_feature") as query: + domain = mst._get_domain_of_feature("area-types") + + assert isinstance(domain, dict) + assert domain == {} + + query.assert_not_called() + + +@given(st_api_parameters(), st.text()) +def test_get_domain_of_feature_raises_error(params, feature): + """Test the domain getter raises an error for invalid features.""" + + mst = mocked_mst(*params) + + with pytest.raises(ValueError, match="^Feature"): + mst._get_domain_of_feature(feature) + + +@given( + st_api_parameters(), + st.dictionaries( + st.text(string.ascii_uppercase, min_size=1), st.integers() + ), + st.dictionaries( + st.text(string.ascii_lowercase, min_size=1), st.integers() + ), +) +def test_get_domain(params, area_type_domain, dimensions_domain): + """Test the domain getter can process metadata correctly.""" + + mst = mocked_mst(*params) + + with mock.patch("centhesus.mst.MST._get_domain_of_feature") as feature: + feature.side_effect = [area_type_domain, dimensions_domain] + domain = mst.get_domain() + + assert isinstance(domain, Domain) + assert domain.attrs == ( + *area_type_domain.keys(), + *dimensions_domain.keys(), + ) + + assert feature.call_count == 2 + assert [call.args for call in feature.call_args_list] == [ + ("area-types",), + ("dimensions",), + ] diff --git a/tests/mst/test_init.py b/tests/mst/test_init.py new file mode 100644 index 0000000..e5f2ec2 --- /dev/null +++ b/tests/mst/test_init.py @@ -0,0 +1,64 @@ +"""Tests for the `centhesus.mst` module.""" + + +from census21api import CensusAPI +from census21api.constants import DIMENSIONS_BY_POPULATION_TYPE +from hypothesis import given + +from centhesus import MST + +from ..strategies import ( + mocked_mst, + st_api_parameters, +) + + +@given(st_api_parameters()) +def test_init(params): + """Test instantiation of the MST class.""" + + population_type, area_type, dimensions = params + + mst = mocked_mst(population_type, area_type, dimensions) + + assert isinstance(mst, MST) + assert mst.population_type == population_type + assert mst.area_type == area_type + assert mst.dimensions == dimensions + + assert isinstance(mst.api, CensusAPI) + assert mst.domain is None + + +@given(st_api_parameters()) +def test_init_none_area_type(params): + """Test instantiation of the MST class when area type is None.""" + + population_type, _, dimensions = params + + mst = mocked_mst(population_type, None, dimensions) + + assert isinstance(mst, MST) + assert mst.population_type == population_type + assert mst.area_type is None + assert mst.dimensions == dimensions + + assert isinstance(mst.api, CensusAPI) + assert mst.domain is None + + +@given(st_api_parameters()) +def test_init_none_dimensions(params): + """Test instantiation of the MST class when dimensions is None.""" + + population_type, area_type, _ = params + + mst = mocked_mst(population_type, area_type, None) + + assert isinstance(mst, MST) + assert mst.population_type == population_type + assert mst.area_type == area_type + assert mst.dimensions == DIMENSIONS_BY_POPULATION_TYPE[population_type] + + assert isinstance(mst.api, CensusAPI) + assert mst.domain is None diff --git a/tests/mst/test_measure.py b/tests/mst/test_measure.py new file mode 100644 index 0000000..ad8cfdd --- /dev/null +++ b/tests/mst/test_measure.py @@ -0,0 +1,86 @@ +"""Unit tests for the measurement methods in `centhesus.MST`.""" + +import platform +from unittest import mock + +import dask +import numpy as np +import pandas as pd +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st +from scipy import sparse + +from ..strategies import mocked_mst, st_single_marginals + + +@given(st_single_marginals(), st.booleans()) +def test_get_marginal(params, flatten): + """Test that a marginal table can be processed correctly.""" + + population_type, area_type, dimensions, clique, table = params + mst = mocked_mst(population_type, area_type, dimensions) + + with mock.patch("centhesus.mst.CensusAPI.query_table") as query: + query.return_value = table + marginal = mst.get_marginal(clique, flatten) + + if flatten: + assert isinstance(marginal, np.ndarray) + assert (marginal == table["count"]).all() + else: + assert isinstance(marginal, pd.Series) + assert marginal.name == "count" + assert (marginal.reset_index() == table).all().all() + + query.assert_called_once() + + +@given(st_single_marginals(), st.booleans()) +def test_get_marginal_failed_call(params, flatten): + """Test that a failed call can be processed still.""" + + population_type, area_type, dimensions, clique, _ = params + mst = mocked_mst(population_type, area_type, dimensions) + + with mock.patch("centhesus.mst.CensusAPI.query_table") as query: + query.return_value = None + marginal = mst.get_marginal(clique, flatten) + + assert marginal is None + + query.assert_called_once() + + +@pytest.mark.skipif( + tuple(map(int, platform.python_version_tuple())) < (3, 9), + reason="Requires Python 3.9+", +) +@settings(deadline=None) +@given(st_single_marginals(), st.integers(1, 5)) +def test_measure(params, num_cliques): + """Test a set of cliques can be measured.""" + + population_type, area_type, dimensions, clique, table = params + mst = mocked_mst(population_type, area_type, dimensions) + + with mock.patch( + "centhesus.mst.MST.get_marginal" + ) as get_marginal, dask.config.set(scheduler="synchronous"): + get_marginal.return_value = table + measurements = mst.measure([clique] * num_cliques) + + assert isinstance(measurements, list) + assert len(measurements) == num_cliques + + for measurement in measurements: + assert isinstance(measurement, tuple) + assert len(measurement) == 4 + + ident, marg, noise, cliq = measurement + assert isinstance(ident, sparse._dia.dia_matrix) + assert ident.shape == (marg.size,) * 2 + assert ident.sum() == marg.size + assert marg.equals(table) + assert noise == 1 + assert cliq == clique diff --git a/tests/mst/test_select.py b/tests/mst/test_select.py new file mode 100644 index 0000000..004c8d3 --- /dev/null +++ b/tests/mst/test_select.py @@ -0,0 +1,135 @@ +"""Unit tests for the selection methods in `centhesus.MST`.""" + +import itertools +import platform +from unittest import mock + +import dask +import networkx as nx +import pytest +from hypothesis import given, settings + +from ..strategies import ( + mocked_mst, + st_importances, + st_single_marginals, + st_subgraphs, +) + + +@given(st_single_marginals(kind="pair")) +def test_calculate_importance_of_pair(params): + """Test the importance of a column pair can be calculated.""" + + population_type, area_type, dimensions, clique, table = params + table = table["count"] + mst = mocked_mst(population_type, area_type, dimensions) + + interim = mock.MagicMock() + interim.project.return_value.datavector.return_value = table.sample( + frac=1.0 + ).reset_index(drop=True) + with mock.patch("centhesus.mst.MST.get_marginal") as get_marginal: + get_marginal.return_value = table + weight = mst._calculate_importance_of_pair(interim, clique) + + assert isinstance(weight, float) + assert weight >= 0 + + interim.project.assert_called_once_with(clique) + interim.project.return_value.datavector.assert_called_once_with() + get_marginal.assert_called_once_with(clique) + + +@given(st_single_marginals(kind="pair")) +def test_calculate_importance_of_pair_failed_call(params): + """Test that a failed call doesn't stop importance processing.""" + + population_type, area_type, dimensions, clique, _ = params + mst = mocked_mst(population_type, area_type, dimensions) + + interim = mock.MagicMock() + with mock.patch("centhesus.mst.MST.get_marginal") as get_marginal: + get_marginal.return_value = None + weight = mst._calculate_importance_of_pair(interim, clique) + + assert weight is None + + interim.project.assert_not_called() + interim.project.return_value.datavector.assert_not_called() + get_marginal.assert_called_once_with(clique) + + +@pytest.mark.skipif( + tuple(map(int, platform.python_version_tuple())) < (3, 9), + reason="Requires Python 3.9+", +) +@settings(deadline=None) +@given(st_importances()) +def test_calculate_importances(params): + """Test that a set of importances can be calculated.""" + + population_type, area_type, dimensions, domain, importances = params + mst = mocked_mst(population_type, area_type, dimensions, domain=domain) + + with mock.patch( + "centhesus.mst.MST._calculate_importance_of_pair" + ) as calc, dask.config.set(scheduler="synchronous"): + calc.side_effect = importances + weights = mst._calculate_importances("interim") + + pairs = list(itertools.combinations(domain, 2)) + calc.call_count == len(pairs) + call_args = [call.args for call in calc.call_args_list] + assert set(call_args) == set(("interim", pair) for pair in pairs) + + assert isinstance(weights, dict) + assert set(weights.keys()) == set(pairs) + + pairs_execution_order = [pair for _, pair in call_args] + for pair, importance in zip(pairs_execution_order, importances): + assert weights[pair] == importance + + +@given(st_importances()) +def test_find_maximum_spanning_tree(params): + """Test an MST can be found from a set of importances.""" + + *api_params, domain, importances = params + mst = mocked_mst(*api_params, domain=domain) + weights = dict(zip(itertools.combinations(domain, 2), importances)) + + tree = mst._find_maximum_spanning_tree(weights) + + assert isinstance(tree, nx.Graph) + assert set(tree.nodes) == set(domain) + assert set(tree.edges).issubset(weights.keys()) + for edge in tree.edges: + assert tree.edges[edge]["weight"] == -weights[edge] + + +@given(st_subgraphs()) +def test_select(params): + """Test that a set of two-way cliques can be found correctly.""" + + *api_params, domain, tree = params + mst = mocked_mst(*api_params, domain=domain) + + with mock.patch("centhesus.mst.MST.fit_model") as fit, mock.patch( + "centhesus.mst.MST._calculate_importances" + ) as calc, mock.patch( + "centhesus.mst.MST._find_maximum_spanning_tree" + ) as find: + fit.return_value = "interim" + calc.return_value = "weights" + find.return_value = tree + cliques = mst.select("measurements") + + possible_edges = [set(pair) for pair in itertools.combinations(domain, 2)] + assert isinstance(cliques, list) + for clique in cliques: + assert set(clique) in possible_edges + + fit.assert_called_once_with("measurements", iters=1000) + calc.assert_called_once_with("interim") + find.assert_called_once_with("weights") diff --git a/tests/strategies.py b/tests/strategies.py new file mode 100644 index 0000000..07f7a47 --- /dev/null +++ b/tests/strategies.py @@ -0,0 +1,200 @@ +"""Custom strategies for testing the package.""" + +import itertools +import math +from unittest import mock + +import networkx as nx +import numpy as np +import pandas as pd +from census21api.constants import ( + AREA_TYPES_BY_POPULATION_TYPE, + DIMENSIONS_BY_POPULATION_TYPE, + POPULATION_TYPES, +) +from hypothesis import assume +from hypothesis import strategies as st +from mbi import Domain + +from centhesus import MST + + +def mocked_mst(population_type, area_type, dimensions, domain=None): + """Create an instance of MST with mocked `get_domain`.""" + + with mock.patch("centhesus.mst.MST.get_domain") as get_domain: + get_domain.return_value = domain + mst = MST(population_type, area_type, dimensions) + + get_domain.assert_called_once_with() + + return mst + + +@st.composite +def st_api_parameters(draw): + """Create a valid set of Census API parameters.""" + + population_type = draw(st.sampled_from(POPULATION_TYPES)) + area_type = draw( + st.sampled_from(AREA_TYPES_BY_POPULATION_TYPE[population_type]), + ) + dimensions = draw( + st.sets( + st.sampled_from(DIMENSIONS_BY_POPULATION_TYPE[population_type]), + min_size=1, + ).map(sorted), + ) + + return population_type, area_type, dimensions + + +@st.composite +def st_feature_metadata_parameters(draw): + """Create a parameter set and feature metadata for a test.""" + + population_type, area_type, dimensions = draw(st_api_parameters()) + + feature = draw(st.sampled_from(("area-types", "dimensions"))) + items = [area_type] if feature == "area-types" else dimensions + metadata = pd.DataFrame( + ((item, draw(st.integers())) for item in items), + columns=("id", "total_count"), + ) + + return population_type, area_type, dimensions, feature, metadata + + +@st.composite +def st_single_marginals(draw, kind=None): + """Create a marginal table and its parameters for a test.""" + + population_type, area_type, dimensions = draw(st_api_parameters()) + + min_size, max_size = 1, 2 + if kind == "single": + max_size = 1 + if kind == "pair": + min_size = 2 + + clique = draw( + st.sets( + st.sampled_from((area_type, *dimensions)), + min_size=min_size, + max_size=max_size, + ).map(tuple) + ) + + num_uniques = [draw(st.integers(2, 5)) for _ in clique] + num_rows = int(np.prod(num_uniques)) + counts = draw( + st.lists(st.integers(0, 100), min_size=num_rows, max_size=num_rows) + ) + + marginal = pd.DataFrame( + itertools.product(*(range(num_unique) for num_unique in num_uniques)), + columns=clique, + ) + marginal["count"] = counts + + return population_type, area_type, dimensions, clique, marginal + + +@st.composite +def st_domains(draw): + """Create a domain and its parameters for a test.""" + + population_type, area_type, dimensions = draw(st_api_parameters()) + + num = len(dimensions) + 1 + sizes = draw(st.lists(st.integers(2, 10), min_size=num, max_size=num)) + domain = Domain.fromdict(dict(zip((area_type, *dimensions), sizes))) + + return population_type, area_type, dimensions, domain + + +@st.composite +def st_importances(draw): + """Create a domain and set of importances for a test.""" + + population_type, area_type, dimensions, domain = draw(st_domains()) + + num = len(domain) + importances = draw( + st.lists( + st.floats(max_value=0, allow_infinity=False, allow_nan=False), + min_size=math.comb(num, 2), + max_size=math.comb(num, 2), + ) + ) + + return population_type, area_type, dimensions, domain, importances + + +@st.composite +def st_subgraphs(draw): + """Create a subgraph and its parameters for a test.""" + + population_type, area_type, dimensions, domain = draw(st_domains()) + + edges = draw( + st.sets(st.sampled_from(list(itertools.combinations(domain, 2)))) + ) + graph = nx.Graph() + graph.add_edges_from(edges) + + return population_type, area_type, dimensions, domain, graph + + +@st.composite +def st_existing_new_columns(draw): + """Create an existing column and a new one for a test.""" + + num_groups = draw(st.integers(1, 3)) + num_rows_in_group = draw(st.integers(10, 50)) + existing = pd.DataFrame( + {"a": [i for i in range(num_groups) for _ in range(num_rows_in_group)]} + ) + + new = draw( + st.lists( + st.integers(0, 3), + min_size=num_rows_in_group, + max_size=num_rows_in_group, + ) + ) + + return existing, new + + +@st.composite +def st_prerequisite_columns(draw): + """Create a column, set of cliques and a used set for a test.""" + + columns = draw( + st.sets( + st.sampled_from(DIMENSIONS_BY_POPULATION_TYPE["UR_HH"]), min_size=2 + ).map(list) + ) + column = draw(st.sampled_from(columns)) + + combinations = [ + *itertools.combinations(columns, 2), + *itertools.combinations(columns, 3), + ] + + cliques = draw( + st.lists( + st.sampled_from(combinations).map(set), min_size=len(columns) - 1 + ) + ) + assume(any(column in clique for clique in cliques)) + + used = draw( + st.sets( + st.sampled_from([col for col in columns if col != column]), + min_size=1, + ) + ) + + return column, cliques, used