diff --git a/spinoffs/autobnn/autobnn/BUILD b/spinoffs/autobnn/autobnn/BUILD index 669f3d936d..79e53539e2 100644 --- a/spinoffs/autobnn/autobnn/BUILD +++ b/spinoffs/autobnn/autobnn/BUILD @@ -118,7 +118,7 @@ py_test( ":estimators", ":kernels", ":operators", - # tensorflow_probability/python/internal:test_util dep, + # absl/testing:absltest dep, ], ) @@ -271,11 +271,11 @@ py_test( ":operators", ":training_util", ":util", + # absl/testing:absltest dep, # chex dep, # google/protobuf:use_fast_cpp_protos dep, # jax dep, # numpy dep, - # tensorflow_probability/python/internal:test_util dep, ], ) @@ -300,6 +300,6 @@ py_test( # google/protobuf:use_fast_cpp_protos dep, # jax dep, # numpy dep, - # tensorflow_probability/python/internal:test_util dep, + # tensorflow_probability/python/internal:test_util.jax dep, ], ) diff --git a/spinoffs/autobnn/autobnn/estimators_test.py b/spinoffs/autobnn/autobnn/estimators_test.py index 6fc50456f8..c63c157e2f 100644 --- a/spinoffs/autobnn/autobnn/estimators_test.py +++ b/spinoffs/autobnn/autobnn/estimators_test.py @@ -16,14 +16,15 @@ import jax import numpy as np -from tensorflow_probability.python.internal import test_util from autobnn import estimators from autobnn import kernels from autobnn import operators from autobnn import util +from absl.testing import absltest -class AutoBNNTest(test_util.TestCase): + +class AutoBNNTest(absltest.TestCase): def test_train_map(self): seed = jax.random.PRNGKey(20231018) @@ -114,7 +115,7 @@ def test_summary(self): autobnn.fit(x_train, y_train) summary_lines = autobnn.summary().split('\n') - self.assertEqual(len(summary_lines), 8, f'Unexpected {len(summary_lines)=}') + self.assertLen(summary_lines, 8, f'Unexpected {len(summary_lines)=}') for line in summary_lines: self.assertRegex( @@ -126,4 +127,4 @@ def test_summary(self): if __name__ == '__main__': - test_util.main() + absltest.main() diff --git a/spinoffs/autobnn/autobnn/training_util_test.py b/spinoffs/autobnn/autobnn/training_util_test.py index f139a8468d..9bee487f72 100644 --- a/spinoffs/autobnn/autobnn/training_util_test.py +++ b/spinoffs/autobnn/autobnn/training_util_test.py @@ -18,14 +18,15 @@ import jax import jax.numpy as jnp import numpy as np -from tensorflow_probability.python.internal import test_util from autobnn import kernels from autobnn import operators from autobnn import training_util from autobnn import util +from absl.testing import absltest -class TrainingUtilTest(test_util.TestCase): + +class TrainingUtilTest(absltest.TestCase): def test__filter_stuck_chains_doesnt_overfilter(self): noise_scale = 0.001 * np.random.randn(64, 100, 1) @@ -210,4 +211,4 @@ def _init(seed): if __name__ == '__main__': - test_util.main() + absltest.main() diff --git a/spinoffs/autobnn/autobnn/util_test.py b/spinoffs/autobnn/autobnn/util_test.py index 491adb290f..fbfaeb3e36 100644 --- a/spinoffs/autobnn/autobnn/util_test.py +++ b/spinoffs/autobnn/autobnn/util_test.py @@ -17,9 +17,9 @@ import jax import jax.numpy as jnp import numpy as np -from tensorflow_probability.python.internal import test_util from autobnn import kernels from autobnn import util +from tensorflow_probability.substrates.jax.internal import test_util class UtilTest(test_util.TestCase): @@ -58,14 +58,19 @@ def test_transform(self): p = bnn.init(seed, jnp.ones((1, 10), dtype=jnp.float32)) # Softplus(low=0.2) bijector - self.assertEqual(0.2 + jax.nn.softplus(p['params']['noise_scale']), - transform(p)['params']['noise_scale']) - self.assertEqual(jnp.exp(p['params']['amplitude']), - transform(p)['params']['amplitude']) + self.assertEqual( + 0.2 + jax.nn.softplus(p['params']['noise_scale']), + transform(p)['params']['noise_scale'], + ) + self.assertEqual( + jnp.exp(p['params']['amplitude']), transform(p)['params']['amplitude'] + ) # Identity bijector - self.assertAllEqual(p['params']['dense2']['kernel'], - transform(p)['params']['dense2']['kernel']) + self.assertAllEqual( + p['params']['dense2']['kernel'], + transform(p)['params']['dense2']['kernel'], + ) if __name__ == '__main__':