-
Notifications
You must be signed in to change notification settings - Fork 308
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for exogenous regressors (#125)
* Simplify implementations of SARIMA, ETS, VectorAR. We remove code duplication from these files, and we also remove the online training features of ETS. This should make the ETS model easier to use at inference time, since time_series_prev is no longer entangled with online inference features. * Bugfixes. * Simpler TransformSequence implementation. * Refactor argument order for anomaly detector train We change the order from (train_data, anomaly_labels, train_config, post_rule_train_config) to (train_data, train_config, anomaly_labels, post_rule_train_config). This brings it in line with the base signature. * Allow custom color plots. * Add exogenous regressor support for Prophet. * Add exog regressor support for SARIMA. * Add exog_data param to downstream models. We allow ForecastingDetector's, ForecasterEnsemble's, and LayeredForecaster's to accept the param exog_data in both train() and forecast() methods. However, layered models (especially autoML) do not yet support training with exogenous data. * Make base train() abstract. * Make exogenous pre-processing more rigorous. * Add exog regressor support to evaluators. * Add exog regressor test for Prophet. * Abstract away the notion of grid search. * Add ensemble & evaluator test coverage for exog. * Fix build failures. * Remove exog_data reference from SeasonalityLayer * More fixes. * Add exogenous regressor support to layered models. * Silence Prophet deserialization warnings. * Fix 2-layer AutoSarima bug. * Change train to _train for DetectorEnsemble. * Fix typos. * Slight cleanup of ensemble code. * Simplify ensemble cross-val to use evaluators. * Add save/load test coverage for exog. * Make layers aware of exogenous regressors. * Deprecate Python 3.6 & update version. * More rigorous handling of kwargs in LayeredModel. * Rename ForecasterWithExog to ForecasterExogBase * More robust support for inverse transforms. Use named variables rather than integer indexing. This ensures that we can invert multivariate forecasts. * Fix how model_kwargs is set in layered models. * Make time series more JSON-compatible. * Various bugfixes. * More systematic post-processing of forecasts. * Fix docs error. * Remove RMSE value assertions from boostingtrees. * Skip univariate VectorAR test as before. * Skip spark tests on Python 3.10 * Optimize application of inverse transforms. The inverse of many transforms is just the identity. This commit adds an optimization which skips applying the inverse altogether if this is the case. * Reduce size of walmart_mini to prevent OOM errors. * Update pyspark session fixture. * Make test_vector_ar smaller. * Remove python3.6 fallback code from conj priors. * models.anomaly.utils -> models.utils.torch_utils * Update default settings of ARIMA models. I figured out that the enforce_invertibility=True and enforce_stationarity=True settings were previously causing segfaults in the unit tests because of an out-of-memory error. I have updated the tests to use smaller data size to circumvent the error. * Fix failures from SARIMA model update. * Try unpersisting dataframes in spark tests. This could ameliorate OOM issues (if that's the cause of test failures). * Increase Spark network timeout for unit tests. * Run spark tests separately for 3.8/3.9. * Update test_forecast_ensemble. * Add tutorial on exogenous regressors. * Add auto-retry to tests. * Use cached docs from gh-pages branch. Also improve git robustness of build_docs.sh.
- Loading branch information
Showing
88 changed files
with
2,376 additions
and
6,772 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
name: build | ||
name: tests | ||
|
||
on: | ||
push: | ||
|
@@ -7,18 +7,18 @@ on: | |
branches: [ main ] | ||
|
||
jobs: | ||
build: | ||
tests: | ||
|
||
runs-on: ubuntu-latest | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] | ||
python-version: ["3.7", "3.8", "3.9", "3.10"] | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v2 | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
|
||
|
@@ -34,45 +34,41 @@ jobs: | |
- name: Test with pytest | ||
id: test | ||
run: | | ||
# Get a comma-separated list of the directories of all python source files | ||
source_files=$(for f in $(find merlion -iname "*.py"); do echo -n ",$f"; done) | ||
script="import os; print(','.join({os.path.dirname(f) for f in '$source_files'.split(',') if f}))" | ||
source_modules=$(python -c "$script") | ||
# A BLAS bug causes high-dim multivar Bayesian LR test to segfault in 3.6. Run the test first to avoid. | ||
if [[ $PYTHON_VERSION == 3.6 ]]; then | ||
python -m pytest -v tests/change_point/test_conj_prior.py | ||
coverage run --source=${source_modules} -L -m pytest -v --ignore tests/change_point/test_conj_prior.py | ||
else | ||
coverage run --source=${source_modules} -L -m pytest -v | ||
fi | ||
# Obtain code coverage from coverage report | ||
coverage report | ||
coverage xml -o .github/badges/coverage.xml | ||
COVERAGE=`coverage report | grep "TOTAL" | grep -Eo "[0-9\.]+%"` | ||
echo "##[set-output name=coverage;]${COVERAGE}" | ||
# Choose a color based on code coverage | ||
COVERAGE=${COVERAGE/\%/} | ||
if (($COVERAGE > 90)); then | ||
COLOR=brightgreen | ||
elif (($COVERAGE > 80)); then | ||
COLOR=green | ||
elif (($COVERAGE > 70)); then | ||
COLOR=yellow | ||
elif (($COVERAGE > 60)); then | ||
COLOR=orange | ||
else | ||
COLOR=red | ||
fi | ||
echo "##[set-output name=color;]${COLOR}" | ||
uses: nick-fields/retry@v2 | ||
env: | ||
PYTHON_VERSION: ${{ matrix.python-version }} | ||
with: | ||
max_attempts: 3 | ||
timeout_minutes: 40 | ||
command: | | ||
# Get a comma-separated list of the directories of all python source files | ||
source_files=$(for f in $(find merlion -iname "*.py"); do echo -n ",$f"; done) | ||
script="import os; print(','.join({os.path.dirname(f) for f in '$source_files'.split(',') if f}))" | ||
source_modules=$(python -c "$script") | ||
# Run tests & obtain code coverage from coverage report. | ||
coverage run --source=${source_modules} -L -m pytest -v -s | ||
coverage report && coverage xml -o .github/badges/coverage.xml | ||
COVERAGE=`coverage report | grep "TOTAL" | grep -Eo "[0-9\.]+%"` | ||
echo "##[set-output name=coverage;]${COVERAGE}" | ||
# Choose a color based on code coverage | ||
COVERAGE=${COVERAGE/\%/} | ||
if (($COVERAGE > 90)); then | ||
COLOR=brightgreen | ||
elif (($COVERAGE > 80)); then | ||
COLOR=green | ||
elif (($COVERAGE > 70)); then | ||
COLOR=yellow | ||
elif (($COVERAGE > 60)); then | ||
COLOR=orange | ||
else | ||
COLOR=red | ||
fi | ||
echo "##[set-output name=color;]${COLOR}" | ||
- name: Create coverage badge | ||
if: ${{ github.ref == 'refs/heads/main' && matrix.python-version == '3.8' }} | ||
if: ${{ github.ref == 'refs/heads/main' && matrix.python-version == '3.10' }} | ||
uses: emibcn/[email protected] | ||
with: | ||
label: coverage | ||
|
@@ -81,8 +77,8 @@ jobs: | |
path: .github/badges/coverage.svg | ||
|
||
- name: Push badge to badges branch | ||
uses: s0/git-publish-subdir-action@develop | ||
if: ${{ github.ref == 'refs/heads/main' && matrix.python-version == '3.8' }} | ||
uses: s0/git-publish-subdir-action@v2.5.1 | ||
if: ${{ github.ref == 'refs/heads/main' && matrix.python-version == '3.10' }} | ||
env: | ||
REPO: self | ||
BRANCH: badges | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.