diff --git a/kats/detectors/prophet_detector.py b/kats/detectors/prophet_detector.py index c5e16680..bf94a13b 100644 --- a/kats/detectors/prophet_detector.py +++ b/kats/detectors/prophet_detector.py @@ -18,12 +18,6 @@ import numpy as np import pandas as pd -from fbprophet import Prophet as FbProphet -from fbprophet.make_holidays import make_holidays_df as make_holidays_df_0 -from fbprophet.serialize import ( - model_from_json as model_from_json_0, - model_to_json as model_to_json_0, -) from kats.consts import ( DataError, DataInsufficientError, @@ -36,11 +30,8 @@ from kats.detectors.detector_consts import AnomalyResponse, ConfidenceBand from kats.models.prophet import predict from prophet import Prophet -from prophet.make_holidays import make_holidays_df as make_holidays_df_1 -from prophet.serialize import ( - model_from_json as model_from_json_1, - model_to_json as model_to_json_1, -) +from prophet.make_holidays import make_holidays_df +from prophet.serialize import model_from_json, model_to_json from pyre_extensions import ParameterSpecification from scipy.stats import norm @@ -172,35 +163,6 @@ class ProphetScoreFunction(Enum): z_score = "z_score" -class ProphetVersion(Enum): - fbprophet = "fbprophet" - prophet = "prophet" - - def make_holidays_df(self, *args: P.args, **kwargs: P.kwargs) -> pd.DataFrame: - if self == ProphetVersion.fbprophet: - return make_holidays_df_0(*args, **kwargs) - else: - return make_holidays_df_1(*args, **kwargs) - - def model_from_json(self, *args: P.args, **kwargs: P.kwargs) -> FbProphet | Prophet: - if self == ProphetVersion.fbprophet: - return model_from_json_0(*args, **kwargs) - else: - return model_from_json_1(*args, **kwargs) - - def model_to_json(self, *args: P.args, **kwargs: P.kwargs) -> str: - if self == ProphetVersion.fbprophet: - return model_to_json_0(*args, **kwargs) - else: - return model_to_json_1(*args, **kwargs) - - def create_prophet(self, *args: P.args, **kwargs: P.kwargs) -> FbProphet | Prophet: - if self == ProphetVersion.fbprophet: - return FbProphet(*args, **kwargs) - else: - return Prophet(*args, **kwargs) # pyre-ignore - - SCORE_FUNC_DICT: Dict[str, Any] = { ProphetScoreFunction.deviation_from_predicted_val.value: deviation_from_predicted_val, ProphetScoreFunction.z_score.value: z_score, @@ -324,7 +286,6 @@ def get_holiday_dates( holidays: Optional[pd.DataFrame] = None, country_holidays: Optional[str] = None, dates: Optional[pd.Series] = None, - prophet_version: ProphetVersion = ProphetVersion.prophet, ) -> pd.Series: if dates is None: return pd.Series() @@ -333,7 +294,7 @@ def get_holiday_dates( if holidays is not None: all_holidays = holidays.copy() if country_holidays: - country_holidays_df = prophet_version.make_holidays_df( + country_holidays_df = make_holidays_df( year_list=year_list, country=country_holidays ) all_holidays = pd.concat((all_holidays, country_holidays_df), sort=False) @@ -343,17 +304,6 @@ def get_holiday_dates( return all_holidays -def deserialize_model( - serialized_model: bytes, -) -> Tuple[FbProphet | Prophet, ProphetVersion]: - model_json = json.loads(serialized_model) - if "__fbprophet_version" in model_json: - prophet_version = ProphetVersion.fbprophet - else: - prophet_version = ProphetVersion.prophet - return prophet_version.model_from_json(serialized_model), prophet_version - - class ProphetDetectorModel(DetectorModel): """Prophet based anomaly detection model. @@ -373,8 +323,7 @@ class ProphetDetectorModel(DetectorModel): """ - model: Optional[FbProphet | Prophet] = None - prophet_version: ProphetVersion = ProphetVersion.prophet + model: Optional[Prophet] = None seasonalities: Dict[SeasonalityTypes, Union[bool, str]] = {} seasonalities_to_fit: Dict[SeasonalityTypes, Union[bool, str]] = {} @@ -423,7 +372,7 @@ def __init__( """ if serialized_model: - self.model, self.prophet_version = deserialize_model(serialized_model) + self.model = model_from_json(serialized_model) else: self.model = None @@ -465,7 +414,7 @@ def serialize(self) -> bytes: Returns: json containing information of the model. """ - return str.encode(self.prophet_version.model_to_json(self.model)) + return str.encode(model_to_json(self.model)) def fit_predict( self, @@ -529,7 +478,6 @@ def fit( self.outlier_threshold, uncertainty_samples=self.outlier_removal_uncertainty_samples, vectorize=self.vectorize, - prophet_version=self.prophet_version, ) # seasonalities depends on current time series self.seasonalities_to_fit = seasonalities_processing( @@ -559,7 +507,7 @@ def fit( self.holidays = pd.DataFrame(self.holidays_list) # No incremental training. Create a model and train from scratch - model = self.prophet_version.create_prophet( + model = Prophet( interval_width=self.scoring_confidence_interval, uncertainty_samples=self.uncertainty_samples, daily_seasonality=self.seasonalities_to_fit[SeasonalityTypes.DAY], @@ -651,7 +599,7 @@ def predict( pd.DataFrame(self.holidays_list) if self.holidays_list else None ) holidays_df: Optional[pd.Series] = get_holiday_dates( - custom_holidays, self.country_holidays, data.time, self.prophet_version + custom_holidays, self.country_holidays, data.time ) if holidays_df is not None: scores_ts = pd.Series(list(scores.value), index=data.time) @@ -679,7 +627,6 @@ def _remove_outliers( outlier_ci_threshold: float = 0.99, uncertainty_samples: float = OUTLIER_REMOVAL_UNCERTAINTY_SAMPLES, vectorize: bool = False, - prophet_version: ProphetVersion = ProphetVersion.prophet, ) -> pd.DataFrame: """ Remove outliers from the time series by fitting a Prophet model to the time series @@ -689,7 +636,7 @@ def _remove_outliers( ts_dates_df = pd.DataFrame({PROPHET_TIME_COLUMN: ts_df.iloc[:, 0]}) - model = prophet_version.create_prophet( + model = Prophet( interval_width=outlier_ci_threshold, uncertainty_samples=uncertainty_samples ) with ExitStack() as stack: @@ -711,8 +658,7 @@ def _remove_outliers( class ProphetTrendDetectorModel(DetectorModel): """Prophet based trend detection model.""" - model: Optional[FbProphet | Prophet] = None - prophet_version: ProphetVersion = ProphetVersion.prophet + model: Optional[Prophet] = None def __init__( self, @@ -722,7 +668,7 @@ def __init__( changepoint_prior_scale: float = 0.01, ) -> None: if serialized_model: - self.model, self.prophet_version = deserialize_model(serialized_model) + self.model = model_from_json(serialized_model) else: self.model = None @@ -738,7 +684,7 @@ def serialize(self) -> bytes: Returns: json containing information of the model. """ - return str.encode(self.prophet_version.model_to_json(self.model)) + return str.encode(model_to_json(self.model)) def _zeros_ts(self, data: pd.DataFrame) -> TimeSeriesData: return TimeSeriesData( @@ -755,7 +701,7 @@ def fit_predict( historical_data: Optional[TimeSeriesData] = None, **kwargs: Any, ) -> AnomalyResponse: - model = self.prophet_version.create_prophet( + model = Prophet( changepoint_range=self.changepoint_range, weekly_seasonality=self.weekly_seasonality, changepoint_prior_scale=self.changepoint_prior_scale, diff --git a/kats/detectors/threshold_detector.py b/kats/detectors/threshold_detector.py index a0e2902b..cb0da095 100644 --- a/kats/detectors/threshold_detector.py +++ b/kats/detectors/threshold_detector.py @@ -9,11 +9,11 @@ import numpy as np import pandas as pd -from fbprophet import Prophet -from fbprophet.serialize import model_from_json, model_to_json from kats.consts import DEFAULT_VALUE_NAME, TimeSeriesData from kats.detectors.detector import DetectorModel from kats.detectors.detector_consts import AnomalyResponse +from prophet import Prophet +from prophet.serialize import model_from_json, model_to_json class StaticThresholdModel(DetectorModel): diff --git a/kats/models/prophet.py b/kats/models/prophet.py index e4ff4c17..e0871cf9 100644 --- a/kats/models/prophet.py +++ b/kats/models/prophet.py @@ -12,14 +12,12 @@ try: # Prophet is an optional dependency for kats. - from fbprophet import Prophet as FbProphet from prophet import Prophet _no_prophet = False except ImportError: _no_prophet = True Prophet = Dict[str, Any] # for Pyre - FbProphet = Dict[str, Any] # for Pyre import numpy as np import numpy.typing as npt @@ -138,7 +136,7 @@ def __init__( extra_regressors: Optional[List[Dict[str, Any]]] = None, ) -> None: if _no_prophet: - raise RuntimeError("requires fbprophet to be installed") + raise RuntimeError("requires prophet to be installed") super().__init__() self.growth = growth self.changepoints = changepoints @@ -252,7 +250,7 @@ class ProphetModel(Model[ProphetParams]): def __init__(self, data: TimeSeriesData, params: ProphetParams) -> None: super().__init__(data, params) if _no_prophet: - raise RuntimeError("requires fbprophet to be installed") + raise RuntimeError("requires prophet to be installed") self.data: TimeSeriesData = data self._data_params_validation() @@ -517,7 +515,7 @@ def get_parameter_search_space() -> List[Dict[str, object]]: # From now on, the main logics are from github PR https://github.com/facebook/prophet/pull/2186 with some modifications. def predict_uncertainty( - prophet_model: Prophet | FbProphet, df: pd.DataFrame, vectorized: bool + prophet_model: Prophet, df: pd.DataFrame, vectorized: bool ) -> pd.DataFrame: """Prediction intervals for yhat and trend. @@ -548,7 +546,7 @@ def predict_uncertainty( def _sample_predictive_trend_vectorized( - prophet_model: Prophet | FbProphet, + prophet_model: Prophet, df: pd.DataFrame, n_samples: int, iteration: int = 0, @@ -594,7 +592,7 @@ def _sample_predictive_trend_vectorized( def _sample_trend_uncertainty( - prophet_model: Prophet | FbProphet, + prophet_model: Prophet, n_samples: int, df: pd.DataFrame, iteration: int = 0, @@ -683,7 +681,7 @@ def _make_trend_shift_matrix( def predict( - prophet_model: Prophet | FbProphet, + prophet_model: Prophet, df: Optional[pd.DataFrame] = None, vectorized: bool = False, ) -> pd.DataFrame: @@ -730,7 +728,7 @@ def predict( def sample_model_vectorized( - prophet_model: Prophet | FbProphet, + prophet_model: Prophet, df: pd.DataFrame, seasonal_features: pd.DataFrame, iteration: int, @@ -761,7 +759,7 @@ def sample_model_vectorized( def sample_posterior_predictive( - prophet_model: Prophet | FbProphet, df: pd.DataFrame, vectorized: bool + prophet_model: Prophet, df: pd.DataFrame, vectorized: bool ) -> Dict[str, npt.NDArray]: """Generate posterior samples of a trained Prophet model. @@ -836,7 +834,7 @@ def _make_historical_mat_time( def _logistic_uncertainty( - prophet_model: Prophet | FbProphet, + prophet_model: Prophet, mat: npt.NDArray, deltas: npt.NDArray, k: float, @@ -905,7 +903,7 @@ def _piecewise_linear_vectorize( def sample_linear_predictive_trend_vectorize( - prophet_model: Prophet | FbProphet, + prophet_model: Prophet, df: pd.DataFrame, sample_size: int, iteration: int, diff --git a/kats/tests/detectors/test_prophet_detector.py b/kats/tests/detectors/test_prophet_detector.py index b8747811..b67ef732 100644 --- a/kats/tests/detectors/test_prophet_detector.py +++ b/kats/tests/detectors/test_prophet_detector.py @@ -13,17 +13,14 @@ import numpy as np import pandas as pd -from fbprophet import Prophet as FbProphet # @manual from kats.consts import TimeSeriesData from kats.data.utils import load_air_passengers from kats.detectors.detector_consts import AnomalyResponse from kats.detectors.prophet_detector import ( - deserialize_model, get_holiday_dates, ProphetDetectorModel, ProphetScoreFunction, ProphetTrendDetectorModel, - ProphetVersion, SeasonalityTypes, to_seasonality, ) @@ -960,67 +957,13 @@ def test_asymmetric_noise_signal(self, test_index: int) -> None: response2.scores.value[test_index], response1.scores.value[test_index] ) - # pyre-fixme[56]: Pyre was not able to infer the type of the decorator - @parameterized.expand( - [ - [FbProphet, ProphetVersion.fbprophet], - [Prophet, ProphetVersion.prophet], - ] - ) - def test_deserialize_model( - self, - prophet_cls: Type[FbProphet | Prophet], - prophet_version: ProphetVersion, - ) -> None: - ts = self.create_random_ts(0, 100, 10, 2) - ProphetDetectorModel.prophet_version = prophet_version - detector_model = ProphetDetectorModel() - detector_model.fit(ts[:90]) - self.assertIsInstance(detector_model.model, prophet_cls) - serialized_model = detector_model.serialize() - deserialized_model, deserialized_prophet_version = deserialize_model( - serialized_model - ) - self.assertEqual(deserialized_prophet_version, prophet_version) - self.assertIsInstance(deserialized_model, prophet_cls) - deserialized_detector_model = ProphetDetectorModel( - serialized_model=serialized_model - ) - self.assertIsInstance(deserialized_detector_model.model, prophet_cls) - anomaly_response_original = detector_model.predict(ts[90:]) - anomaly_response_deserialized = deserialized_detector_model.predict(ts[90:]) - np.testing.assert_almost_equal( - anomaly_response_original.scores.value.to_numpy(), - anomaly_response_deserialized.scores.value.to_numpy(), - ) - original_predicted_ts = anomaly_response_original.predicted_ts - assert original_predicted_ts is not None # for pyre - deserialized_predicted_ts = anomaly_response_deserialized.predicted_ts - assert deserialized_predicted_ts is not None # for pyre - np.testing.assert_almost_equal( - original_predicted_ts.value.to_numpy(), - deserialized_predicted_ts.value.to_numpy(), - ) - - # pyre-fixme[56]: Pyre was not able to infer the type of the decorator - @parameterized.expand( - [ - [ProphetVersion.fbprophet, "__fbprophet_version"], - [ProphetVersion.prophet, "__prophet_version"], - ] - ) - def test_serialized_prophet_version_key( - self, - prophet_version: ProphetVersion, - prophet_version_key: str, - ) -> None: + def test_serialized_prophet_version_key(self) -> None: ts = self.create_random_ts(0, 100, 10, 2) - ProphetDetectorModel.prophet_version = prophet_version detector_model = ProphetDetectorModel() detector_model.fit(ts[:90]) serialized_model = detector_model.serialize() model_json = json.loads(serialized_model) - self.assertIn(prophet_version_key, model_json) + self.assertIn("__prophet_version", model_json) class TestProphetTrendDetectorModel(TestCase): diff --git a/kats/tests/models/test_prophet_model.py b/kats/tests/models/test_prophet_model.py index b210acba..5be3ac44 100644 --- a/kats/tests/models/test_prophet_model.py +++ b/kats/tests/models/test_prophet_model.py @@ -116,7 +116,7 @@ def mock_prophet_import(module: Any, *args: Any, **kwargs: Any) -> None: cls.mock_imports = patch("builtins.__import__", side_effect=mock_prophet_import) - def test_fbprophet_not_installed(self) -> None: + def test_prophet_not_installed(self) -> None: # Unload prophet module so its imports can be mocked as necessary del sys.modules["kats.models.prophet"]