Skip to content

Commit

Permalink
[numpy] Fix users of NumPy APIs that are removed in NumPy 2.0.
Browse files Browse the repository at this point in the history
This change migrates users of APIs removed in NumPy 2.0 to their recommended replacements (https://numpy.org/devdocs/numpy_2_0_migration_guide.html).

PiperOrigin-RevId: 656419848
  • Loading branch information
hawkinsp authored and tensorflower-gardener committed Jul 26, 2024
1 parent 64d1946 commit 9b07a2e
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 6 deletions.
8 changes: 6 additions & 2 deletions tensorflow_probability/python/bijectors/bijector_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ def assert_scalar_congruency(bijector,
ten_x_pts = np.linspace(lower_x, upper_x, num=10).astype(np.float32)
if bijector.dtype is not None:
ten_x_pts = ten_x_pts.astype(dtype_util.as_numpy_dtype(bijector.dtype))
lower_x = np.cast[dtype_util.as_numpy_dtype(bijector.dtype)](lower_x)
upper_x = np.cast[dtype_util.as_numpy_dtype(bijector.dtype)](upper_x)
lower_x = np.asarray(
lower_x, dtype=dtype_util.as_numpy_dtype(bijector.dtype)
)
upper_x = np.asarray(
upper_x, dtype=dtype_util.as_numpy_dtype(bijector.dtype)
)
forward_on_10_pts = bijector.forward(ten_x_pts)

# Set the lower/upper limits in the range of the bijector.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def dist_lambda(t):
model.compile(optimizer=optimizer, loss=negloglik)

model.fit(x, y, epochs=1, verbose=True, batch_size=32, validation_split=0.2)
self.assertGreater(model.history.history["val_loss"][0], -np.Inf)
self.assertGreater(model.history.history["val_loss"][0], -np.inf)


if __name__ == "__main__":
Expand Down
11 changes: 10 additions & 1 deletion tensorflow_probability/python/internal/dtype_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ def _unify_dtype(current, new):
lambda dt, h: base_dtype(h if dt is None else dt), dtype, dtype_hint)


def _issctype(x):
if not isinstance(x, (type, np.dtype)):
return False
try:
return np.dtype(x) != np.object_
except: # pylint: disable=bare-except
return False


def convert_to_dtype(tensor_or_dtype, dtype=None, dtype_hint=None):
"""Get a dtype from a list/tensor/dtype using convert_to_tensor semantics."""
if tensor_or_dtype is None:
Expand All @@ -244,7 +253,7 @@ def convert_to_dtype(tensor_or_dtype, dtype=None, dtype_hint=None):
# Numpy dtypes defer to dtype/dtype_hint
elif isinstance(tensor_or_dtype, np.ndarray):
dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype)
elif np.issctype(tensor_or_dtype):
elif _issctype(tensor_or_dtype):
dt = base_dtype(dtype or dtype_hint or tensor_or_dtype)
else:
# If this is a Python object, call `convert_to_tensor` and grab the dtype.
Expand Down
6 changes: 4 additions & 2 deletions tensorflow_probability/python/internal/dtype_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ def testMax(self, dtype, expected_maxval):
disable_numpy=True,
reason='`convert_to_tensor` respects array dtypes in numpy backend.')
def testConvertToDtype(self, tensor_or_dtype, dtype, dtype_hint):
if np.issctype(tensor_or_dtype):
if isinstance(tensor_or_dtype, np.generic) or hasattr(
tensor_or_dtype, 'dtype'
):
example_tensor = np.zeros([], tensor_or_dtype)
elif isinstance(tensor_or_dtype, tf.DType):
example_tensor = tf.zeros([], tensor_or_dtype)
Expand All @@ -203,7 +205,7 @@ def testConvertToDtype(self, tensor_or_dtype, dtype, dtype_hint):
disable_jax=True,
reason='`convert_to_tensor` only raises in TF mode.')
def testConvertToDTypeRaises(self, tensor_or_dtype, dtype, dtype_hint):
if np.issctype(tensor_or_dtype):
if isinstance(tensor_or_dtype, np.generic):
example_tensor = np.zeros([], tensor_or_dtype)
elif isinstance(tensor_or_dtype, tf.DType):
example_tensor = tf.zeros([], tensor_or_dtype)
Expand Down

0 comments on commit 9b07a2e

Please sign in to comment.