From a455824a227532ee08582580beeb99ccd5e253db Mon Sep 17 00:00:00 2001 From: ursk Date: Thu, 28 Mar 2024 07:58:04 -0700 Subject: [PATCH] Add AutoBNN example colab. PiperOrigin-RevId: 619931273 --- .../examples/Forecasting_With_AutoBNN.ipynb | 398 ++++++++++++++++++ 1 file changed, 398 insertions(+) create mode 100644 discussion/examples/Forecasting_With_AutoBNN.ipynb diff --git a/discussion/examples/Forecasting_With_AutoBNN.ipynb b/discussion/examples/Forecasting_With_AutoBNN.ipynb new file mode 100644 index 0000000000..9a14bb16db --- /dev/null +++ b/discussion/examples/Forecasting_With_AutoBNN.ipynb @@ -0,0 +1,398 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "-XQVajJ1o2O2" + }, + "source": [ + "# Probabilistic Time Series Forecasting with the AutoBNN Package\n", + "\n", + "This notebook introduces the Tensorflow Probability AutoBNN package. It provides examples for forecasting with the Mauna Loa atmospheric CO2 dataset, using either learned or manually specified kernels.\n", + "\n", + "We recommend using a [TPU runtime](https://colab.sandbox.google.com/notebooks/tpu.ipynb) for this colab.\n", + "\n", + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "-WNWeIl_8giU" + }, + "outputs": [], + "source": [ + "!pip install autobnn\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from autobnn import estimators\n", + "from autobnn import operators\n", + "from autobnn import kernels\n", + "from autobnn import likelihoods\n", + "from autobnn import training_util\n", + "\n", + "seed = jax.random.PRNGKey(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BIpknrQ4p3CL" + }, + "source": [ + "## Mauna Loa CO2 Data\n", + "The Mauna Loa atmospheric CO2 dataset is available for download from the Scripps Institute. We construct a Pandas dataframe with the monthly readings of CO2 concentration, holding out the final 10 years to forecast." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "-JhX3mWuFIwC" + }, + "outputs": [], + "source": [ + "# The Mauna Loa data as taken from\n", + "# https://scrippsco2.ucsd.edu/assets/data/atmospheric/stations/in_situ_co2/monthly/monthly_in_situ_co2_mlo.csv\n", + "# Data is provided by the Scripps Institution of Oceanography at UC San Diego under a CC BY license\n", + "# (https://creativecommons.org/licenses/by/4.0/).\n", + "co2_by_month = np.array('320.62,321.60,322.39,323.70,324.08,323.75,322.38,320.36,318.64,318.10,319.78,321.03,322.33,322.50,323.04,324.42,325.00,324.09,322.54,320.92,319.25,319.39,320.73,321.96,322.57,323.15,323.89,325.02,325.57,325.36,324.14,322.11,320.33,320.25,321.32,322.89,324.00,324.42,325.63,326.66,327.38,326.71,325.88,323.66,322.38,321.78,322.85,324.12,325.06,325.98,326.93,328.14,328.08,327.67,326.34,324.69,323.10,323.06,324.01,325.13,326.17,326.68,327.17,327.79,328.92,328.57,327.36,325.43,323.36,323.56,324.80,326.01,326.77,327.63,327.75,329.73,330.07,329.09,328.04,326.32,324.84,325.20,326.50,327.55,328.55,329.56,330.30,331.50,332.48,332.07,330.87,329.31,327.51,327.18,328.16,328.64,329.35,330.71,331.48,332.65,333.09,332.25,331.18,329.39,327.43,327.37,328.46,329.57,330.40,331.40,332.04,333.31,333.97,333.60,331.90,330.06,328.56,328.34,329.49,330.76,331.75,332.56,333.50,334.58,334.88,334.33,333.05,330.94,329.30,328.94,330.31,331.68,332.93,333.42,334.70,336.07,336.75,336.27,334.92,332.75,331.59,331.16,332.40,333.85,334.97,335.38,336.64,337.76,338.01,337.89,336.54,334.68,332.76,332.55,333.92,334.95,336.23,336.76,337.96,338.88,339.47,339.29,337.73,336.09,333.92,333.86,335.29,336.73,338.01,338.36,340.07,340.77,341.47,341.17,339.56,337.60,335.88,336.02,337.10,338.21,339.24,340.48,341.38,342.51,342.91,342.25,340.49,338.43,336.69,336.86,338.36,339.61,340.75,341.61,342.70,343.57,344.14,343.35,342.06,339.81,337.98,337.86,339.26,340.49,341.38,342.52,343.10,344.94,345.76,345.32,343.98,342.38,339.87,339.99,341.15,342.99,343.70,344.50,345.28,347.06,347.43,346.80,345.39,343.28,341.07,341.35,342.98,344.22,344.97,345.99,347.42,348.35,348.93,348.25,346.56,344.67,343.09,342.80,344.24,345.56,346.30,346.95,347.85,349.55,350.21,349.55,347.94,345.90,344.85,344.17,345.66,346.90,348.02,348.48,349.42,350.99,351.85,351.26,349.51,348.10,346.45,346.36,347.81,348.96,350.43,351.73,352.22,353.59,354.22,353.79,352.38,350.43,348.73,348.88,350.07,351.34,352.76,353.07,353.68,355.42,355.67,355.12,353.90,351.67,349.80,349.99,351.30,352.52,353.66,354.70,355.38,356.20,357.16,356.23,354.81,352.91,350.96,351.18,352.83,354.21,354.72,355.75,357.16,358.60,359.34,358.24,356.17,354.02,352.15,352.21,353.75,354.99,355.99,356.72,357.81,359.15,359.66,359.25,357.02,355.00,353.01,353.31,354.16,355.40,356.70,357.17,358.38,359.46,360.28,359.60,357.57,355.52,353.69,353.99,355.34,356.80,358.37,358.91,359.97,361.26,361.69,360.94,359.55,357.48,355.84,356.00,357.58,359.04,359.97,361.00,361.64,363.45,363.80,363.26,361.89,359.45,358.05,357.75,359.56,360.70,362.05,363.24,364.02,364.71,365.41,364.97,363.65,361.48,359.45,359.61,360.76,362.33,363.18,363.99,364.56,366.36,366.80,365.63,364.47,362.50,360.19,360.78,362.43,364.28,365.33,366.15,367.31,368.61,369.30,368.88,367.64,365.78,363.90,364.23,365.46,366.97,368.15,368.87,369.59,371.14,371.00,370.35,369.27,366.93,364.64,365.13,366.68,368.00,369.14,369.46,370.51,371.66,371.83,371.69,370.12,368.12,366.62,366.73,368.29,369.53,370.28,371.50,372.12,372.86,374.02,373.31,371.62,369.55,367.96,368.09,369.68,371.24,372.44,373.08,373.52,374.85,375.55,375.40,374.02,371.48,370.70,370.25,372.08,373.78,374.68,375.62,376.11,377.65,378.35,378.13,376.61,374.48,372.98,373.00,374.35,375.69,376.79,377.36,378.39,380.50,380.62,379.55,377.76,375.83,374.05,374.22,375.84,377.44,378.34,379.61,380.08,382.05,382.24,382.08,380.67,378.67,376.42,376.80,378.31,379.96,381.37,382.02,382.56,384.37,384.92,384.03,382.28,380.48,378.81,379.06,380.14,381.66,382.58,383.71,384.34,386.23,386.41,385.87,384.45,381.84,380.86,380.86,382.36,383.61,385.07,385.84,385.83,386.77,388.51,388.05,386.25,384.08,383.09,382.78,384.01,385.11,386.65,387.12,388.52,389.57,390.16,389.62,388.07,386.08,384.65,384.33,386.05,387.49,388.55,390.07,391.01,392.38,393.22,392.24,390.33,388.52,386.84,387.16,388.67,389.81,391.30,391.92,392.45,393.37,394.28,393.69,392.59,390.21,389.00,388.93,390.24,391.80,393.07,393.35,394.36,396.43,396.87,395.88,394.52,392.54,391.13,391.01,392.95,394.34,395.61,396.85,397.26,398.35,399.98,398.87,397.37,395.41,393.39,393.70,395.19,396.82,397.92,398.10,399.47,401.33,401.88,401.31,399.07,397.21,395.40,395.65,397.23,398.79,399.85,400.31,401.51,403.45,404.10,402.88,401.61,399.00,397.50,398.28,400.24,401.89,402.65,404.16,404.85,407.57,407.66,407.00,404.50,402.24,401.01,401.50,403.64,404.55,406.07,406.64,407.06,408.95,409.91,409.12,407.20,405.24,403.27,403.64,405.17,406.75,408.05,408.34,409.25,410.30,411.30,410.88,408.90,407.10,405.59,405.99,408.12,409.23,410.92'.split(',')).astype(np.float32)\n", + "num_forecast_steps = 12 * 10 # Forecast the final ten years, given previous data\n", + "co2_by_month_training_data = co2_by_month[:-num_forecast_steps]\n", + "\n", + "co2_dates = np.arange(\"1966-01\", \"2019-02\", dtype=\"datetime64[M]\")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "ak2nUqgm9KLL" + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(16, 4))\n", + "ax.plot(co2_dates[:-num_forecast_steps], co2_by_month_training_data, lw=2)\n", + "plt.title(\"Monthly Average CO2 Concentration, Mauna Loa, Hawaii\")\n", + "plt.xlabel(\"Year\")\n", + "plt.ylabel(\"CO2 concentration / ppm\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GCxNHbFxhhaG" + }, + "source": [ + "The data is passed to the fitting function as an array of time indices `x_train` (normalized to span the range of 0 to 1), which form the input to the neural network, and an array of target values `y_train` which the network predicts." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "jvobMqNkhriD" + }, + "outputs": [], + "source": [ + "x_test = np.arange(len(co2_by_month), dtype=np.float32)\n", + "y_test = co2_by_month\n", + "x_train = x_test[:-num_forecast_steps]\n", + "y_train = co2_by_month_training_data\n", + "\n", + "x_scale = x_train.max()\n", + "x_train = x_train / x_scale\n", + "x_test = x_test / x_scale\n", + "\n", + "one_year = jnp.array(12. / x_scale, dtype=jnp.float32) # monthly data" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mNVpbSqX_RTH" + }, + "source": [ + "## Modeling with a hand-crafted GP kernel\n", + "\n", + "This section explores applying the AutoBNN framework with a hand-crafted kernel, building upon the example in section 5.4.3 of [Rasmussen & Williams (2006)](https://gaussianprocess.org/gpml/chapters/RW5.pdf) for modeling the Mauna Loa atmospheric CO2 dataset. Their example aimed to showcase the power and flexibility of Gaussian Processes, and their chosen kernel serves as a valuable test case for our framework.\n", + "\n", + "Rasmussen & Williams select the following components for their model:\n", + "* a squared exponential (SE) kernel to capture the smooth trend.\n", + "* a periodic kernel with a period of one year, multiplied with an SE to allow deviations from exact periodicity.\n", + "* a rational quadratic to model medium term irregularities.\n", + "* a squared exponential noise term.\n", + "\n", + "The authors note that \"one could also have used a squared exponential form for [the rational quadratic] component\", which we do here as our framework does not currently implement the rational quadratic. The squared exponential noise term is provided by our Normal likelihood model. With these translations, the Rasumussen & Williams model is easily specified in the following AutoBNN code:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "YCYxk6OF_kQN" + }, + "outputs": [], + "source": [ + "smooth_trend = kernels.ExponentiatedQuadraticBNN(\n", + " amplitude_scale=4., length_scale_scale=one_year * 4.)\n", + "\n", + "# For annoying technical reasons, any kernels used in a Multiply operator\n", + "# must pass the going_to_be_multiplied=True option in their constructor.\n", + "# (This is because the Multiply operator doesn't use the final layer of\n", + "# weights from its input BNNs, and going_to_be_multiplied=True tells the\n", + "# BNN not to create them so that they are not allocated or optimized over.)\n", + "seasonal = operators.Multiply(bnns=[\n", + " kernels.PeriodicBNN(period=one_year, going_to_be_multiplied=True),\n", + " kernels.ExponentiatedQuadraticBNN(amplitude_scale=4.,\n", + " length_scale_scale=one_year,\n", + " going_to_be_multiplied=True)])\n", + "\n", + "medium_term = kernels.ExponentiatedQuadraticBNN(length_scale_scale=one_year)\n", + "\n", + "# To train a BNN, we must define a likelihood model that says how likely\n", + "# the observed data is given the BNN's predictions. The\n", + "# NormalLikelihoodLogisticNoise likelihood model says that\n", + "# P(data | predictions) ~ N(mean=predictions, scale=n)\n", + "# with\n", + "# n ~ Logistic(location=0, scale=log_noise_scale)\n", + "likelihood = likelihoods.NormalLikelihoodLogisticNoise(log_noise_scale=0.1)\n", + "\n", + "kernel = operators.WeightedSum(likelihood_model=likelihood,\n", + " bnns=[smooth_trend, medium_term, seasonal])\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-RBh985WhfxO" + }, + "source": [ + " We fit the model using a maximum a posteriori (MAP) estimate. We train an ensemble of 64 fits, which helps explore multiple modes of the posterior and results in better uncertainty than a single MAP fit.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "6RgXzLjphhLk" + }, + "outputs": [], + "source": [ + "%%time\n", + "seed = jax.random.PRNGKey(0)\n", + "fit_seed, pred_seed, seed = jax.random.split(seed, 3)\n", + "\n", + "params, loss = training_util.fit_bnn_map(\n", + " kernel,\n", + " seed=fit_seed,\n", + " x_train=x_train[..., None],\n", + " y_train=y_train[..., None],\n", + " num_particles=32,\n", + " # optimizer_kwargs:\n", + " num_iters=2_000,\n", + " learning_rate=0.01\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NSAxnwWn3Kzw" + }, + "source": [ + "## Visualizing the Results of the hand-crafted model\n", + "\n", + "After computing predictions on the held-out set, we use our visualization utilities to show loss curves and predictions. The loss curves serve mainly as a diagnostic to ensure all 64 particles converge to a low loss, and do not diverge or get stuck. The prediction plot shows ground truth data and the model fit. The fit shows both individual particles (gray) as well as the mean prediction with 5% and 95% confidence intervals (blue). The fit here captures the overall upwards trend and sawtooth shape of the periodic component.\n", + "After about 3 years of prediction intervals that contain the held out data, the model begins to underpredict, indicating that it did not quite capture the trend in the data, and is overly confident in this prediction.\n", + "\n", + "While it's hard to compare directly with the results in Rasmussen & Williams (they fit on an earlier subset of the data and do not compare against a held-out set), we note that imperfections in our fit could be due to slight differences in the kernel specification. This serves as a reminder how difficult and error-prone kernel design can be.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "BONZi4bS19dO" + }, + "outputs": [], + "source": [ + "preds = jax.vmap(lambda p: kernel.apply(p, x_test[..., None]))(params)\n", + "\n", + "draws = likelihood.sample(params['params'], preds.squeeze(), seed=pred_seed, sample_shape=100)\n", + "lo, mid, p90, hi = jnp.percentile(draws.squeeze(), jnp.array([2.5, 50., 90., 97.5]), axis=(0, 1))\n", + "\n", + "plot = training_util.plot_results(co2_dates,\n", + " preds,\n", + " dates_test=co2_dates[-num_forecast_steps:],\n", + " y_test=co2_by_month[-num_forecast_steps:],\n", + " p2_5=lo,\n", + " p50=mid,\n", + " p97_5=hi,\n", + " dates_train=co2_dates[:-num_forecast_steps],\n", + " y_train=co2_by_month[:-num_forecast_steps],\n", + " diagnostics=loss,\n", + " log_scale=True,\n", + " show_particles=True,\n", + " left_limit=50*12,\n", + " right_limit=10*12)\n", + "\n", + "plot.axes[1].set_xlabel(\"Year\")\n", + "plot.axes[1].set_ylabel(\"CO2 concentration / ppm\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QgYDt9oIFpFx" + }, + "source": [ + "## Metrics\n", + "\n", + "The `make_results_dataframe` provides SMAPE, MASE and MSIS, [common forecasting metrics](https://medium.com/@vinitkothari.24/time-series-evaluation-metrics-mape-vs-wmape-vs-smape-which-one-to-use-why-and-when-part1-32d3852b4779), for this fit. We compute metrics to 10 years out (specified by `num_forecast_steps`) and report the fit at 60 months (i.e. 5 years) out as an example." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "rM_9ZVBkJuJW" + }, + "outputs": [], + "source": [ + "training_util.make_results_dataframe(mid[-num_forecast_steps:],\n", + " y_test=co2_by_month[-num_forecast_steps:],\n", + " y_train=co2_by_month[:-num_forecast_steps],\n", + " p2_5=lo[-num_forecast_steps:],\n", + " p90=p90[-num_forecast_steps:],\n", + " p97_5=hi[-num_forecast_steps:]\n", + " ).iloc[60]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kQU9ulXf8972" + }, + "source": [ + "## Structure Discovery with Continuous Relaxations\n", + "\n", + "Manually specifying a kernel, like we did above, can be useful in many situations. The true power of the AutoBNN framework, though, lies in its ability to discover structure directly from the data. Below we run the same MAP estimator with the `sum_of_products` model, which is one of the model families distributed with the AutoBNN packages.\n", + "\n", + "This kernel uses a continuous relaxation over periodic, linear, exponentiated quadratic, second-degree polynomial and identity leaves as the basic structure. It then combines four such subtrees by first adding them in pairs, and multiplying the final two subtrees. This is a structure that we empirically found to be suitable to model a wide variety of time series. It can serve as a useful starting point for further exploration. Other model families supporting structure discovery can be found in the [models.py](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/experimental/autobnn/models.py) file." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "DHvq-xN-tqW7" + }, + "outputs": [], + "source": [ + "%%time\n", + "fit_seed, pred_seed, seed = jax.random.split(seed, 3)\n", + "\n", + "est = estimators.AutoBnnMapEstimator(\n", + " 'sum_of_products',\n", + " likelihood_model='normal_likelihood_logistic_noise',\n", + " seed=jax.random.PRNGKey(0),\n", + " periods=(one_year,),\n", + " num_particles=32,\n", + ")\n", + "\n", + "est = est.fit(x_train[..., None], y_train[..., None])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9JZG6zUy_hcd" + }, + "source": [ + "## Visualizing the Results of the AutoBNN model\n", + "\n", + "Compared to the hand-crafted kernel above, the prediction plot shows a tighter fit to the training data, with a graceful increase in uncertainty deeper into the forecast horizon." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "ThD4Kur0vjyw" + }, + "outputs": [], + "source": [ + "preds = est.predict(x_test[..., None])\n", + "lo, mid, p90, hi = est.predict_quantiles(x_test[..., None], q=[2.5, 50., 90., 97.5])\n", + "_ = training_util.plot_results(co2_dates,\n", + " preds,\n", + " dates_test=co2_dates[-num_forecast_steps:],\n", + " y_test=co2_by_month[-num_forecast_steps:],\n", + " p2_5=lo,\n", + " p50=mid,\n", + " p97_5=hi,\n", + " dates_train=co2_dates[:-num_forecast_steps],\n", + " y_train=co2_by_month[:-num_forecast_steps],\n", + " diagnostics=loss,\n", + " log_scale=True,\n", + " show_particles=True,\n", + " left_limit=50*12,\n", + " right_limit=10*12)\n", + "plot.axes[1].set_xlabel(\"Year\")\n", + "plot.axes[1].set_ylabel(\"CO2 concentration / ppm\")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "kIWIALyCy4Mv" + }, + "outputs": [], + "source": [ + "training_util.make_results_dataframe(mid[-num_forecast_steps:],\n", + " y_test=co2_by_month[-num_forecast_steps:],\n", + " y_train=co2_by_month[:-num_forecast_steps],\n", + " p2_5=lo[-num_forecast_steps:],\n", + " p90=p90[-num_forecast_steps:],\n", + " p97_5=hi[-num_forecast_steps:]\n", + " ).iloc[60]\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "84SJr4JXDEyf" + }, + "source": [ + "## Conclusion\n", + "\n", + "This colab showcases two applications of the AutoBNN framework to the Mauna Loa dataset, fitting a model with pre-defined structure and using structure discovery. While this dataset with about 500 time steps is well within reach for fitting a GP, we have successfuly run the same models with datasets of 5000 and more timesteps." + ] + } + ], + "metadata": { + "accelerator": "TPU", + "colab": { + "gpuType": "V28", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}