Skip to content

Commit

Permalink
Improve script validation (home-assistant#32461)
Browse files Browse the repository at this point in the history
  • Loading branch information
balloob authored Mar 5, 2020
1 parent da7c551 commit 6a21afa
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 90 deletions.
21 changes: 15 additions & 6 deletions homeassistant/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
CONF_API_KEY = "api_key"
CONF_API_VERSION = "api_version"
CONF_AT = "at"
CONF_AUTHENTICATION = "authentication"
CONF_AUTH_MFA_MODULES = "auth_mfa_modules"
CONF_AUTH_PROVIDERS = "auth_providers"
CONF_AUTHENTICATION = "authentication"
CONF_BASE = "base"
CONF_BEFORE = "before"
CONF_BELOW = "below"
Expand All @@ -57,11 +57,13 @@
CONF_COMMAND_STATE = "command_state"
CONF_COMMAND_STOP = "command_stop"
CONF_CONDITION = "condition"
CONF_CONTINUE_ON_TIMEOUT = "continue_on_timeout"
CONF_COVERS = "covers"
CONF_CURRENCY = "currency"
CONF_CUSTOMIZE = "customize"
CONF_CUSTOMIZE_DOMAIN = "customize_domain"
CONF_CUSTOMIZE_GLOB = "customize_glob"
CONF_DELAY = "delay"
CONF_DELAY_TIME = "delay_time"
CONF_DEVICE = "device"
CONF_DEVICE_CLASS = "device_class"
Expand All @@ -82,6 +84,8 @@
CONF_ENTITY_NAMESPACE = "entity_namespace"
CONF_ENTITY_PICTURE_TEMPLATE = "entity_picture_template"
CONF_EVENT = "event"
CONF_EVENT_DATA = "event_data"
CONF_EVENT_DATA_TEMPLATE = "event_data_template"
CONF_EXCLUDE = "exclude"
CONF_FILE_PATH = "file_path"
CONF_FILENAME = "filename"
Expand All @@ -95,15 +99,15 @@
CONF_HS = "hs"
CONF_ICON = "icon"
CONF_ICON_TEMPLATE = "icon_template"
CONF_INCLUDE = "include"
CONF_ID = "id"
CONF_INCLUDE = "include"
CONF_IP_ADDRESS = "ip_address"
CONF_LATITUDE = "latitude"
CONF_LONGITUDE = "longitude"
CONF_LIGHTS = "lights"
CONF_LONGITUDE = "longitude"
CONF_MAC = "mac"
CONF_METHOD = "method"
CONF_MAXIMUM = "maximum"
CONF_METHOD = "method"
CONF_MINIMUM = "minimum"
CONF_MODE = "mode"
CONF_MONITORED_CONDITIONS = "monitored_conditions"
Expand All @@ -130,14 +134,18 @@
CONF_RECIPIENT = "recipient"
CONF_REGION = "region"
CONF_RESOURCE = "resource"
CONF_RESOURCES = "resources"
CONF_RESOURCE_TEMPLATE = "resource_template"
CONF_RESOURCES = "resources"
CONF_RGB = "rgb"
CONF_ROOM = "room"
CONF_SCAN_INTERVAL = "scan_interval"
CONF_SCENE = "scene"
CONF_SENDER = "sender"
CONF_SENSOR_TYPE = "sensor_type"
CONF_SENSORS = "sensors"
CONF_SERVICE = "service"
CONF_SERVICE_DATA = "data"
CONF_SERVICE_TEMPLATE = "service_template"
CONF_SHOW_ON_MAP = "show_on_map"
CONF_SLAVE = "slave"
CONF_SOURCE = "source"
Expand All @@ -159,11 +167,12 @@
CONF_USERNAME = "username"
CONF_VALUE_TEMPLATE = "value_template"
CONF_VERIFY_SSL = "verify_ssl"
CONF_WAIT_TEMPLATE = "wait_template"
CONF_WEBHOOK_ID = "webhook_id"
CONF_WEEKDAY = "weekday"
CONF_WHITE_VALUE = "white_value"
CONF_WHITELIST = "whitelist"
CONF_WHITELIST_EXTERNAL_DIRS = "whitelist_external_dirs"
CONF_WHITE_VALUE = "white_value"
CONF_XY = "xy"
CONF_ZONE = "zone"

Expand Down
99 changes: 73 additions & 26 deletions homeassistant/helpers/config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,27 @@
CONF_ALIAS,
CONF_BELOW,
CONF_CONDITION,
CONF_CONTINUE_ON_TIMEOUT,
CONF_DELAY,
CONF_DEVICE_ID,
CONF_DOMAIN,
CONF_ENTITY_ID,
CONF_ENTITY_NAMESPACE,
CONF_EVENT,
CONF_EVENT_DATA,
CONF_EVENT_DATA_TEMPLATE,
CONF_FOR,
CONF_PLATFORM,
CONF_SCAN_INTERVAL,
CONF_SCENE,
CONF_SERVICE,
CONF_SERVICE_TEMPLATE,
CONF_STATE,
CONF_TIMEOUT,
CONF_UNIT_SYSTEM_IMPERIAL,
CONF_UNIT_SYSTEM_METRIC,
CONF_VALUE_TEMPLATE,
CONF_WAIT_TEMPLATE,
ENTITY_MATCH_ALL,
ENTITY_MATCH_NONE,
SUN_EVENT_SUNRISE,
Expand Down Expand Up @@ -722,7 +731,7 @@ def key_value_validator(value: Any) -> Dict[str, Any]:

if key_value not in value_schemas:
raise vol.Invalid(
f"Unexpected key {key_value}. Expected {', '.join(value_schemas)}"
f"Unexpected value for {key}: '{key_value}'. Expected {', '.join(value_schemas)}"
)

return cast(Dict[str, Any], value_schemas[key_value](value))
Expand Down Expand Up @@ -800,24 +809,24 @@ def make_entity_service_schema(
EVENT_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
vol.Required("event"): string,
vol.Optional("event_data"): dict,
vol.Optional("event_data_template"): {match_all: template_complex},
vol.Required(CONF_EVENT): string,
vol.Optional(CONF_EVENT_DATA): dict,
vol.Optional(CONF_EVENT_DATA_TEMPLATE): {match_all: template_complex},
}
)

SERVICE_SCHEMA = vol.All(
vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
vol.Exclusive("service", "service name"): service,
vol.Exclusive("service_template", "service name"): template,
vol.Exclusive(CONF_SERVICE, "service name"): service,
vol.Exclusive(CONF_SERVICE_TEMPLATE, "service name"): template,
vol.Optional("data"): dict,
vol.Optional("data_template"): {match_all: template_complex},
vol.Optional(CONF_ENTITY_ID): comp_entity_ids,
}
),
has_at_least_one_key("service", "service_template"),
has_at_least_one_key(CONF_SERVICE, CONF_SERVICE_TEMPLATE),
)

NUMERIC_STATE_CONDITION_SCHEMA = vol.All(
Expand Down Expand Up @@ -943,7 +952,7 @@ def make_entity_service_schema(
_SCRIPT_DELAY_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
vol.Required("delay"): vol.Any(
vol.Required(CONF_DELAY): vol.Any(
vol.All(time_period, positive_timedelta), template, template_complex
),
}
Expand All @@ -952,9 +961,9 @@ def make_entity_service_schema(
_SCRIPT_WAIT_TEMPLATE_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
vol.Required("wait_template"): template,
vol.Required(CONF_WAIT_TEMPLATE): template,
vol.Optional(CONF_TIMEOUT): vol.All(time_period, positive_timedelta),
vol.Optional("continue_on_timeout"): boolean,
vol.Optional(CONF_CONTINUE_ON_TIMEOUT): boolean,
}
)

Expand All @@ -964,19 +973,57 @@ def make_entity_service_schema(

DEVICE_ACTION_SCHEMA = DEVICE_ACTION_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)

_SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required("scene"): entity_domain("scene")})

SCRIPT_SCHEMA = vol.All(
ensure_list,
[
vol.Any(
SERVICE_SCHEMA,
_SCRIPT_DELAY_SCHEMA,
_SCRIPT_WAIT_TEMPLATE_SCHEMA,
EVENT_SCHEMA,
CONDITION_SCHEMA,
DEVICE_ACTION_SCHEMA,
_SCRIPT_SCENE_SCHEMA,
)
],
)
_SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required(CONF_SCENE): entity_domain("scene")})

SCRIPT_ACTION_DELAY = "delay"
SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template"
SCRIPT_ACTION_CHECK_CONDITION = "condition"
SCRIPT_ACTION_FIRE_EVENT = "event"
SCRIPT_ACTION_CALL_SERVICE = "call_service"
SCRIPT_ACTION_DEVICE_AUTOMATION = "device"
SCRIPT_ACTION_ACTIVATE_SCENE = "scene"


def determine_script_action(action: dict) -> str:
"""Determine action type."""
if CONF_DELAY in action:
return SCRIPT_ACTION_DELAY

if CONF_WAIT_TEMPLATE in action:
return SCRIPT_ACTION_WAIT_TEMPLATE

if CONF_CONDITION in action:
return SCRIPT_ACTION_CHECK_CONDITION

if CONF_EVENT in action:
return SCRIPT_ACTION_FIRE_EVENT

if CONF_DEVICE_ID in action:
return SCRIPT_ACTION_DEVICE_AUTOMATION

if CONF_SCENE in action:
return SCRIPT_ACTION_ACTIVATE_SCENE

return SCRIPT_ACTION_CALL_SERVICE


ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = {
SCRIPT_ACTION_CALL_SERVICE: SERVICE_SCHEMA,
SCRIPT_ACTION_DELAY: _SCRIPT_DELAY_SCHEMA,
SCRIPT_ACTION_WAIT_TEMPLATE: _SCRIPT_WAIT_TEMPLATE_SCHEMA,
SCRIPT_ACTION_FIRE_EVENT: EVENT_SCHEMA,
SCRIPT_ACTION_CHECK_CONDITION: CONDITION_SCHEMA,
SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA,
SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA,
}


def script_action(value: Any) -> dict:
"""Validate a script action."""
if not isinstance(value, dict):
raise vol.Invalid("expected dictionary")

return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value)


SCRIPT_SCHEMA = vol.All(ensure_list, [script_action])
Loading

0 comments on commit 6a21afa

Please sign in to comment.