diff --git a/tensorflow_probability/python/distributions/categorical.py b/tensorflow_probability/python/distributions/categorical.py index 92dec6c4ef..b63a4dfb70 100644 --- a/tensorflow_probability/python/distributions/categorical.py +++ b/tensorflow_probability/python/distributions/categorical.py @@ -333,6 +333,32 @@ def _entropy(self): _mul_exp(log_probs, log_probs), axis=-1) + def _mean(self): + if self._logits is None: + # If we only have probs, there's not much we can do to ensure numerical + # precision. + log_probs = tf.math.log(self._probs) + else: + log_probs = tf.math.log_softmax(self._logits) + labels = tf.range(self._num_categories(log_probs), dtype=log_probs.dtype) + mean = tf.reduce_sum(_mul_exp(labels, log_probs), axis=-1) + tensorshape_util.set_shape(mean, log_probs.shape[:-1]) + return mean + + def _variance(self): + if self._logits is None: + # If we only have probs, there's not much we can do to ensure numerical + # precision. + log_probs = tf.math.log(self._probs) + else: + log_probs = tf.math.log_softmax(self._logits) + labels = tf.range(self._num_categories(log_probs), dtype=log_probs.dtype) + mean = tf.reduce_sum(_mul_exp(labels, log_probs), axis=-1, keepdims=True) + var = tf.reduce_sum( + _mul_exp(tf.math.squared_difference(labels, mean), log_probs), axis=-1) + tensorshape_util.set_shape(var, log_probs.shape[:-1]) + return var + def _mode(self): x = self._probs if self._logits is None else self._logits mode = tf.cast(tf.argmax(x, axis=-1), self.dtype) diff --git a/tensorflow_probability/python/distributions/categorical_test.py b/tensorflow_probability/python/distributions/categorical_test.py index 671881e848..5b1684cb89 100644 --- a/tensorflow_probability/python/distributions/categorical_test.py +++ b/tensorflow_probability/python/distributions/categorical_test.py @@ -500,6 +500,25 @@ def testLogPMFShapeNoBatch(self): self.assertEqual(3, tensorshape_util.rank(log_prob.shape)) self.assertAllEqual([2, 2, 2], log_prob.shape) + def testMean(self): + histograms = np.array([[[0.2, 0.8], [0.6, 0.4]]]) + dist = categorical.Categorical( + tf.math.log(histograms) - 50., validate_args=True) + self.assertAllClose([[0.8, 0.4]], self.evaluate(dist.mean())) + + def testMeanHuge(self): + num_logits = 10_000_000 + dist = categorical.Categorical( + tf.zeros(num_logits), validate_args=True) + self.assertAllClose(num_logits / 2, self.evaluate(dist.mean())) + + def testVariance(self): + histograms = np.array([[[0.2, 0.8], [0.6, 0.4]]]) + dist = categorical.Categorical( + tf.math.log(histograms) - 50., validate_args=True) + self.assertAllClose( + [[0.2 * 0.8, 0.6 * 0.4]], self.evaluate(dist.variance())) + def testMode(self): histograms = np.array([[[0.2, 0.8], [0.6, 0.4]]]) dist = categorical.Categorical(