Skip to content

Commit

Permalink
[MRG+1] Support custom scoring (#373)
Browse files Browse the repository at this point in the history
* Support custom scoring

* Add to whats_new
  • Loading branch information
tgsmith61591 authored Aug 3, 2020
1 parent 2cb61f0 commit d90d91e
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 18 deletions.
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ v0.8.1) will document the latest features.
* Fix a bug where the :class:`pmdarima.model_selection.SlidingWindowForecastCV` could produce
too few splits for the given input data.

* Permit custom scoring metrics to be passed for out-of-sample scoring, as requested in
`#368 <https://github.com/alkaline-ml/pmdarima/issues/368>`_.


`v1.6.1 <http://alkaline-ml.com/pmdarima/1.6.1/>`_
--------------------------------------------------
Expand Down
38 changes: 38 additions & 0 deletions pmdarima/arima/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import warnings
from sklearn import metrics

# The valid information criteria
VALID_CRITERIA = {'aic', 'aicc', 'bic', 'hqic', 'oob'}
Expand Down Expand Up @@ -98,3 +99,40 @@ def check_trace(trace):
if trace:
return 1
return 0


def get_scoring_metric(metric):
"""Get a scoring metric by name, or passthrough a callable
Parameters
----------
metric : str or callable
A name of a scoring metric, or a custom callable function. If it is a
callable, it must adhere to the signature::
def func(y_true, y_pred)
Note that the ARIMA model selection seeks to MINIMIZE the score, and it
is up to the user to ensure that scoring methods that return maximizing
criteria (i.e., ``r2_score``) are wrapped in a function that will
return the negative value of the score.
"""
if isinstance(metric, str):

# XXX: legacy support, remap mse/mae to their long versions
if metric == "mse":
return metrics.mean_squared_error
if metric == "mae":
return metrics.mean_absolute_error

try:
return getattr(metrics, metric)
except AttributeError:
raise ValueError("'%s' is not a valid scoring method." % metric)

if not callable(metric):
raise TypeError("`metric` must be a valid scoring method, or a "
"callable, but got type=%s" % type(metric))

# TODO: warn for potentially invalid signature?
return metric
27 changes: 16 additions & 11 deletions pmdarima/arima/arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# A user-friendly wrapper to the statsmodels ARIMA that mimics the familiar
# sklearn interface.

from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.utils.metaestimators import if_delegate_has_method
from sklearn.utils.validation import check_array

Expand All @@ -15,13 +14,13 @@
import numpy as np
import warnings

from . import _validation as val
from ..base import BaseARIMA
from ..compat.numpy import DTYPE # DTYPE for arrays
from ..compat.sklearn import check_is_fitted, safe_indexing
from ..compat import statsmodels as sm_compat
from ..compat import matplotlib as mpl_compat
from ..utils import get_callable, if_has_delegate, is_iterable, check_endog, \
check_exog
from ..utils import if_has_delegate, is_iterable, check_endog, check_exog
from ..utils.visualization import _get_plt

# Get the version
Expand All @@ -31,11 +30,6 @@
'ARIMA'
]

VALID_SCORING = {
'mse': mean_squared_error,
'mae': mean_absolute_error
}


def _aicc(model_results, nobs, add_constant):
"""Compute the corrected Akaike Information Criterion"""
Expand Down Expand Up @@ -191,9 +185,20 @@ class ARIMA(BaseARIMA):
> Score on: [5, 6]
> Append [5, 6] to end of self.arima_res_.data.endog values
scoring : str, optional (default='mse')
scoring : str or callable, optional (default='mse')
If performing validation (i.e., if ``out_of_sample_size`` > 0), the
metric to use for scoring the out-of-sample data. One of {'mse', 'mae'}
metric to use for scoring the out-of-sample data:
* If a string, must be a valid metric name importable from
``sklearn.metrics``.
* If a callable, must adhere to the function signature::
def foo_loss(y_true, y_pred)
Note that models are selected by *minimizing* loss. If using a
maximizing metric (such as ``sklearn.metrics.r2_score``), it is the
user's responsibility to wrap the function such that it returns a
negative value for minimizing.
scoring_args : dict, optional (default=None)
A dictionary of key-word arguments to be passed to the
Expand Down Expand Up @@ -439,7 +444,7 @@ def fit(self, y, exogenous=None, **fit_args):

# determine the CV args, if any
cv = self.out_of_sample_size
scoring = get_callable(self.scoring, VALID_SCORING)
scoring = val.get_scoring_metric(self.scoring)

# don't allow negative, don't allow > n_samples
cv = max(cv, 0)
Expand Down
20 changes: 13 additions & 7 deletions pmdarima/arima/tests/test_arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
import pandas as pd

from pmdarima.arima import ARIMA, auto_arima, AutoARIMA
from pmdarima.arima.arima import VALID_SCORING
from pmdarima.arima import _validation as val
from pmdarima.compat.pytest import pytest_error_str
from pmdarima.datasets import load_lynx, load_wineind, load_heartrate
from pmdarima.utils import get_callable

from numpy.random import RandomState
from numpy.testing import assert_array_almost_equal, assert_almost_equal, \
assert_allclose
from statsmodels import api as sm
from sklearn.metrics import mean_squared_error

import joblib
import os
Expand Down Expand Up @@ -150,8 +150,14 @@ def test_predict_in_sample_exog(model, exog, confints):
assert isinstance(res, np.ndarray)


def _two_times_mse(y_true, y_pred, **_):
"""A custom loss to test we can pass custom scoring metrics"""
return mean_squared_error(y_true, y_pred) * 2


@pytest.mark.parametrize('as_pd', [True, False])
def test_with_oob_and_exog(as_pd):
@pytest.mark.parametrize('scoring', ['mse', _two_times_mse])
def test_with_oob_and_exog(as_pd, scoring):
endog = hr
exog = np.random.RandomState(1).rand(hr.shape[0], 3)
if as_pd:
Expand All @@ -160,7 +166,7 @@ def test_with_oob_and_exog(as_pd):

arima = ARIMA(order=(2, 1, 2),
suppress_warnings=True,
scoring='mse',
scoring=scoring,
out_of_sample_size=10).fit(y=endog, exogenous=exog)

# show we can get oob score and preds
Expand All @@ -180,7 +186,7 @@ def test_with_oob():
# Assert the predictions give the expected MAE/MSE
oob_preds = arima.oob_preds_
assert oob_preds.shape[0] == 10
scoring = get_callable('mse', VALID_SCORING)
scoring = val.get_scoring_metric('mse')
assert scoring(hr[-10:], oob_preds) == oob

# show we can fit if ooss < 0 and oob will be nan
Expand Down Expand Up @@ -230,7 +236,7 @@ def test_oob_for_issue_28():
out_of_sample_size=0).fit(y=hr[:-10],
exogenous=xreg[:-10, :])

scoring = get_callable(arima_no_oob.scoring, VALID_SCORING)
scoring = val.get_scoring_metric(arima_no_oob.scoring)
preds = arima_no_oob.predict(n_periods=10, exogenous=xreg[-10:, :])
assert np.allclose(oob, scoring(hr[-10:], preds), rtol=1e-2)

Expand Down Expand Up @@ -290,7 +296,7 @@ def test_oob_sarimax():
oob = fit.oob()

# compare scores:
scoring = get_callable(fit_no_oob.scoring, VALID_SCORING)
scoring = val.get_scoring_metric(fit_no_oob.scoring)
no_oob_preds = fit_no_oob.predict(n_periods=15, exogenous=xreg[-15:, :])
assert np.allclose(oob, scoring(wineind[-15:], no_oob_preds), rtol=1e-2)

Expand Down
20 changes: 20 additions & 0 deletions pmdarima/arima/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,23 @@ def test_check_start_max_values(st, mx, argname, exp_vals, exp_err_msg):
def test_check_trace(trace, expected):
res = val.check_trace(trace)
assert expected == res


@pytest.mark.parametrize(
'metric,expected_error,expected_error_msg', [
pytest.param("mae", None, None),
pytest.param("mse", None, None),
pytest.param("mean_squared_error", None, None),
pytest.param("r2_score", None, None),
pytest.param("foo", ValueError, "is not a valid scoring"),
pytest.param(123, TypeError, "must be a valid scoring method, or a"),
]
)
def test_valid_metrics(metric, expected_error, expected_error_msg):
if not expected_error:
assert callable(val.get_scoring_metric(metric))
else:
with pytest.raises(expected_error) as err:
val.get_scoring_metric(metric)
assert expected_error_msg in pytest_error_str(err)

0 comments on commit d90d91e

Please sign in to comment.