Skip to content

Commit

Permalink
Add bad_indices_policy to TF Probability wrappers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 655007060
  • Loading branch information
mrry authored and tensorflower-gardener committed Jul 23, 2024
1 parent 8a5daf0 commit 53a6d68
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tensorflow_probability/python/internal/backend/numpy/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,13 @@ def _sort(values, axis=-1, direction='ASCENDING', name=None): # pylint: disable


# TODO(b/140685491): Add unit-test.
def _tensor_scatter_nd_add(tensor, indices, updates, name=None): # pylint: disable=unused-argument
def _tensor_scatter_nd_add(
tensor, indices, updates, bad_indices_policy='', name=None): # pylint: disable=unused-argument
"""Numpy implementation of `tf.tensor_scatter_nd_add`."""
indices = _convert_to_tensor(indices)
tensor = _convert_to_tensor(tensor)
updates = _convert_to_tensor(updates)
del bad_indices_policy
indices = tuple(
indices[..., i] for i in range(indices.shape[-1])) # TODO(b/140685491)
if JAX_MODE:
Expand All @@ -132,11 +134,13 @@ def _tensor_scatter_nd_add(tensor, indices, updates, name=None): # pylint: disa


# TODO(b/140685491): Add unit-test.
def _tensor_scatter_nd_sub(tensor, indices, updates, name=None): # pylint: disable=unused-argument
def _tensor_scatter_nd_sub(
tensor, indices, updates, bad_indices_policy='', name=None): # pylint: disable=unused-argument
"""Numpy implementation of `tf.tensor_scatter_nd_sub`."""
indices = _convert_to_tensor(indices)
tensor = _convert_to_tensor(tensor)
updates = _convert_to_tensor(updates)
del bad_indices_policy
indices = tuple(
indices[..., i] for i in range(indices.shape[-1])) # TODO(b/140685491)
if JAX_MODE:
Expand All @@ -146,11 +150,13 @@ def _tensor_scatter_nd_sub(tensor, indices, updates, name=None): # pylint: disa


# TODO(b/140685491): Add unit-test.
def _tensor_scatter_nd_update(tensor, indices, updates, name=None): # pylint: disable=unused-argument
def _tensor_scatter_nd_update(
tensor, indices, updates, bad_indices_policy='', name=None): # pylint: disable=unused-argument
"""Numpy implementation of `tf.tensor_scatter_nd_update`."""
indices = _convert_to_tensor(indices)
tensor = _convert_to_tensor(tensor)
updates = _convert_to_tensor(updates)
del bad_indices_policy
indices = tuple(
indices[..., i] for i in range(indices.shape[-1])) # TODO(b/140685491)
if JAX_MODE:
Expand Down

0 comments on commit 53a6d68

Please sign in to comment.