Skip to content

Commit

Permalink
Don't use test_util when it isn't needed, and when it is needed, use the
Browse files Browse the repository at this point in the history
jax substrate version.

PiperOrigin-RevId: 620044472
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Mar 28, 2024
1 parent ae05b63 commit 6e86fb2
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 17 deletions.
6 changes: 3 additions & 3 deletions spinoffs/autobnn/autobnn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ py_test(
":estimators",
":kernels",
":operators",
# tensorflow_probability/python/internal:test_util dep,
# absl/testing:absltest dep,
],
)

Expand Down Expand Up @@ -271,11 +271,11 @@ py_test(
":operators",
":training_util",
":util",
# absl/testing:absltest dep,
# chex dep,
# google/protobuf:use_fast_cpp_protos dep,
# jax dep,
# numpy dep,
# tensorflow_probability/python/internal:test_util dep,
],
)

Expand All @@ -300,6 +300,6 @@ py_test(
# google/protobuf:use_fast_cpp_protos dep,
# jax dep,
# numpy dep,
# tensorflow_probability/python/internal:test_util dep,
# tensorflow_probability/python/internal:test_util.jax dep,
],
)
9 changes: 5 additions & 4 deletions spinoffs/autobnn/autobnn/estimators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

import jax
import numpy as np
from tensorflow_probability.python.internal import test_util
from autobnn import estimators
from autobnn import kernels
from autobnn import operators
from autobnn import util

from absl.testing import absltest

class AutoBNNTest(test_util.TestCase):

class AutoBNNTest(absltest.TestCase):

def test_train_map(self):
seed = jax.random.PRNGKey(20231018)
Expand Down Expand Up @@ -114,7 +115,7 @@ def test_summary(self):
autobnn.fit(x_train, y_train)
summary_lines = autobnn.summary().split('\n')

self.assertEqual(len(summary_lines), 8, f'Unexpected {len(summary_lines)=}')
self.assertLen(summary_lines, 8, f'Unexpected {len(summary_lines)=}')

for line in summary_lines:
self.assertRegex(
Expand All @@ -126,4 +127,4 @@ def test_summary(self):


if __name__ == '__main__':
test_util.main()
absltest.main()
7 changes: 4 additions & 3 deletions spinoffs/autobnn/autobnn/training_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
import jax
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.python.internal import test_util
from autobnn import kernels
from autobnn import operators
from autobnn import training_util
from autobnn import util

from absl.testing import absltest

class TrainingUtilTest(test_util.TestCase):

class TrainingUtilTest(absltest.TestCase):

def test__filter_stuck_chains_doesnt_overfilter(self):
noise_scale = 0.001 * np.random.randn(64, 100, 1)
Expand Down Expand Up @@ -210,4 +211,4 @@ def _init(seed):


if __name__ == '__main__':
test_util.main()
absltest.main()
19 changes: 12 additions & 7 deletions spinoffs/autobnn/autobnn/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import jax
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.python.internal import test_util
from autobnn import kernels
from autobnn import util
from tensorflow_probability.substrates.jax.internal import test_util


class UtilTest(test_util.TestCase):
Expand Down Expand Up @@ -58,14 +58,19 @@ def test_transform(self):
p = bnn.init(seed, jnp.ones((1, 10), dtype=jnp.float32))

# Softplus(low=0.2) bijector
self.assertEqual(0.2 + jax.nn.softplus(p['params']['noise_scale']),
transform(p)['params']['noise_scale'])
self.assertEqual(jnp.exp(p['params']['amplitude']),
transform(p)['params']['amplitude'])
self.assertEqual(
0.2 + jax.nn.softplus(p['params']['noise_scale']),
transform(p)['params']['noise_scale'],
)
self.assertEqual(
jnp.exp(p['params']['amplitude']), transform(p)['params']['amplitude']
)

# Identity bijector
self.assertAllEqual(p['params']['dense2']['kernel'],
transform(p)['params']['dense2']['kernel'])
self.assertAllEqual(
p['params']['dense2']['kernel'],
transform(p)['params']['dense2']['kernel'],
)


if __name__ == '__main__':
Expand Down

0 comments on commit 6e86fb2

Please sign in to comment.