diff --git a/tensorflow_probability/python/internal/trainable_state_util_test.py b/tensorflow_probability/python/internal/trainable_state_util_test.py index f14de01fec..e6ca7cc04e 100644 --- a/tensorflow_probability/python/internal/trainable_state_util_test.py +++ b/tensorflow_probability/python/internal/trainable_state_util_test.py @@ -271,15 +271,13 @@ def test_apply_raises_on_bad_parameters(self): def test_rewrites_yield_to_return_in_docstring(self): wrapped = trainable_state_util.as_stateless_builder( generator_with_docstring) - self.assertContainsExactSubsequence( - generator_with_docstring.__doc__, 'Yields:') + self.assertIn('Yields:', generator_with_docstring.__doc__) self.assertNotIn('Yields:', wrapped.__doc__) - self.assertContainsExactSubsequence( + self.assertIn('Test generator with a docstring.', wrapped.__doc__) + self.assertIn( + trainable_state_util._STATELESS_RETURNS_DOCSTRING, wrapped.__doc__, - 'Test generator with a docstring.') - self.assertContainsExactSubsequence( - wrapped.__doc__, - trainable_state_util._STATELESS_RETURNS_DOCSTRING) + ) def test_fitting_example(self): if not JAX_MODE: @@ -335,15 +333,13 @@ def test_structured_parameters(self): def test_rewrites_yield_to_return_in_docstring(self): wrapped = trainable_state_util.as_stateful_builder( generator_with_docstring) - self.assertContainsExactSubsequence( - generator_with_docstring.__doc__, 'Yields:') + self.assertIn('Yields:', generator_with_docstring.__doc__) self.assertNotIn('Yields:', wrapped.__doc__) - self.assertContainsExactSubsequence( - wrapped.__doc__, - 'Test generator with a docstring.') - self.assertContainsExactSubsequence( + self.assertIn('Test generator with a docstring.', wrapped.__doc__) + self.assertIn( + trainable_state_util._STATEFUL_RETURNS_DOCSTRING, wrapped.__doc__, - trainable_state_util._STATEFUL_RETURNS_DOCSTRING) + ) @test_util.jax_disable_variable_test def test_fitting_example(self):