Skip to content

Commit

Permalink
Allow unit tests to disable 64 bit precision mode in jax.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 550617289
  • Loading branch information
ursk authored and jburnim committed Jul 28, 2023
1 parent 3f42739 commit 714d547
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tensorflow_probability/python/internal/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2010,9 +2010,9 @@ def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name
return names


def main(jax_mode=JAX_MODE):
def main(jax_mode=JAX_MODE, jax_enable_x64=True):
"""Test main function that injects a custom loader."""
if jax_mode:
if jax_mode and jax_enable_x64:
from jax.config import config # pylint: disable=g-import-not-at-top
config.update('jax_enable_x64', True)

Expand Down

0 comments on commit 714d547

Please sign in to comment.