Improve validation of device trigger config (#26910)

* Improve validation of device trigger config

* Remove action and condition checks

* Move config validation to own file

* Fix tests

* Fixes

* Fixes

* Small tweak
This commit is contained in:
Erik Montnemery 2019-09-27 17:48:48 +02:00 committed by GitHub
parent 588bc26661
commit e57e7e8449
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 210 additions and 31 deletions

View file

@ -0,0 +1,60 @@
"""Config validation helper for the automation integration."""
import asyncio
import importlib
import voluptuous as vol
from homeassistant.const import CONF_PLATFORM
from homeassistant.config import async_log_exception, config_without_domain
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform
from homeassistant.loader import IntegrationNotFound
from . import CONF_TRIGGER, DOMAIN, PLATFORM_SCHEMA
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
async def async_validate_config_item(hass, config, full_config=None):
"""Validate config item."""
try:
config = PLATFORM_SCHEMA(config)
triggers = []
for trigger in config[CONF_TRIGGER]:
trigger_platform = importlib.import_module(
"..{}".format(trigger[CONF_PLATFORM]), __name__
)
if hasattr(trigger_platform, "async_validate_trigger_config"):
trigger = await trigger_platform.async_validate_trigger_config(
hass, trigger
)
triggers.append(trigger)
config[CONF_TRIGGER] = triggers
except (vol.Invalid, HomeAssistantError, IntegrationNotFound) as ex:
async_log_exception(ex, DOMAIN, full_config or config, hass)
return None
return config
async def async_validate_config(hass, config):
"""Validate config."""
automations = []
validated_automations = await asyncio.gather(
*(
async_validate_config_item(hass, p_config, config)
for _, p_config in config_per_platform(config, DOMAIN)
)
)
for validated_automation in validated_automations:
if validated_automation is not None:
automations.append(validated_automation)
# Create a copy of the configuration with all config for current
# component removed and add validated config back in.
config = config_without_domain(config, DOMAIN)
config[DOMAIN] = automations
return config

View file

@ -1,20 +1,24 @@
"""Offer device oriented automation.""" """Offer device oriented automation."""
import voluptuous as vol import voluptuous as vol
from homeassistant.const import CONF_DOMAIN, CONF_PLATFORM from homeassistant.components.device_automation import (
from homeassistant.loader import async_get_integration TRIGGER_BASE_SCHEMA,
async_get_device_automation_platform,
)
# mypy: allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-defs, no-check-untyped-defs
TRIGGER_SCHEMA = vol.Schema( TRIGGER_SCHEMA = TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
{vol.Required(CONF_PLATFORM): "device", vol.Required(CONF_DOMAIN): str},
extra=vol.ALLOW_EXTRA,
) async def async_validate_trigger_config(hass, config):
"""Validate config."""
platform = await async_get_device_automation_platform(hass, config, "trigger")
return platform.TRIGGER_SCHEMA(config)
async def async_attach_trigger(hass, config, action, automation_info): async def async_attach_trigger(hass, config, action, automation_info):
"""Listen for trigger.""" """Listen for trigger."""
integration = await async_get_integration(hass, config[CONF_DOMAIN]) platform = await async_get_device_automation_platform(hass, config, "trigger")
platform = integration.get_platform("device_trigger")
return await platform.async_attach_trigger(hass, config, action, automation_info) return await platform.async_attach_trigger(hass, config, action, automation_info)

View file

@ -5,10 +5,11 @@ import os
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import EVENT_COMPONENT_LOADED, CONF_ID
from homeassistant.setup import ATTR_COMPONENT
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.const import EVENT_COMPONENT_LOADED, CONF_ID
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import ATTR_COMPONENT
from homeassistant.util.yaml import load_yaml, dump from homeassistant.util.yaml import load_yaml, dump
DOMAIN = "config" DOMAIN = "config"
@ -80,6 +81,7 @@ class BaseEditConfigView(HomeAssistantView):
data_schema, data_schema,
*, *,
post_write_hook=None, post_write_hook=None,
data_validator=None,
): ):
"""Initialize a config view.""" """Initialize a config view."""
self.url = f"/api/config/{component}/{config_type}/{{config_key}}" self.url = f"/api/config/{component}/{config_type}/{{config_key}}"
@ -88,6 +90,7 @@ class BaseEditConfigView(HomeAssistantView):
self.key_schema = key_schema self.key_schema = key_schema
self.data_schema = data_schema self.data_schema = data_schema
self.post_write_hook = post_write_hook self.post_write_hook = post_write_hook
self.data_validator = data_validator
def _empty_config(self): def _empty_config(self):
"""Empty config if file not found.""" """Empty config if file not found."""
@ -128,14 +131,18 @@ class BaseEditConfigView(HomeAssistantView):
except vol.Invalid as err: except vol.Invalid as err:
return self.json_message(f"Key malformed: {err}", 400) return self.json_message(f"Key malformed: {err}", 400)
hass = request.app["hass"]
try: try:
# We just validate, we don't store that data because # We just validate, we don't store that data because
# we don't want to store the defaults. # we don't want to store the defaults.
self.data_schema(data) if self.data_validator:
except vol.Invalid as err: await self.data_validator(hass, data)
else:
self.data_schema(data)
except (vol.Invalid, HomeAssistantError) as err:
return self.json_message(f"Message malformed: {err}", 400) return self.json_message(f"Message malformed: {err}", 400)
hass = request.app["hass"]
path = hass.config.path(self.path) path = hass.config.path(self.path)
current = await self.read_config(hass) current = await self.read_config(hass)

View file

@ -3,6 +3,7 @@ from collections import OrderedDict
import uuid import uuid
from homeassistant.components.automation import DOMAIN, PLATFORM_SCHEMA from homeassistant.components.automation import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.automation.config import async_validate_config_item
from homeassistant.const import CONF_ID, SERVICE_RELOAD from homeassistant.const import CONF_ID, SERVICE_RELOAD
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -26,6 +27,7 @@ async def async_setup(hass):
cv.string, cv.string,
PLATFORM_SCHEMA, PLATFORM_SCHEMA,
post_write_hook=hook, post_write_hook=hook,
data_validator=async_validate_config_item,
) )
) )
return True return True

View file

@ -9,6 +9,8 @@ from homeassistant.components import websocket_api
from homeassistant.helpers.entity_registry import async_entries_for_device from homeassistant.helpers.entity_registry import async_entries_for_device
from homeassistant.loader import async_get_integration, IntegrationNotFound from homeassistant.loader import async_get_integration, IntegrationNotFound
from .exceptions import InvalidDeviceAutomationConfig
DOMAIN = "device_automation" DOMAIN = "device_automation"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -43,6 +45,27 @@ async def async_setup(hass, config):
return True return True
async def async_get_device_automation_platform(hass, config, automation_type):
"""Load device automation platform for integration.
Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation.
"""
platform_name, _ = TYPES[automation_type]
try:
integration = await async_get_integration(hass, config[CONF_DOMAIN])
platform = integration.get_platform(platform_name)
except IntegrationNotFound:
raise InvalidDeviceAutomationConfig(
f"Integration '{config[CONF_DOMAIN]}' not found"
)
except ImportError:
raise InvalidDeviceAutomationConfig(
f"Integration '{config[CONF_DOMAIN]}' does not support device automation {automation_type}s"
)
return platform
async def _async_get_device_automations_from_domain( async def _async_get_device_automations_from_domain(
hass, domain, automation_type, device_id hass, domain, automation_type, device_id
): ):

View file

@ -416,7 +416,7 @@ def process_ha_config_upgrade(hass: HomeAssistant) -> None:
@callback @callback
def async_log_exception( def async_log_exception(
ex: vol.Invalid, domain: str, config: Dict, hass: HomeAssistant ex: Exception, domain: str, config: Dict, hass: HomeAssistant
) -> None: ) -> None:
"""Log an error for configuration validation. """Log an error for configuration validation.
@ -428,23 +428,26 @@ def async_log_exception(
@callback @callback
def _format_config_error(ex: vol.Invalid, domain: str, config: Dict) -> str: def _format_config_error(ex: Exception, domain: str, config: Dict) -> str:
"""Generate log exception for configuration validation. """Generate log exception for configuration validation.
This method must be run in the event loop. This method must be run in the event loop.
""" """
message = f"Invalid config for [{domain}]: " message = f"Invalid config for [{domain}]: "
if "extra keys not allowed" in ex.error_message: if isinstance(ex, vol.Invalid):
message += ( if "extra keys not allowed" in ex.error_message:
"[{option}] is an invalid option for [{domain}]. " message += (
"Check: {domain}->{path}.".format( "[{option}] is an invalid option for [{domain}]. "
option=ex.path[-1], "Check: {domain}->{path}.".format(
domain=domain, option=ex.path[-1],
path="->".join(str(m) for m in ex.path), domain=domain,
path="->".join(str(m) for m in ex.path),
)
) )
) else:
message += "{}.".format(humanize_error(config, ex))
else: else:
message += "{}.".format(humanize_error(config, ex)) message += str(ex)
try: try:
domain_config = config.get(domain, config) domain_config = config.get(domain, config)
@ -717,6 +720,24 @@ async def async_process_component_config(
_LOGGER.error("Unable to import %s: %s", domain, ex) _LOGGER.error("Unable to import %s: %s", domain, ex)
return None return None
# Check if the integration has a custom config validator
config_validator = None
try:
config_validator = integration.get_platform("config")
except ImportError:
pass
if config_validator is not None and hasattr(
config_validator, "async_validate_config"
):
try:
return await config_validator.async_validate_config( # type: ignore
hass, config
)
except (vol.Invalid, HomeAssistantError) as ex:
async_log_exception(ex, domain, config, hass)
return None
# No custom config validator, proceed with schema validation
if hasattr(component, "CONFIG_SCHEMA"): if hasattr(component, "CONFIG_SCHEMA"):
try: try:
return component.CONFIG_SCHEMA(config) # type: ignore return component.CONFIG_SCHEMA(config) # type: ignore

View file

@ -90,7 +90,7 @@ def has_at_least_one_key(*keys: str) -> Callable:
for k in obj.keys(): for k in obj.keys():
if k in keys: if k in keys:
return obj return obj
raise vol.Invalid("must contain one of {}.".format(", ".join(keys))) raise vol.Invalid("must contain at least one of {}.".format(", ".join(keys)))
return validate return validate

View file

@ -307,7 +307,7 @@ class IntegrationNotFound(LoaderError):
def __init__(self, domain: str) -> None: def __init__(self, domain: str) -> None:
"""Initialize a component not found error.""" """Initialize a component not found error."""
super().__init__(f"Integration {domain} not found.") super().__init__(f"Integration '{domain}' not found.")
self.domain = domain self.domain = domain

View file

@ -2,6 +2,7 @@
import pytest import pytest
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.components.automation as automation
from homeassistant.components.websocket_api.const import TYPE_RESULT from homeassistant.components.websocket_api.const import TYPE_RESULT
from homeassistant.helpers import device_registry from homeassistant.helpers import device_registry
@ -161,3 +162,64 @@ async def test_websocket_get_triggers(hass, hass_ws_client, device_reg, entity_r
assert msg["success"] assert msg["success"]
triggers = msg["result"] triggers = msg["result"]
assert _same_lists(triggers, expected_triggers) assert _same_lists(triggers, expected_triggers)
async def test_automation_with_non_existing_integration(hass, caplog):
"""Test device automation with non existing integration."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"alias": "hello",
"trigger": {
"platform": "device",
"device_id": "none",
"domain": "beer",
},
"action": {"service": "test.automation", "entity_id": "hello.world"},
}
},
)
assert "Integration 'beer' not found" in caplog.text
async def test_automation_with_integration_without_device_trigger(hass, caplog):
"""Test automation with integration without device trigger support."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"alias": "hello",
"trigger": {
"platform": "device",
"device_id": "none",
"domain": "test",
},
"action": {"service": "test.automation", "entity_id": "hello.world"},
}
},
)
assert (
"Integration 'test' does not support device automation triggers" in caplog.text
)
async def test_automation_with_bad_trigger(hass, caplog):
"""Test automation with bad device trigger."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"alias": "hello",
"trigger": {"platform": "device", "domain": "light"},
"action": {"service": "test.automation", "entity_id": "hello.world"},
}
},
)
assert "required key not provided" in caplog.text

View file

@ -75,7 +75,7 @@ async def test_component_platform_not_found(hass, loop):
assert res.keys() == {"homeassistant"} assert res.keys() == {"homeassistant"}
assert res.errors[0] == CheckConfigError( assert res.errors[0] == CheckConfigError(
"Component error: beer - Integration beer not found.", None, None "Component error: beer - Integration 'beer' not found.", None, None
) )
# Only 1 error expected # Only 1 error expected
@ -95,7 +95,7 @@ async def test_component_platform_not_found_2(hass, loop):
assert res["light"] == [] assert res["light"] == []
assert res.errors[0] == CheckConfigError( assert res.errors[0] == CheckConfigError(
"Platform error light.beer - Integration beer not found.", None, None "Platform error light.beer - Integration 'beer' not found.", None, None
) )
# Only 1 error expected # Only 1 error expected

View file

@ -63,7 +63,7 @@ def test_component_platform_not_found(isfile_patch, loop):
assert res["components"].keys() == {"homeassistant"} assert res["components"].keys() == {"homeassistant"}
assert res["except"] == { assert res["except"] == {
check_config.ERROR_STR: [ check_config.ERROR_STR: [
"Component error: beer - Integration beer not found." "Component error: beer - Integration 'beer' not found."
] ]
} }
assert res["secret_cache"] == {} assert res["secret_cache"] == {}
@ -77,7 +77,7 @@ def test_component_platform_not_found(isfile_patch, loop):
assert res["components"]["light"] == [] assert res["components"]["light"] == []
assert res["except"] == { assert res["except"] == {
check_config.ERROR_STR: [ check_config.ERROR_STR: [
"Platform error light.beer - Integration beer not found." "Platform error light.beer - Integration 'beer' not found."
] ]
} }
assert res["secret_cache"] == {} assert res["secret_cache"] == {}