Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: Support non-numpy array backends #886

Draft
wants to merge 29 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1331d49
FEAT: enable backend switching for base gravitational-wave transient …
ColmTalbot Oct 25, 2024
caf20c8
FEAT: support multiband and relative binning likelihoods
ColmTalbot Oct 25, 2024
924f750
FEAT: make more conversions backend agnostic
ColmTalbot Oct 26, 2024
c574a92
FEAT: use more normal conversions
ColmTalbot Oct 28, 2024
250dff2
FEAT: move backend switching code to bilby
ColmTalbot Nov 13, 2024
fcf5967
FEAT: make core prior backend agnostic
ColmTalbot Nov 14, 2024
1b85785
FEAT: make non-numpy arrays serializable
ColmTalbot Nov 14, 2024
3b9162b
BUG: fix some array conversion methods
ColmTalbot Nov 14, 2024
acb9a53
MAINT: remove unneeded requirement
ColmTalbot Nov 14, 2024
c5b8ecf
MAINT: add compat packages
ColmTalbot Nov 14, 2024
420519b
Merge branch 'bilby-dev:main' into bilback
ColmTalbot Dec 11, 2024
b853e2d
DEV: some more prior agnosticism
ColmTalbot Dec 11, 2024
2e46c9b
TEST: make all prior tests run
ColmTalbot Dec 12, 2024
7225785
Merge branch 'bilby-dev:main' into bilback
ColmTalbot Jan 7, 2025
3ed1116
DEV: move some jax functionality to compat
ColmTalbot Jan 25, 2025
faf37f7
REFACTOR: use array backend for ln_i0
ColmTalbot Jan 25, 2025
1177fc4
make distance marginalizatio backend transparent
ColmTalbot Jan 25, 2025
dd6028a
DEV: some more prior dict array refactoring
ColmTalbot Jan 25, 2025
a1b463c
fix jax logic for distance marginalization
ColmTalbot Jan 29, 2025
71c4f6a
improve efficiency of setting up multibanding
ColmTalbot Jan 29, 2025
f4abda4
make high-dimensional gaussians jax compatible
ColmTalbot Jan 29, 2025
69635ac
make cubic spline calibration work with jax backend
ColmTalbot Jan 30, 2025
138204e
BUG: fix linspace calls
ColmTalbot Feb 4, 2025
8e247b8
ENH: fix bottleneck in relative binning for JAX
ColmTalbot Feb 4, 2025
c1bb41d
ENH: make interpolated prior backend friendly
ColmTalbot Feb 4, 2025
64cdb4d
REFACTOR: refactor backend-specific interpolation code
ColmTalbot Feb 5, 2025
3465d56
ENH: make sine gaussian model backend independent
ColmTalbot Feb 5, 2025
7c0f0c6
ENH: make roq likelihood backend independent
ColmTalbot Feb 5, 2025
86331f0
BUG: fix roq slicing
ColmTalbot Feb 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions .github/workflows/basic-install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
# disable windows build test as bilby_cython is currently broken there
os: [ubuntu-latest, macos-latest]
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -48,8 +47,8 @@ jobs:
python -c "import bilby.hyper"
python -c "import cli_bilby"
python test/import_test.py
# - if: ${{ matrix.os != "windows-latest" }}
# run: |
# for script in $(pip show -f bilby | grep "bin\/" | xargs -I {} basename {}); do
# ${script} --help;
# done
- if: ${{ matrix.os != "windows-latest" }}
run: |
for script in $(pip show -f bilby | grep "bin\/" | xargs -I {} basename {}); do
${script} --help;
done
Empty file added bilby/compat/__init__.py
Empty file.
73 changes: 73 additions & 0 deletions bilby/compat/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from functools import partial

import jax
import jax.numpy as jnp
from ..core.likelihood import Likelihood


def generic_bilby_likelihood_function(likelihood, parameters, use_ratio=True):
"""
A wrapper to allow a :code:`Bilby` likelihood to be used with :code:`jax`.

Parameters
==========
likelihood: bilby.core.likelihood.Likelihood
The likelihood to evaluate.
parameters: dict
The parameters to evaluate the likelihood at.
use_ratio: bool, optional
Whether to evaluate the likelihood ratio or the full likelihood.
Default is :code:`True`.
"""
parameters = {k: jnp.array(v) for k, v in parameters.items()}
likelihood.parameters.update(parameters)
if use_ratio:
return likelihood.log_likelihood_ratio()
else:
return likelihood.log_likelihood()


class JittedLikelihood(Likelihood):
"""
A wrapper to just-in-time compile a :code:`Bilby` likelihood for use with :code:`jax`.

.. note::

This is currently hardcoded to return the log likelihood ratio, regardless of
the input.

Parameters
==========
likelihood: bilby.core.likelihood.Likelihood
The likelihood to wrap.
likelihood_func: callable, optional
The function to use to evaluate the likelihood. Default is
:code:`generic_bilby_likelihood_function`. This function should take the
likelihood and parameters as arguments along with additional keyword arguments.
kwargs: dict, optional
Additional keyword arguments to pass to the likelihood function.
"""

def __init__(
self,
likelihood,
likelihood_func=generic_bilby_likelihood_function,
kwargs=None,
cast_to_float=True,
):
if kwargs is None:
kwargs = dict()
self.kwargs = kwargs
self._likelihood = likelihood
self.likelihood_func = jax.jit(partial(likelihood_func, likelihood))
self.cast_to_float = cast_to_float
super().__init__(dict())

def __getattr__(self, name):
return getattr(self._likelihood, name)

def log_likelihood_ratio(self):
ln_l = jnp.nan_to_num(self.likelihood_func(self.parameters, **self.kwargs))
if self.cast_to_float:
ln_l = float(ln_l)
return ln_l
6 changes: 6 additions & 0 deletions bilby/compat/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import Union
import numpy as np

Real = Union[float, int]
ArrayLike = Union[np.ndarray, list, tuple]

34 changes: 34 additions & 0 deletions bilby/compat/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np
from scipy._lib._array_api import array_namespace

__all__ = ["array_module", "promote_to_array"]


def array_module(arr):
if arr.__class__.__module__ == "builtins":
return np
else:
return array_namespace(arr)


def promote_to_array(args, backend, skip=None):
if skip is None:
skip = len(args)
else:
skip = len(args) - skip
if backend.__name__ != "numpy":
args = tuple(backend.array(arg) for arg in args[:skip]) + args[skip:]
return args


def xp_wrap(func):

def wrapped(self, *args, **kwargs):
if "xp" not in kwargs:
try:
kwargs["xp"] = array_module(*args)
except TypeError:
pass
return func(self, *args, **kwargs)

return wrapped
44 changes: 30 additions & 14 deletions bilby/core/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from scipy.stats import multivariate_normal

from .utils import infer_parameters_from_function, infer_args_from_function_except_n_args
from ..compat.utils import array_module


class Likelihood(object):
Expand Down Expand Up @@ -465,10 +466,16 @@ class AnalyticalMultidimensionalCovariantGaussian(Likelihood):
"""

def __init__(self, mean, cov):
self.cov = np.atleast_2d(cov)
self.mean = np.atleast_1d(mean)
self.sigma = np.sqrt(np.diag(self.cov))
self.pdf = multivariate_normal(mean=self.mean, cov=self.cov)
xp = array_module(cov)
self.cov = xp.atleast_2d(cov)
self.mean = xp.atleast_1d(mean)
self.sigma = xp.sqrt(np.diag(self.cov))
if xp == np:
self.logpdf = multivariate_normal(mean=self.mean, cov=self.cov).logpdf
else:
from functools import partial
from jax.scipy.stats.multivariate_normal import logpdf
self.logpdf = partial(logpdf, mean=self.mean, cov=self.cov)
parameters = {"x{0}".format(i): 0 for i in range(self.dim)}
super(AnalyticalMultidimensionalCovariantGaussian, self).__init__(parameters=parameters)

Expand All @@ -477,8 +484,9 @@ def dim(self):
return len(self.cov[0])

def log_likelihood(self):
x = np.array([self.parameters["x{0}".format(i)] for i in range(self.dim)])
return self.pdf.logpdf(x)
xp = array_module(self.cov)
x = xp.array([self.parameters["x{0}".format(i)] for i in range(self.dim)])
return self.logpdf(x)


class AnalyticalMultidimensionalBimodalCovariantGaussian(Likelihood):
Expand All @@ -496,12 +504,19 @@ class AnalyticalMultidimensionalBimodalCovariantGaussian(Likelihood):
"""

def __init__(self, mean_1, mean_2, cov):
self.cov = np.atleast_2d(cov)
self.sigma = np.sqrt(np.diag(self.cov))
self.mean_1 = np.atleast_1d(mean_1)
self.mean_2 = np.atleast_1d(mean_2)
self.pdf_1 = multivariate_normal(mean=self.mean_1, cov=self.cov)
self.pdf_2 = multivariate_normal(mean=self.mean_2, cov=self.cov)
xp = array_module(cov)
self.cov = xp.atleast_2d(cov)
self.sigma = xp.sqrt(np.diag(self.cov))
self.mean_1 = xp.atleast_1d(mean_1)
self.mean_2 = xp.atleast_1d(mean_2)
if xp == np:
self.logpdf_1 = multivariate_normal(mean=self.mean_1, cov=self.cov).logpdf
self.logpdf_2 = multivariate_normal(mean=self.mean_2, cov=self.cov).logpdf
else:
from functools import partial
from jax.scipy.stats.multivariate_normal import logpdf
self.logpdf_1 = partial(logpdf, mean=self.mean_1, cov=self.cov)
self.logpdf_2 = partial(logpdf, mean=self.mean_2, cov=self.cov)
parameters = {"x{0}".format(i): 0 for i in range(self.dim)}
super(AnalyticalMultidimensionalBimodalCovariantGaussian, self).__init__(parameters=parameters)

Expand All @@ -510,8 +525,9 @@ def dim(self):
return len(self.cov[0])

def log_likelihood(self):
x = np.array([self.parameters["x{0}".format(i)] for i in range(self.dim)])
return -np.log(2) + np.logaddexp(self.pdf_1.logpdf(x), self.pdf_2.logpdf(x))
xp = array_module(self.cov)
x = xp.array([self.parameters["x{0}".format(i)] for i in range(self.dim)])
return -xp.log(2) + xp.logaddexp(self.logpdf_1(x), self.logpdf_2(x))


class JointLikelihood(Likelihood):
Expand Down
Loading