-
Notifications
You must be signed in to change notification settings - Fork 310
/
Copy path__init__.py
50 lines (46 loc) · 2.1 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#
# Copyright (c) 2023 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
"""
Datasets for time series forecasting. Really, these are just time series with
no labels of any sort.
"""
from ts_datasets.base import BaseDataset
from ts_datasets.forecast.custom import CustomDataset
from ts_datasets.forecast.m4 import M4
from ts_datasets.forecast.energy_power import EnergyPower
from ts_datasets.forecast.seattle_trail import SeattleTrail
from ts_datasets.forecast.solar_plant import SolarPlant
__all__ = ["get_dataset", "CustomDataset", "M4", "EnergyPower", "SeattleTrail", "SolarPlant"]
def get_dataset(dataset_name: str, rootdir: str = None, **kwargs) -> BaseDataset:
"""
:param dataset_name: the name of the dataset to load, formatted as
``<name>`` or ``<name>_<subset>``, e.g. ``EnergyPower`` or ``M4_Hourly``
:param rootdir: the directory where the desired dataset is stored. Not
required if the package :py:mod:`ts_datasets` is installed in editable
mode, i.e. with flag ``-e``.
:param kwargs: keyword arguments for the data loader you are trying to load.
:return: the data loader for the desired dataset (and subset) desired
"""
name_subset = dataset_name.split("_", maxsplit=1)
valid_datasets = set(__all__).difference({"get_dataset"})
if name_subset[0] in valid_datasets:
cls = globals()[name_subset[0]]
else:
raise KeyError(
"Dataset should be formatted as <name> or "
"<name>_<subset>, where <name> is one of "
f"{valid_datasets}. Got {dataset_name} instead."
)
if not hasattr(cls, "valid_subsets") and len(name_subset) == 2:
raise ValueError(
f"Dataset {name_subset[0]} does not have any subsets, "
f"but attempted to load subset {name_subset[1]} by "
f"specifying dataset name {dataset_name}."
)
if len(name_subset) > 1:
kwargs.update(subset=name_subset[1])
return cls(rootdir=rootdir, **kwargs)