Skip to content

Commit

Permalink
Correctly detect when to import tf_keras rather than tf.keras.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 645101863
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Jun 20, 2024
1 parent abff241 commit 32ea239
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tensorflow_probability/python/internal/tf_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@
# pylint: disable=g-import-not-at-top
# pylint: disable=unused-import
# pylint: disable=wildcard-import
_keras_version_fn = getattr(tf.keras, "version", None)
if _keras_version_fn and _keras_version_fn().startswith("3."):
try:
_keras_version_fn = getattr(tf.keras, "version", None)
_use_tf_keras = _keras_version_fn and _keras_version_fn().startswith("3.")
del _keras_version_fn
except ImportError:
_use_tf_keras = True
if _use_tf_keras:
from tf_keras import *
from tf_keras import __internal__
import tf_keras.api._v1.keras.__internal__.legacy.layers as tf1_layers
Expand All @@ -35,4 +40,4 @@
del tf1

del tf
del _keras_version_fn
del _use_tf_keras

0 comments on commit 32ea239

Please sign in to comment.