From addc2c43408388eda100d4a4d6969845fc013c0f Mon Sep 17 00:00:00 2001 From: andrey-git Date: Sun, 22 Jan 2017 21:19:50 +0200 Subject: [PATCH] Allow easier customization of whole domain, entity lists, globs. (#5215) --- homeassistant/config.py | 38 ++++++---- homeassistant/helpers/config_validation.py | 9 +++ homeassistant/helpers/customize.py | 80 ++++++++++++++++++++ homeassistant/helpers/entity.py | 18 +---- tests/helpers/test_config_validation.py | 28 +++++++ tests/helpers/test_customize.py | 87 ++++++++++++++++++++++ tests/helpers/test_entity.py | 6 +- tests/test_config.py | 67 +++++++++++++---- 8 files changed, 285 insertions(+), 48 deletions(-) create mode 100644 homeassistant/helpers/customize.py create mode 100644 tests/helpers/test_customize.py diff --git a/homeassistant/config.py b/homeassistant/config.py index eb29212a67d0f6..bbfee5730a81e6 100644 --- a/homeassistant/config.py +++ b/homeassistant/config.py @@ -6,7 +6,7 @@ import shutil from types import MappingProxyType # pylint: disable=unused-import -from typing import Any, Tuple # NOQA +from typing import Any, List, Tuple # NOQA import voluptuous as vol @@ -14,15 +14,15 @@ CONF_LATITUDE, CONF_LONGITUDE, CONF_NAME, CONF_PACKAGES, CONF_UNIT_SYSTEM, CONF_TIME_ZONE, CONF_CUSTOMIZE, CONF_ELEVATION, CONF_UNIT_SYSTEM_METRIC, CONF_UNIT_SYSTEM_IMPERIAL, CONF_TEMPERATURE_UNIT, TEMP_CELSIUS, - __version__) -from homeassistant.core import valid_entity_id, DOMAIN as CONF_CORE + CONF_ENTITY_ID, __version__) +from homeassistant.core import DOMAIN as CONF_CORE from homeassistant.exceptions import HomeAssistantError from homeassistant.loader import get_component from homeassistant.util.yaml import load_yaml import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.entity import set_customize from homeassistant.util import dt as date_util, location as loc_util from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM +from homeassistant.helpers.customize import set_customize _LOGGER = logging.getLogger(__name__) @@ -87,19 +87,24 @@ """ -def _valid_customize(value): - """Config validator for customize.""" - if not isinstance(value, dict): - raise vol.Invalid('Expected dictionary') +CUSTOMIZE_SCHEMA_ENTRY = vol.Schema({ + vol.Required(CONF_ENTITY_ID): vol.All( + cv.ensure_list_csv, vol.Length(min=1), [cv.string], [vol.Lower]) +}, extra=vol.ALLOW_EXTRA) - for key, val in value.items(): - if not valid_entity_id(key): - raise vol.Invalid('Invalid entity ID: {}'.format(key)) - if not isinstance(val, dict): - raise vol.Invalid('Value of {} is not a dictionary'.format(key)) +def _convert_old_config(inp: Any) -> List: + if not isinstance(inp, dict): + return cv.ensure_list(inp) + if CONF_ENTITY_ID in inp: + return [inp] # sigle entry + res = [] - return value + inp = vol.Schema({cv.match_all: dict})(inp) + for key, val in inp.items(): + val[CONF_ENTITY_ID] = key + res.append(val) + return res PACKAGES_CONFIG_SCHEMA = vol.Schema({ @@ -116,7 +121,8 @@ def _valid_customize(value): CONF_UNIT_SYSTEM: cv.unit_system, CONF_TIME_ZONE: cv.time_zone, vol.Required(CONF_CUSTOMIZE, - default=MappingProxyType({})): _valid_customize, + default=MappingProxyType({})): vol.All( + _convert_old_config, [CUSTOMIZE_SCHEMA_ENTRY]), vol.Optional(CONF_PACKAGES, default={}): PACKAGES_CONFIG_SCHEMA, }) @@ -301,7 +307,7 @@ def set_time_zone(time_zone_str): if CONF_TIME_ZONE in config: set_time_zone(config.get(CONF_TIME_ZONE)) - set_customize(config.get(CONF_CUSTOMIZE) or {}) + set_customize(hass, config.get(CONF_CUSTOMIZE) or {}) if CONF_UNIT_SYSTEM in config: if config[CONF_UNIT_SYSTEM] == CONF_UNIT_SYSTEM_IMPERIAL: diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index b78eedec8c2f03..0c28dbdd78e323 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -376,6 +376,8 @@ def validator(value): """Validate ordered dict.""" config = OrderedDict() + if not isinstance(value, dict): + raise vol.Invalid('Value {} is not a dictionary'.format(value)) for key, val in value.items(): v_res = item_validator({key: val}) config.update(v_res) @@ -385,6 +387,13 @@ def validator(value): return validator +def ensure_list_csv(value: Any) -> Sequence: + """Ensure that input is a list or make one from comma-separated string.""" + if isinstance(value, str): + return [member.strip() for member in value.split(',')] + return ensure_list(value) + + # Validator helpers def key_dependency(key, dependency): diff --git a/homeassistant/helpers/customize.py b/homeassistant/helpers/customize.py new file mode 100644 index 00000000000000..b03a89ff40f4d7 --- /dev/null +++ b/homeassistant/helpers/customize.py @@ -0,0 +1,80 @@ +"""A helper module for customization.""" +import collections +from typing import Dict, List +import fnmatch + +from homeassistant.const import CONF_ENTITY_ID +from homeassistant.core import HomeAssistant, split_entity_id + +_OVERWRITE_KEY = 'overwrite' +_OVERWRITE_CACHE_KEY = 'overwrite_cache' + + +def set_customize(hass: HomeAssistant, customize: List[Dict]) -> None: + """Overwrite all current customize settings. + + Async friendly. + """ + hass.data[_OVERWRITE_KEY] = customize + hass.data[_OVERWRITE_CACHE_KEY] = {} + + +def get_overrides(hass: HomeAssistant, entity_id: str) -> Dict: + """Return a dictionary of overrides related to entity_id. + + Whole-domain overrides are of lowest priorities, + then glob on entity ID, and finally exact entity_id + matches are of highest priority. + + The lookups are cached. + """ + if _OVERWRITE_CACHE_KEY in hass.data and \ + entity_id in hass.data[_OVERWRITE_CACHE_KEY]: + return hass.data[_OVERWRITE_CACHE_KEY][entity_id] + if _OVERWRITE_KEY not in hass.data: + return {} + domain_result = {} # type: Dict[str, Any] + glob_result = {} # type: Dict[str, Any] + exact_result = {} # type: Dict[str, Any] + domain = split_entity_id(entity_id)[0] + + def clean_entry(entry: Dict) -> Dict: + """Clean up entity-matching keys.""" + entry.pop(CONF_ENTITY_ID, None) + return entry + + def deep_update(target: Dict, source: Dict) -> None: + """Deep update a dictionary.""" + for key, value in source.items(): + if isinstance(value, collections.Mapping): + updated_value = target.get(key, {}) + # If the new value is map, but the old value is not - + # overwrite the old value. + if not isinstance(updated_value, collections.Mapping): + updated_value = {} + deep_update(updated_value, value) + target[key] = updated_value + else: + target[key] = source[key] + + for rule in hass.data[_OVERWRITE_KEY]: + if CONF_ENTITY_ID in rule: + entities = rule[CONF_ENTITY_ID] + if domain in entities: + deep_update(domain_result, rule) + if entity_id in entities: + deep_update(exact_result, rule) + for entity_id_glob in entities: + if entity_id_glob == entity_id: + continue + if fnmatch.fnmatchcase(entity_id, entity_id_glob): + deep_update(glob_result, rule) + break + result = {} + deep_update(result, clean_entry(domain_result)) + deep_update(result, clean_entry(glob_result)) + deep_update(result, clean_entry(exact_result)) + if _OVERWRITE_CACHE_KEY not in hass.data: + hass.data[_OVERWRITE_CACHE_KEY] = {} + hass.data[_OVERWRITE_CACHE_KEY][entity_id] = result + return result diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 0d2f56f18072dc..438de6a66d3997 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -4,7 +4,7 @@ import functools as ft from timeit import default_timer as timer -from typing import Any, Optional, List, Dict +from typing import Optional, List from homeassistant.const import ( ATTR_ASSUMED_STATE, ATTR_FRIENDLY_NAME, ATTR_HIDDEN, ATTR_ICON, @@ -16,9 +16,7 @@ from homeassistant.util import ensure_unique_string, slugify from homeassistant.util.async import ( run_coroutine_threadsafe, run_callback_threadsafe) - -# Entity attributes that we will overwrite -_OVERWRITE = {} # type: Dict[str, Any] +from homeassistant.helpers.customize import get_overrides _LOGGER = logging.getLogger(__name__) @@ -57,16 +55,6 @@ def async_generate_entity_id(entity_id_format: str, name: Optional[str], entity_id_format.format(slugify(name)), current_ids) -def set_customize(customize: Dict[str, Any]) -> None: - """Overwrite all current customize settings. - - Async friendly. - """ - global _OVERWRITE - - _OVERWRITE = {key.lower(): val for key, val in customize.items()} - - class Entity(object): """An abstract class for Home Assistant entities.""" @@ -254,7 +242,7 @@ def async_update_ha_state(self, force_refresh=False): end - start) # Overwrite properties that have been set in the config file. - attr.update(_OVERWRITE.get(self.entity_id, {})) + attr.update(get_overrides(self.hass, self.entity_id)) # Remove hidden property if false so it won't show up. if not attr.get(ATTR_HIDDEN, True): diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py index 252f7f60c95d85..7255447cd498f7 100644 --- a/tests/helpers/test_config_validation.py +++ b/tests/helpers/test_config_validation.py @@ -165,6 +165,25 @@ def test_entity_ids(): ] +def test_ensure_list_csv(): + """Test ensure_list_csv.""" + schema = vol.Schema(cv.ensure_list_csv) + + options = ( + None, + 12, + [], + ['string'], + 'string1,string2' + ) + for value in options: + schema(value) + + assert schema('string1, string2 ') == [ + 'string1', 'string2' + ] + + def test_event_schema(): """Test event_schema validation.""" options = ( @@ -429,6 +448,15 @@ def test_has_at_least_one_key(): schema(value) +def test_ordered_dict_only_dict(): + """Test ordered_dict validator.""" + schema = vol.Schema(cv.ordered_dict(cv.match_all, cv.match_all)) + + for value in (None, [], 100, 'hello'): + with pytest.raises(vol.MultipleInvalid): + schema(value) + + def test_ordered_dict_order(): """Test ordered_dict validator.""" schema = vol.Schema(cv.ordered_dict(int, cv.string)) diff --git a/tests/helpers/test_customize.py b/tests/helpers/test_customize.py new file mode 100644 index 00000000000000..e3fd1e325b004c --- /dev/null +++ b/tests/helpers/test_customize.py @@ -0,0 +1,87 @@ +"""Test the customize helper.""" +import homeassistant.helpers.customize as customize + + +class MockHass(object): + """Mock object for HassAssistant.""" + + data = {} + + +class TestHelpersCustomize(object): + """Test homeassistant.helpers.customize module.""" + + def setup_method(self, method): + """Setup things to be run when tests are started.""" + self.entity_id = 'test.test' + self.hass = MockHass() + + def _get_overrides(self, overrides): + customize.set_customize(self.hass, overrides) + return customize.get_overrides(self.hass, self.entity_id) + + def test_override_single_value(self): + """Test entity customization through configuration.""" + result = self._get_overrides([ + {'entity_id': [self.entity_id], 'key': 'value'}]) + + assert result == {'key': 'value'} + + def test_override_multiple_values(self): + """Test entity customization through configuration.""" + result = self._get_overrides([ + {'entity_id': [self.entity_id], 'key1': 'value1'}, + {'entity_id': [self.entity_id], 'key2': 'value2'}]) + + assert result == {'key1': 'value1', 'key2': 'value2'} + + def test_override_same_value(self): + """Test entity customization through configuration.""" + result = self._get_overrides([ + {'entity_id': [self.entity_id], 'key': 'value1'}, + {'entity_id': [self.entity_id], 'key': 'value2'}]) + + assert result == {'key': 'value2'} + + def test_override_by_domain(self): + """Test entity customization through configuration.""" + result = self._get_overrides([ + {'entity_id': ['test'], 'key': 'value'}]) + + assert result == {'key': 'value'} + + def test_override_by_glob(self): + """Test entity customization through configuration.""" + result = self._get_overrides([ + {'entity_id': ['test.?e*'], 'key': 'value'}]) + + assert result == {'key': 'value'} + + def test_override_exact_over_glob_over_domain(self): + """Test entity customization through configuration.""" + result = self._get_overrides([ + {'entity_id': ['test.test'], 'key1': 'valueExact'}, + {'entity_id': ['test.tes?'], + 'key1': 'valueGlob', + 'key2': 'valueGlob'}, + {'entity_id': ['test'], + 'key1': 'valueDomain', + 'key2': 'valueDomain', + 'key3': 'valueDomain'}]) + + assert result == { + 'key1': 'valueExact', + 'key2': 'valueGlob', + 'key3': 'valueDomain'} + + def test_override_deep_dict(self): + """Test we can overwrite hidden property to True.""" + result = self._get_overrides( + [{'entity_id': [self.entity_id], + 'test': {'key1': 'value1', 'key2': 'value2'}}, + {'entity_id': [self.entity_id], + 'test': {'key3': 'value3', 'key2': 'value22'}}]) + assert result['test'] == { + 'key1': 'value1', + 'key2': 'value22', + 'key3': 'value3'} diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index b2db8277085738..9ec016ccfcd9b3 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -6,6 +6,7 @@ import pytest import homeassistant.helpers.entity as entity +from homeassistant.helpers.customize import set_customize from homeassistant.const import ATTR_HIDDEN from tests.common import get_test_home_assistant @@ -78,7 +79,6 @@ def setup_method(self, method): def teardown_method(self, method): """Stop everything that was started.""" - entity.set_customize({}) self.hass.stop() def test_default_hidden_not_in_attributes(self): @@ -88,7 +88,9 @@ def test_default_hidden_not_in_attributes(self): def test_overwriting_hidden_property_to_true(self): """Test we can overwrite hidden property to True.""" - entity.set_customize({self.entity.entity_id: {ATTR_HIDDEN: True}}) + set_customize( + self.hass, + [{'entity_id': [self.entity.entity_id], ATTR_HIDDEN: True}]) self.entity.update_ha_state() state = self.hass.states.get(self.entity.entity_id) diff --git a/tests/test_config.py b/tests/test_config.py index 455ebe33c616be..3976948605623c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -170,16 +170,17 @@ def test_create_default_config_returns_none_if_write_error(self, os.path.join(CONFIG_DIR, 'non_existing_dir/'), False)) self.assertTrue(mock_print.called) + # pylint: disable=no-self-use def test_core_config_schema(self): """Test core config schema.""" for value in ( - {CONF_UNIT_SYSTEM: 'K'}, - {'time_zone': 'non-exist'}, - {'latitude': '91'}, - {'longitude': -181}, - {'customize': 'bla'}, - {'customize': {'invalid_entity_id': {}}}, - {'customize': {'light.sensor': 100}}, + {CONF_UNIT_SYSTEM: 'K'}, + {'time_zone': 'non-exist'}, + {'latitude': '91'}, + {'longitude': -181}, + {'customize': 'bla'}, + {'customize': {'light.sensor': 100}}, + {'customize': {'entity_id': []}}, ): with pytest.raises(MultipleInvalid): config_util.CORE_CONFIG_SCHEMA(value) @@ -196,13 +197,7 @@ def test_core_config_schema(self): }, }) - def test_entity_customization(self): - """Test entity customization through configuration.""" - config = {CONF_LATITUDE: 50, - CONF_LONGITUDE: 50, - CONF_NAME: 'Test', - CONF_CUSTOMIZE: {'test.test': {'hidden': True}}} - + def _compute_state(self, config): run_coroutine_threadsafe( config_util.async_process_ha_core_config(self.hass, config), self.hass.loop).result() @@ -214,10 +209,50 @@ def test_entity_customization(self): self.hass.block_till_done() - state = self.hass.states.get('test.test') + return self.hass.states.get('test.test') + + def test_entity_customization_false(self): + """Test entity customization through configuration.""" + config = {CONF_LATITUDE: 50, + CONF_LONGITUDE: 50, + CONF_NAME: 'Test', + CONF_CUSTOMIZE: { + 'test.test': {'hidden': False}}} + + state = self._compute_state(config) + + assert 'hidden' not in state.attributes + + def test_entity_customization(self): + """Test entity customization through configuration.""" + config = {CONF_LATITUDE: 50, + CONF_LONGITUDE: 50, + CONF_NAME: 'Test', + CONF_CUSTOMIZE: {'test.test': {'hidden': True}}} + + state = self._compute_state(config) assert state.attributes['hidden'] + def test_entity_customization_comma_separated(self): + """Test entity customization through configuration.""" + config = {CONF_LATITUDE: 50, + CONF_LONGITUDE: 50, + CONF_NAME: 'Test', + CONF_CUSTOMIZE: [ + {'entity_id': 'test.not_test,test,test.not_t*', + 'key1': 'value1'}, + {'entity_id': 'test.test,not_test,test.not_t*', + 'key2': 'value2'}, + {'entity_id': 'test.not_test,not_test,test.t*', + 'key3': 'value3'}]} + + state = self._compute_state(config) + + assert state.attributes['key1'] == 'value1' + assert state.attributes['key2'] == 'value2' + assert state.attributes['key3'] == 'value3' + @mock.patch('homeassistant.config.shutil') @mock.patch('homeassistant.config.os') def test_remove_lib_on_upgrade(self, mock_os, mock_shutil): @@ -229,6 +264,7 @@ def test_remove_lib_on_upgrade(self, mock_os, mock_shutil): mock_open = mock.mock_open() with mock.patch('homeassistant.config.open', mock_open, create=True): opened_file = mock_open.return_value + # pylint: disable=no-member opened_file.readline.return_value = ha_version self.hass.config.path = mock.Mock() @@ -258,6 +294,7 @@ def test_not_remove_lib_if_not_upgrade(self, mock_os, mock_shutil): mock_open = mock.mock_open() with mock.patch('homeassistant.config.open', mock_open, create=True): opened_file = mock_open.return_value + # pylint: disable=no-member opened_file.readline.return_value = ha_version self.hass.config.path = mock.Mock()