-
Notifications
You must be signed in to change notification settings - Fork 310
/
Copy pathfactory.py
123 lines (111 loc) · 5.43 KB
/
factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#
# Copyright (c) 2023 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
"""
Contains the `ModelFactory`.
"""
import copy
import inspect
import logging
from typing import Dict, Tuple, Type, Union
import dill
from merlion.models.base import ModelBase
from merlion.utils import dynamic_import
logger = logging.getLogger(__name__)
import_alias = dict(
# Default models
DefaultDetector="merlion.models.defaults:DefaultDetector",
DefaultForecaster="merlion.models.defaults:DefaultForecaster",
# Anomaly detection models
ArimaDetector="merlion.models.anomaly.forecast_based.arima:ArimaDetector",
DynamicBaseline="merlion.models.anomaly.dbl:DynamicBaseline",
IsolationForest="merlion.models.anomaly.isolation_forest:IsolationForest",
LocalOutlierFactor="merlion.models.anomaly.lof:LOF",
# Forecast-based anomaly detection models
ETSDetector="merlion.models.anomaly.forecast_based.ets:ETSDetector",
MSESDetector="merlion.models.anomaly.forecast_based.mses:MSESDetector",
ProphetDetector="merlion.models.anomaly.forecast_based.prophet:ProphetDetector",
RandomCutForest="merlion.models.anomaly.random_cut_forest:RandomCutForest",
SarimaDetector="merlion.models.anomaly.forecast_based.sarima:SarimaDetector",
WindStats="merlion.models.anomaly.windstats:WindStats",
SpectralResidual="merlion.models.anomaly.spectral_residual:SpectralResidual",
ZMS="merlion.models.anomaly.zms:ZMS",
# DeepPointAnomalyDetector="merlion.models.anomaly.deep_point_anomaly_detector:DeepPointAnomalyDetector",
# Multivariate Anomaly Detection models
AutoEncoder="merlion.models.anomaly.autoencoder:AutoEncoder",
VAE="merlion.models.anomaly.vae:VAE",
DAGMM="merlion.models.anomaly.dagmm:DAGMM",
LSTMED="merlion.models.anomaly.lstm_ed:LSTMED",
# Change point detection models
BOCPD="merlion.models.anomaly.change_point.bocpd:BOCPD",
# Forecasting models
Arima="merlion.models.forecast.arima:Arima",
ETS="merlion.models.forecast.ets:ETS",
MSES="merlion.models.forecast.smoother:MSES",
Prophet="merlion.models.forecast.prophet:Prophet",
Sarima="merlion.models.forecast.sarima:Sarima",
StatThreshold="merlion.models.anomaly.stat_threshold:StatThreshold",
VectorAR="merlion.models.forecast.vector_ar:VectorAR",
RandomForestForecaster="merlion.models.forecast.trees:RandomForestForecaster",
ExtraTreesForecaster="merlion.models.forecast.trees:ExtraTreesForecaster",
LGBMForecaster="merlion.models.forecast.trees:LGBMForecaster",
TransformerForecaster="merlion.models.forecast.transformer:TransformerForecaster",
InformerForecaster="merlion.models.forecast.informer:InformerForecaster",
AutoformerForecaster="merlion.models.forecast.autoformer:AutoformerForecaster",
ETSformerForecaster="merlion.models.forecast.etsformer:ETSformerForecaster",
DeepARForecaster="merlion.models.forecast.deep_ar:DeepARForecaster",
# Ensembles
DetectorEnsemble="merlion.models.ensemble.anomaly:DetectorEnsemble",
ForecasterEnsemble="merlion.models.ensemble.forecast:ForecasterEnsemble",
# Layers
SeasonalityLayer="merlion.models.automl.seasonality:SeasonalityLayer",
AutoETS="merlion.models.automl.autoets:AutoETS",
AutoProphet="merlion.models.automl.autoprophet:AutoProphet",
AutoSarima="merlion.models.automl.autosarima:AutoSarima",
)
class ModelFactory:
@classmethod
def get_model_class(cls, name: str) -> Type[ModelBase]:
return dynamic_import(name, import_alias)
@classmethod
def create(cls, name, return_unused_kwargs=False, **kwargs) -> Union[ModelBase, Tuple[ModelBase, Dict]]:
model_class = cls.get_model_class(name)
config, kwargs = model_class.config_class.from_dict(kwargs, return_unused_kwargs=True)
# initialize the model
signature = inspect.signature(model_class)
init_kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters}
kwargs = {k: v for k, v in kwargs.items() if k not in init_kwargs}
model = model_class(config=config, **init_kwargs)
# set model state with remaining kwargs, and return any unused kwargs if desired
if return_unused_kwargs:
state = {k: v for k, v in kwargs.items() if hasattr(model, k)}
model._load_state(state)
return model, {k: v for k, v in kwargs.items() if k not in state}
model._load_state(kwargs)
return model
@classmethod
def load(cls, name, model_path, **kwargs) -> ModelBase:
if model_path is None:
return cls.create(name, **kwargs)
else:
model_class = cls.get_model_class(name)
return model_class.load(model_path, **kwargs)
@classmethod
def load_bytes(cls, obj, **kwargs) -> ModelBase:
name = dill.loads(obj)[0]
model_class = cls.get_model_class(name)
return model_class.from_bytes(obj, **kwargs)
def instantiate_or_copy_model(model: Union[dict, ModelBase]):
if isinstance(model, ModelBase):
return copy.deepcopy(model)
elif isinstance(model, dict):
try:
return ModelFactory.create(**model)
except Exception as e:
logger.error(f"Invalid `dict` specifying a model config.\n\nGot {model}")
raise e
else:
raise TypeError(f"Expected model to be a `dict` or `ModelBase`. Got {model}")