diff --git a/tensorflow_probability/python/distributions/hidden_markov_model.py b/tensorflow_probability/python/distributions/hidden_markov_model.py index 4426801375..bc30c5a8d8 100644 --- a/tensorflow_probability/python/distributions/hidden_markov_model.py +++ b/tensorflow_probability/python/distributions/hidden_markov_model.py @@ -1181,6 +1181,36 @@ def _reduce_one_step(): return ps.cond(self.num_steps > 1, _reduce_multiple_steps, _reduce_one_step) + def single_step_prediction(self, observation, prediction_distribution=None): + """ + Function to run single prediction step based on incoming observation data point and the current prediction + distribution of the hmm model. + If no prediction_distribution is given (typically in the initial step), then the current distribution is derived + from the priors of the hmm (the initial_distribution over the states). In a forecasting + model that runs on live data, the first step would require initialisation while subsequent steps would use the + previous step's prediction distribution as input. + The prediction distribution is updated and returned from this function. + """ + observation = tf.convert_to_tensor(observation, name='observations') + + if prediction_distribution is None: + num = self.initial_distribution.log_prob(range(self.num_states_static)) \ + + self.observation_distribution.log_prob(observation) + else: + if not isinstance(prediction_distribution, distribution.Distribution): + raise TypeError('If prediction_distribution is provided, it must be a Distribution object, ' + 'but saw: %s' % prediction_distribution) + num = tf.math.log(prediction_distribution.probs) \ + + self.observation_distribution.log_prob(observation) + + filtering_distribution = tf.exp(num - tf.reduce_logsumexp(num)) + prediction_distribution = tf.tensordot(self.transition_distribution.probs_parameter(), + filtering_distribution, axes=1) + prediction_distribution = categorical.Categorical(probs=prediction_distribution) + observation_prediction = tf.tensordot(prediction_distribution.probs, self.observation_distribution.mean(), axes=1) + + return prediction_distribution, observation_prediction + # pylint: disable=protected-access def _default_event_space_bijector(self): return (self._observation_distribution.