From 351148e35a82a68efcf705ac31211357c48d8fae Mon Sep 17 00:00:00 2001 From: danielward27 Date: Thu, 13 Feb 2025 09:35:40 +0000 Subject: [PATCH] Fix numerical inverse log det --- flowjax/bijections/utils.py | 2 +- tests/test_bijections/test_bijections.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flowjax/bijections/utils.py b/flowjax/bijections/utils.py index 50f26a0..5f46e5d 100644 --- a/flowjax/bijections/utils.py +++ b/flowjax/bijections/utils.py @@ -299,4 +299,4 @@ def transform_and_log_det(self, x, condition=None): def inverse_and_log_det(self, y, condition=None): x = self.inverter(self.bijection, y, condition) _, log_det = self.bijection.transform_and_log_det(x, condition) - return x, log_det + return x, -log_det diff --git a/tests/test_bijections/test_bijections.py b/tests/test_bijections/test_bijections.py index 24de4d9..46cc3d8 100644 --- a/tests/test_bijections/test_bijections.py +++ b/tests/test_bijections/test_bijections.py @@ -212,7 +212,7 @@ cond_shape=(), ), "NumericalInverse": lambda: NumericalInverse( - Affine(5), + Affine(5, 2), root_finder_to_inverter( partial(bisection_search, lower=-1, upper=1, atol=1e-7), ),