From 32ea239ee3e746ef07eed23f73f3d2941e2d4598 Mon Sep 17 00:00:00 2001 From: siege Date: Thu, 20 Jun 2024 11:29:04 -0700 Subject: [PATCH] Correctly detect when to import tf_keras rather than tf.keras. PiperOrigin-RevId: 645101863 --- tensorflow_probability/python/internal/tf_keras.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tensorflow_probability/python/internal/tf_keras.py b/tensorflow_probability/python/internal/tf_keras.py index 5f1cdf4cff..61a5a2755b 100644 --- a/tensorflow_probability/python/internal/tf_keras.py +++ b/tensorflow_probability/python/internal/tf_keras.py @@ -20,8 +20,13 @@ # pylint: disable=g-import-not-at-top # pylint: disable=unused-import # pylint: disable=wildcard-import -_keras_version_fn = getattr(tf.keras, "version", None) -if _keras_version_fn and _keras_version_fn().startswith("3."): +try: + _keras_version_fn = getattr(tf.keras, "version", None) + _use_tf_keras = _keras_version_fn and _keras_version_fn().startswith("3.") + del _keras_version_fn +except ImportError: + _use_tf_keras = True +if _use_tf_keras: from tf_keras import * from tf_keras import __internal__ import tf_keras.api._v1.keras.__internal__.legacy.layers as tf1_layers @@ -35,4 +40,4 @@ del tf1 del tf -del _keras_version_fn +del _use_tf_keras