From 53a6d684b01dc72198d3f6e79632febd4bca450c Mon Sep 17 00:00:00 2001 From: mrry Date: Mon, 22 Jul 2024 21:11:38 -0700 Subject: [PATCH] Add `bad_indices_policy` to TF Probability wrappers. PiperOrigin-RevId: 655007060 --- .../python/internal/backend/numpy/misc.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tensorflow_probability/python/internal/backend/numpy/misc.py b/tensorflow_probability/python/internal/backend/numpy/misc.py index c6977864d0..d9f1ecdd5c 100644 --- a/tensorflow_probability/python/internal/backend/numpy/misc.py +++ b/tensorflow_probability/python/internal/backend/numpy/misc.py @@ -118,11 +118,13 @@ def _sort(values, axis=-1, direction='ASCENDING', name=None): # pylint: disable # TODO(b/140685491): Add unit-test. -def _tensor_scatter_nd_add(tensor, indices, updates, name=None): # pylint: disable=unused-argument +def _tensor_scatter_nd_add( + tensor, indices, updates, bad_indices_policy='', name=None): # pylint: disable=unused-argument """Numpy implementation of `tf.tensor_scatter_nd_add`.""" indices = _convert_to_tensor(indices) tensor = _convert_to_tensor(tensor) updates = _convert_to_tensor(updates) + del bad_indices_policy indices = tuple( indices[..., i] for i in range(indices.shape[-1])) # TODO(b/140685491) if JAX_MODE: @@ -132,11 +134,13 @@ def _tensor_scatter_nd_add(tensor, indices, updates, name=None): # pylint: disa # TODO(b/140685491): Add unit-test. -def _tensor_scatter_nd_sub(tensor, indices, updates, name=None): # pylint: disable=unused-argument +def _tensor_scatter_nd_sub( + tensor, indices, updates, bad_indices_policy='', name=None): # pylint: disable=unused-argument """Numpy implementation of `tf.tensor_scatter_nd_sub`.""" indices = _convert_to_tensor(indices) tensor = _convert_to_tensor(tensor) updates = _convert_to_tensor(updates) + del bad_indices_policy indices = tuple( indices[..., i] for i in range(indices.shape[-1])) # TODO(b/140685491) if JAX_MODE: @@ -146,11 +150,13 @@ def _tensor_scatter_nd_sub(tensor, indices, updates, name=None): # pylint: disa # TODO(b/140685491): Add unit-test. -def _tensor_scatter_nd_update(tensor, indices, updates, name=None): # pylint: disable=unused-argument +def _tensor_scatter_nd_update( + tensor, indices, updates, bad_indices_policy='', name=None): # pylint: disable=unused-argument """Numpy implementation of `tf.tensor_scatter_nd_update`.""" indices = _convert_to_tensor(indices) tensor = _convert_to_tensor(tensor) updates = _convert_to_tensor(updates) + del bad_indices_policy indices = tuple( indices[..., i] for i in range(indices.shape[-1])) # TODO(b/140685491) if JAX_MODE: