diff --git a/custom_components/calendarific/__init__.py b/custom_components/calendarific/__init__.py index 58eb513..3bffc01 100644 --- a/custom_components/calendarific/__init__.py +++ b/custom_components/calendarific/__init__.py @@ -7,6 +7,7 @@ import voluptuous as vol from homeassistant.config_entries import ConfigEntry +from homeassistant.const import Platform from homeassistant.core import HomeAssistant from homeassistant import config_entries @@ -37,8 +38,6 @@ _LOGGER = logging.getLogger(__name__) -holiday_list = [] - def setup(hass, config): """Set up platform using YAML.""" if DOMAIN in config: @@ -54,14 +53,12 @@ def setup(hass, config): async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): - hass.async_create_task( - hass.config_entries.async_forward_entry_setup(entry, "sensor") - ) + await hass.config_entries.async_forward_entry_setups(entry, [Platform.SENSOR]) return True async def async_unload_entry(hass, entry): """Unload a config entry.""" - return await hass.config_entries.async_forward_entry_unload(entry, "sensor") + return await hass.config_entries.async_forward_entry_unload(entry, Platform.SENSOR) class CalendarificApiReader: @@ -96,6 +93,9 @@ def get_description(self,holiday_name): return next(i for i in self._holidays if i['name'] == holiday_name)['description'] except: return "NOT FOUND" + + def get_holidays(self): + return [item['name'] for item in self._holidays] def update(self): if self._lastupdated == datetime.now().date(): @@ -121,10 +121,6 @@ def update(self): return self._error_logged = False self._next_holidays = response['response']['holidays'] - global holiday_list - holiday_list = [] - for holiday in self._holidays: - holiday_list.append(holiday['name']) return True diff --git a/custom_components/calendarific/config_flow.py b/custom_components/calendarific/config_flow.py index b111ea2..b489b1e 100644 --- a/custom_components/calendarific/config_flow.py +++ b/custom_components/calendarific/config_flow.py @@ -5,6 +5,7 @@ import voluptuous as vol from homeassistant import config_entries +from homeassistant import core from homeassistant.const import CONF_NAME from homeassistant.core import HomeAssistant, callback @@ -27,8 +28,6 @@ CONF_UNIT_OF_MEASUREMENT, ) -from . import holiday_list - _LOGGER = logging.getLogger(__name__) @callback @@ -46,9 +45,11 @@ def __init__(self) -> None: self._errors = {} self._data = {} self._data["unique_id"] = str(uuid.uuid4()) + hass = core.async_get_hass() + self._holiday_list = hass.data[DOMAIN]["apiReader"].get_holidays() async def async_step_user(self, user_input=None): - if holiday_list == []: + if self._holiday_list == []: return self.async_abort(reason="no_holidays_found") self._errors = {} if user_input is not None: @@ -84,7 +85,7 @@ async def _show_user_form(self, user_input): if CONF_UNIT_OF_MEASUREMENT in user_input: unit_of_measurement = user_input[CONF_UNIT_OF_MEASUREMENT] data_schema = OrderedDict() - data_schema[vol.Required(CONF_HOLIDAY, default=holiday)] = vol.In(holiday_list) + data_schema[vol.Required(CONF_HOLIDAY, default=holiday)] = vol.In(self._holiday_list) data_schema[vol.Optional(CONF_NAME, default=name)] = str data_schema[vol.Required(CONF_UNIT_OF_MEASUREMENT, default=unit_of_measurement)] = str data_schema[vol.Required(CONF_ICON_NORMAL, default=icon_normal)] = str