Skip to content

Commit

Permalink
2/n remove fbprophet from Kats
Browse files Browse the repository at this point in the history
Summary:
This diff fully removes the `fbprophet` dependency from Kats.

bye

Differential Revision: D67025007

fbshipit-source-id: 875ee7278e48042cf1d17ae69a14a62490fb0d7b
  • Loading branch information
islijepcevic authored and facebook-github-bot committed Dec 17, 2024
1 parent cb3a584 commit 26160ad
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 141 deletions.
80 changes: 13 additions & 67 deletions kats/detectors/prophet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@

import numpy as np
import pandas as pd
from fbprophet import Prophet as FbProphet
from fbprophet.make_holidays import make_holidays_df as make_holidays_df_0
from fbprophet.serialize import (
model_from_json as model_from_json_0,
model_to_json as model_to_json_0,
)
from kats.consts import (
DataError,
DataInsufficientError,
Expand All @@ -36,11 +30,8 @@
from kats.detectors.detector_consts import AnomalyResponse, ConfidenceBand
from kats.models.prophet import predict
from prophet import Prophet
from prophet.make_holidays import make_holidays_df as make_holidays_df_1
from prophet.serialize import (
model_from_json as model_from_json_1,
model_to_json as model_to_json_1,
)
from prophet.make_holidays import make_holidays_df
from prophet.serialize import model_from_json, model_to_json
from pyre_extensions import ParameterSpecification
from scipy.stats import norm

Expand Down Expand Up @@ -172,35 +163,6 @@ class ProphetScoreFunction(Enum):
z_score = "z_score"


class ProphetVersion(Enum):
fbprophet = "fbprophet"
prophet = "prophet"

def make_holidays_df(self, *args: P.args, **kwargs: P.kwargs) -> pd.DataFrame:
if self == ProphetVersion.fbprophet:
return make_holidays_df_0(*args, **kwargs)
else:
return make_holidays_df_1(*args, **kwargs)

def model_from_json(self, *args: P.args, **kwargs: P.kwargs) -> FbProphet | Prophet:
if self == ProphetVersion.fbprophet:
return model_from_json_0(*args, **kwargs)
else:
return model_from_json_1(*args, **kwargs)

def model_to_json(self, *args: P.args, **kwargs: P.kwargs) -> str:
if self == ProphetVersion.fbprophet:
return model_to_json_0(*args, **kwargs)
else:
return model_to_json_1(*args, **kwargs)

def create_prophet(self, *args: P.args, **kwargs: P.kwargs) -> FbProphet | Prophet:
if self == ProphetVersion.fbprophet:
return FbProphet(*args, **kwargs)
else:
return Prophet(*args, **kwargs) # pyre-ignore


SCORE_FUNC_DICT: Dict[str, Any] = {
ProphetScoreFunction.deviation_from_predicted_val.value: deviation_from_predicted_val,
ProphetScoreFunction.z_score.value: z_score,
Expand Down Expand Up @@ -324,7 +286,6 @@ def get_holiday_dates(
holidays: Optional[pd.DataFrame] = None,
country_holidays: Optional[str] = None,
dates: Optional[pd.Series] = None,
prophet_version: ProphetVersion = ProphetVersion.prophet,
) -> pd.Series:
if dates is None:
return pd.Series()
Expand All @@ -333,7 +294,7 @@ def get_holiday_dates(
if holidays is not None:
all_holidays = holidays.copy()
if country_holidays:
country_holidays_df = prophet_version.make_holidays_df(
country_holidays_df = make_holidays_df(
year_list=year_list, country=country_holidays
)
all_holidays = pd.concat((all_holidays, country_holidays_df), sort=False)
Expand All @@ -343,17 +304,6 @@ def get_holiday_dates(
return all_holidays


def deserialize_model(
serialized_model: bytes,
) -> Tuple[FbProphet | Prophet, ProphetVersion]:
model_json = json.loads(serialized_model)
if "__fbprophet_version" in model_json:
prophet_version = ProphetVersion.fbprophet
else:
prophet_version = ProphetVersion.prophet
return prophet_version.model_from_json(serialized_model), prophet_version


class ProphetDetectorModel(DetectorModel):
"""Prophet based anomaly detection model.
Expand All @@ -373,8 +323,7 @@ class ProphetDetectorModel(DetectorModel):
"""

model: Optional[FbProphet | Prophet] = None
prophet_version: ProphetVersion = ProphetVersion.prophet
model: Optional[Prophet] = None
seasonalities: Dict[SeasonalityTypes, Union[bool, str]] = {}
seasonalities_to_fit: Dict[SeasonalityTypes, Union[bool, str]] = {}

Expand Down Expand Up @@ -423,7 +372,7 @@ def __init__(
"""

if serialized_model:
self.model, self.prophet_version = deserialize_model(serialized_model)
self.model = model_from_json(serialized_model)
else:
self.model = None

Expand Down Expand Up @@ -465,7 +414,7 @@ def serialize(self) -> bytes:
Returns:
json containing information of the model.
"""
return str.encode(self.prophet_version.model_to_json(self.model))
return str.encode(model_to_json(self.model))

def fit_predict(
self,
Expand Down Expand Up @@ -529,7 +478,6 @@ def fit(
self.outlier_threshold,
uncertainty_samples=self.outlier_removal_uncertainty_samples,
vectorize=self.vectorize,
prophet_version=self.prophet_version,
)
# seasonalities depends on current time series
self.seasonalities_to_fit = seasonalities_processing(
Expand Down Expand Up @@ -559,7 +507,7 @@ def fit(
self.holidays = pd.DataFrame(self.holidays_list)

# No incremental training. Create a model and train from scratch
model = self.prophet_version.create_prophet(
model = Prophet(
interval_width=self.scoring_confidence_interval,
uncertainty_samples=self.uncertainty_samples,
daily_seasonality=self.seasonalities_to_fit[SeasonalityTypes.DAY],
Expand Down Expand Up @@ -651,7 +599,7 @@ def predict(
pd.DataFrame(self.holidays_list) if self.holidays_list else None
)
holidays_df: Optional[pd.Series] = get_holiday_dates(
custom_holidays, self.country_holidays, data.time, self.prophet_version
custom_holidays, self.country_holidays, data.time
)
if holidays_df is not None:
scores_ts = pd.Series(list(scores.value), index=data.time)
Expand Down Expand Up @@ -679,7 +627,6 @@ def _remove_outliers(
outlier_ci_threshold: float = 0.99,
uncertainty_samples: float = OUTLIER_REMOVAL_UNCERTAINTY_SAMPLES,
vectorize: bool = False,
prophet_version: ProphetVersion = ProphetVersion.prophet,
) -> pd.DataFrame:
"""
Remove outliers from the time series by fitting a Prophet model to the time series
Expand All @@ -689,7 +636,7 @@ def _remove_outliers(

ts_dates_df = pd.DataFrame({PROPHET_TIME_COLUMN: ts_df.iloc[:, 0]})

model = prophet_version.create_prophet(
model = Prophet(
interval_width=outlier_ci_threshold, uncertainty_samples=uncertainty_samples
)
with ExitStack() as stack:
Expand All @@ -711,8 +658,7 @@ def _remove_outliers(
class ProphetTrendDetectorModel(DetectorModel):
"""Prophet based trend detection model."""

model: Optional[FbProphet | Prophet] = None
prophet_version: ProphetVersion = ProphetVersion.prophet
model: Optional[Prophet] = None

def __init__(
self,
Expand All @@ -722,7 +668,7 @@ def __init__(
changepoint_prior_scale: float = 0.01,
) -> None:
if serialized_model:
self.model, self.prophet_version = deserialize_model(serialized_model)
self.model = model_from_json(serialized_model)
else:
self.model = None

Expand All @@ -738,7 +684,7 @@ def serialize(self) -> bytes:
Returns:
json containing information of the model.
"""
return str.encode(self.prophet_version.model_to_json(self.model))
return str.encode(model_to_json(self.model))

def _zeros_ts(self, data: pd.DataFrame) -> TimeSeriesData:
return TimeSeriesData(
Expand All @@ -755,7 +701,7 @@ def fit_predict(
historical_data: Optional[TimeSeriesData] = None,
**kwargs: Any,
) -> AnomalyResponse:
model = self.prophet_version.create_prophet(
model = Prophet(
changepoint_range=self.changepoint_range,
weekly_seasonality=self.weekly_seasonality,
changepoint_prior_scale=self.changepoint_prior_scale,
Expand Down
4 changes: 2 additions & 2 deletions kats/detectors/threshold_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

import numpy as np
import pandas as pd
from fbprophet import Prophet
from fbprophet.serialize import model_from_json, model_to_json
from kats.consts import DEFAULT_VALUE_NAME, TimeSeriesData
from kats.detectors.detector import DetectorModel
from kats.detectors.detector_consts import AnomalyResponse
from prophet import Prophet
from prophet.serialize import model_from_json, model_to_json


class StaticThresholdModel(DetectorModel):
Expand Down
22 changes: 10 additions & 12 deletions kats/models/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@

try:
# Prophet is an optional dependency for kats.
from fbprophet import Prophet as FbProphet
from prophet import Prophet

_no_prophet = False
except ImportError:
_no_prophet = True
Prophet = Dict[str, Any] # for Pyre
FbProphet = Dict[str, Any] # for Pyre

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -138,7 +136,7 @@ def __init__(
extra_regressors: Optional[List[Dict[str, Any]]] = None,
) -> None:
if _no_prophet:
raise RuntimeError("requires fbprophet to be installed")
raise RuntimeError("requires prophet to be installed")
super().__init__()
self.growth = growth
self.changepoints = changepoints
Expand Down Expand Up @@ -252,7 +250,7 @@ class ProphetModel(Model[ProphetParams]):
def __init__(self, data: TimeSeriesData, params: ProphetParams) -> None:
super().__init__(data, params)
if _no_prophet:
raise RuntimeError("requires fbprophet to be installed")
raise RuntimeError("requires prophet to be installed")
self.data: TimeSeriesData = data
self._data_params_validation()

Expand Down Expand Up @@ -517,7 +515,7 @@ def get_parameter_search_space() -> List[Dict[str, object]]:

# From now on, the main logics are from github PR https://github.com/facebook/prophet/pull/2186 with some modifications.
def predict_uncertainty(
prophet_model: Prophet | FbProphet, df: pd.DataFrame, vectorized: bool
prophet_model: Prophet, df: pd.DataFrame, vectorized: bool
) -> pd.DataFrame:
"""Prediction intervals for yhat and trend.
Expand Down Expand Up @@ -548,7 +546,7 @@ def predict_uncertainty(


def _sample_predictive_trend_vectorized(
prophet_model: Prophet | FbProphet,
prophet_model: Prophet,
df: pd.DataFrame,
n_samples: int,
iteration: int = 0,
Expand Down Expand Up @@ -594,7 +592,7 @@ def _sample_predictive_trend_vectorized(


def _sample_trend_uncertainty(
prophet_model: Prophet | FbProphet,
prophet_model: Prophet,
n_samples: int,
df: pd.DataFrame,
iteration: int = 0,
Expand Down Expand Up @@ -683,7 +681,7 @@ def _make_trend_shift_matrix(


def predict(
prophet_model: Prophet | FbProphet,
prophet_model: Prophet,
df: Optional[pd.DataFrame] = None,
vectorized: bool = False,
) -> pd.DataFrame:
Expand Down Expand Up @@ -730,7 +728,7 @@ def predict(


def sample_model_vectorized(
prophet_model: Prophet | FbProphet,
prophet_model: Prophet,
df: pd.DataFrame,
seasonal_features: pd.DataFrame,
iteration: int,
Expand Down Expand Up @@ -761,7 +759,7 @@ def sample_model_vectorized(


def sample_posterior_predictive(
prophet_model: Prophet | FbProphet, df: pd.DataFrame, vectorized: bool
prophet_model: Prophet, df: pd.DataFrame, vectorized: bool
) -> Dict[str, npt.NDArray]:
"""Generate posterior samples of a trained Prophet model.
Expand Down Expand Up @@ -836,7 +834,7 @@ def _make_historical_mat_time(


def _logistic_uncertainty(
prophet_model: Prophet | FbProphet,
prophet_model: Prophet,
mat: npt.NDArray,
deltas: npt.NDArray,
k: float,
Expand Down Expand Up @@ -905,7 +903,7 @@ def _piecewise_linear_vectorize(


def sample_linear_predictive_trend_vectorize(
prophet_model: Prophet | FbProphet,
prophet_model: Prophet,
df: pd.DataFrame,
sample_size: int,
iteration: int,
Expand Down
Loading

0 comments on commit 26160ad

Please sign in to comment.