Improve validation of device trigger config (#26910)

* Improve validation of device trigger config

* Remove action and condition checks

* Move config validation to own file

* Fix tests

* Fixes

* Fixes

* Small tweak
This commit is contained in:
Erik Montnemery 2019-09-27 17:48:48 +02:00 committed by GitHub
parent 588bc26661
commit e57e7e8449
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 210 additions and 31 deletions

View file

@ -0,0 +1,60 @@
"""Config validation helper for the automation integration."""
import asyncio
import importlib
import voluptuous as vol
from homeassistant.const import CONF_PLATFORM
from homeassistant.config import async_log_exception, config_without_domain
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform
from homeassistant.loader import IntegrationNotFound
from . import CONF_TRIGGER, DOMAIN, PLATFORM_SCHEMA
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
async def async_validate_config_item(hass, config, full_config=None):
"""Validate config item."""
try:
config = PLATFORM_SCHEMA(config)
triggers = []
for trigger in config[CONF_TRIGGER]:
trigger_platform = importlib.import_module(
"..{}".format(trigger[CONF_PLATFORM]), __name__
)
if hasattr(trigger_platform, "async_validate_trigger_config"):
trigger = await trigger_platform.async_validate_trigger_config(
hass, trigger
)
triggers.append(trigger)
config[CONF_TRIGGER] = triggers
except (vol.Invalid, HomeAssistantError, IntegrationNotFound) as ex:
async_log_exception(ex, DOMAIN, full_config or config, hass)
return None
return config
async def async_validate_config(hass, config):
"""Validate config."""
automations = []
validated_automations = await asyncio.gather(
*(
async_validate_config_item(hass, p_config, config)
for _, p_config in config_per_platform(config, DOMAIN)
)
)
for validated_automation in validated_automations:
if validated_automation is not None:
automations.append(validated_automation)
# Create a copy of the configuration with all config for current
# component removed and add validated config back in.
config = config_without_domain(config, DOMAIN)
config[DOMAIN] = automations
return config

View file

@ -1,20 +1,24 @@
"""Offer device oriented automation."""
import voluptuous as vol
from homeassistant.const import CONF_DOMAIN, CONF_PLATFORM
from homeassistant.loader import async_get_integration
from homeassistant.components.device_automation import (
TRIGGER_BASE_SCHEMA,
async_get_device_automation_platform,
)
# mypy: allow-untyped-defs, no-check-untyped-defs
TRIGGER_SCHEMA = vol.Schema(
{vol.Required(CONF_PLATFORM): "device", vol.Required(CONF_DOMAIN): str},
extra=vol.ALLOW_EXTRA,
)
TRIGGER_SCHEMA = TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
async def async_validate_trigger_config(hass, config):
"""Validate config."""
platform = await async_get_device_automation_platform(hass, config, "trigger")
return platform.TRIGGER_SCHEMA(config)
async def async_attach_trigger(hass, config, action, automation_info):
"""Listen for trigger."""
integration = await async_get_integration(hass, config[CONF_DOMAIN])
platform = integration.get_platform("device_trigger")
platform = await async_get_device_automation_platform(hass, config, "trigger")
return await platform.async_attach_trigger(hass, config, action, automation_info)

View file

@ -5,10 +5,11 @@ import os
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import EVENT_COMPONENT_LOADED, CONF_ID
from homeassistant.setup import ATTR_COMPONENT
from homeassistant.components.http import HomeAssistantView
from homeassistant.const import EVENT_COMPONENT_LOADED, CONF_ID
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import ATTR_COMPONENT
from homeassistant.util.yaml import load_yaml, dump
DOMAIN = "config"
@ -80,6 +81,7 @@ class BaseEditConfigView(HomeAssistantView):
data_schema,
*,
post_write_hook=None,
data_validator=None,
):
"""Initialize a config view."""
self.url = f"/api/config/{component}/{config_type}/{{config_key}}"
@ -88,6 +90,7 @@ class BaseEditConfigView(HomeAssistantView):
self.key_schema = key_schema
self.data_schema = data_schema
self.post_write_hook = post_write_hook
self.data_validator = data_validator
def _empty_config(self):
"""Empty config if file not found."""
@ -128,14 +131,18 @@ class BaseEditConfigView(HomeAssistantView):
except vol.Invalid as err:
return self.json_message(f"Key malformed: {err}", 400)
hass = request.app["hass"]
try:
# We just validate, we don't store that data because
# we don't want to store the defaults.
self.data_schema(data)
except vol.Invalid as err:
if self.data_validator:
await self.data_validator(hass, data)
else:
self.data_schema(data)
except (vol.Invalid, HomeAssistantError) as err:
return self.json_message(f"Message malformed: {err}", 400)
hass = request.app["hass"]
path = hass.config.path(self.path)
current = await self.read_config(hass)

View file

@ -3,6 +3,7 @@ from collections import OrderedDict
import uuid
from homeassistant.components.automation import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.automation.config import async_validate_config_item
from homeassistant.const import CONF_ID, SERVICE_RELOAD
import homeassistant.helpers.config_validation as cv
@ -26,6 +27,7 @@ async def async_setup(hass):
cv.string,
PLATFORM_SCHEMA,
post_write_hook=hook,
data_validator=async_validate_config_item,
)
)
return True

View file

@ -9,6 +9,8 @@ from homeassistant.components import websocket_api
from homeassistant.helpers.entity_registry import async_entries_for_device
from homeassistant.loader import async_get_integration, IntegrationNotFound
from .exceptions import InvalidDeviceAutomationConfig
DOMAIN = "device_automation"
_LOGGER = logging.getLogger(__name__)
@ -43,6 +45,27 @@ async def async_setup(hass, config):
return True
async def async_get_device_automation_platform(hass, config, automation_type):
"""Load device automation platform for integration.
Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation.
"""
platform_name, _ = TYPES[automation_type]
try:
integration = await async_get_integration(hass, config[CONF_DOMAIN])
platform = integration.get_platform(platform_name)
except IntegrationNotFound:
raise InvalidDeviceAutomationConfig(
f"Integration '{config[CONF_DOMAIN]}' not found"
)
except ImportError:
raise InvalidDeviceAutomationConfig(
f"Integration '{config[CONF_DOMAIN]}' does not support device automation {automation_type}s"
)
return platform
async def _async_get_device_automations_from_domain(
hass, domain, automation_type, device_id
):

View file

@ -416,7 +416,7 @@ def process_ha_config_upgrade(hass: HomeAssistant) -> None:
@callback
def async_log_exception(
ex: vol.Invalid, domain: str, config: Dict, hass: HomeAssistant
ex: Exception, domain: str, config: Dict, hass: HomeAssistant
) -> None:
"""Log an error for configuration validation.
@ -428,23 +428,26 @@ def async_log_exception(
@callback
def _format_config_error(ex: vol.Invalid, domain: str, config: Dict) -> str:
def _format_config_error(ex: Exception, domain: str, config: Dict) -> str:
"""Generate log exception for configuration validation.
This method must be run in the event loop.
"""
message = f"Invalid config for [{domain}]: "
if "extra keys not allowed" in ex.error_message:
message += (
"[{option}] is an invalid option for [{domain}]. "
"Check: {domain}->{path}.".format(
option=ex.path[-1],
domain=domain,
path="->".join(str(m) for m in ex.path),
if isinstance(ex, vol.Invalid):
if "extra keys not allowed" in ex.error_message:
message += (
"[{option}] is an invalid option for [{domain}]. "
"Check: {domain}->{path}.".format(
option=ex.path[-1],
domain=domain,
path="->".join(str(m) for m in ex.path),
)
)
)
else:
message += "{}.".format(humanize_error(config, ex))
else:
message += "{}.".format(humanize_error(config, ex))
message += str(ex)
try:
domain_config = config.get(domain, config)
@ -717,6 +720,24 @@ async def async_process_component_config(
_LOGGER.error("Unable to import %s: %s", domain, ex)
return None
# Check if the integration has a custom config validator
config_validator = None
try:
config_validator = integration.get_platform("config")
except ImportError:
pass
if config_validator is not None and hasattr(
config_validator, "async_validate_config"
):
try:
return await config_validator.async_validate_config( # type: ignore
hass, config
)
except (vol.Invalid, HomeAssistantError) as ex:
async_log_exception(ex, domain, config, hass)
return None
# No custom config validator, proceed with schema validation
if hasattr(component, "CONFIG_SCHEMA"):
try:
return component.CONFIG_SCHEMA(config) # type: ignore

View file

@ -90,7 +90,7 @@ def has_at_least_one_key(*keys: str) -> Callable:
for k in obj.keys():
if k in keys:
return obj
raise vol.Invalid("must contain one of {}.".format(", ".join(keys)))
raise vol.Invalid("must contain at least one of {}.".format(", ".join(keys)))
return validate

View file

@ -307,7 +307,7 @@ class IntegrationNotFound(LoaderError):
def __init__(self, domain: str) -> None:
"""Initialize a component not found error."""
super().__init__(f"Integration {domain} not found.")
super().__init__(f"Integration '{domain}' not found.")
self.domain = domain

View file

@ -2,6 +2,7 @@
import pytest
from homeassistant.setup import async_setup_component
import homeassistant.components.automation as automation
from homeassistant.components.websocket_api.const import TYPE_RESULT
from homeassistant.helpers import device_registry
@ -161,3 +162,64 @@ async def test_websocket_get_triggers(hass, hass_ws_client, device_reg, entity_r
assert msg["success"]
triggers = msg["result"]
assert _same_lists(triggers, expected_triggers)
async def test_automation_with_non_existing_integration(hass, caplog):
"""Test device automation with non existing integration."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"alias": "hello",
"trigger": {
"platform": "device",
"device_id": "none",
"domain": "beer",
},
"action": {"service": "test.automation", "entity_id": "hello.world"},
}
},
)
assert "Integration 'beer' not found" in caplog.text
async def test_automation_with_integration_without_device_trigger(hass, caplog):
"""Test automation with integration without device trigger support."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"alias": "hello",
"trigger": {
"platform": "device",
"device_id": "none",
"domain": "test",
},
"action": {"service": "test.automation", "entity_id": "hello.world"},
}
},
)
assert (
"Integration 'test' does not support device automation triggers" in caplog.text
)
async def test_automation_with_bad_trigger(hass, caplog):
"""Test automation with bad device trigger."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"alias": "hello",
"trigger": {"platform": "device", "domain": "light"},
"action": {"service": "test.automation", "entity_id": "hello.world"},
}
},
)
assert "required key not provided" in caplog.text

View file

@ -75,7 +75,7 @@ async def test_component_platform_not_found(hass, loop):
assert res.keys() == {"homeassistant"}
assert res.errors[0] == CheckConfigError(
"Component error: beer - Integration beer not found.", None, None
"Component error: beer - Integration 'beer' not found.", None, None
)
# Only 1 error expected
@ -95,7 +95,7 @@ async def test_component_platform_not_found_2(hass, loop):
assert res["light"] == []
assert res.errors[0] == CheckConfigError(
"Platform error light.beer - Integration beer not found.", None, None
"Platform error light.beer - Integration 'beer' not found.", None, None
)
# Only 1 error expected

View file

@ -63,7 +63,7 @@ def test_component_platform_not_found(isfile_patch, loop):
assert res["components"].keys() == {"homeassistant"}
assert res["except"] == {
check_config.ERROR_STR: [
"Component error: beer - Integration beer not found."
"Component error: beer - Integration 'beer' not found."
]
}
assert res["secret_cache"] == {}
@ -77,7 +77,7 @@ def test_component_platform_not_found(isfile_patch, loop):
assert res["components"]["light"] == []
assert res["except"] == {
check_config.ERROR_STR: [
"Platform error light.beer - Integration beer not found."
"Platform error light.beer - Integration 'beer' not found."
]
}
assert res["secret_cache"] == {}