Improve device automation validation (#86143)

This commit is contained in:
Erik Montnemery 2023-01-21 00:44:17 +01:00 committed by GitHub
parent 0c8b6c13fc
commit 1e2f00e186
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 255 additions and 116 deletions

View file

@ -1,13 +1,12 @@
"""Offer device oriented automation."""
from __future__ import annotations
from typing import Any, Protocol, cast
from typing import Any, Protocol
import voluptuous as vol
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN
from homeassistant.const import CONF_DOMAIN
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType
@ -16,7 +15,7 @@ from . import (
DeviceAutomationType,
async_get_device_automation_platform,
)
from .exceptions import InvalidDeviceAutomationConfig
from .helpers import async_validate_device_automation_config
TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
@ -58,36 +57,9 @@ async def async_validate_trigger_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
try:
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
)
if not hasattr(platform, "async_validate_trigger_config"):
return cast(ConfigType, platform.TRIGGER_SCHEMA(config))
# Only call the dynamic validator if the relevant config entry is loaded
registry = dr.async_get(hass)
if not (device := registry.async_get(config[CONF_DEVICE_ID])):
return config
device_config_entry = None
for entry_id in device.config_entries:
if not (entry := hass.config_entries.async_get_entry(entry_id)):
continue
if entry.domain != config[CONF_DOMAIN]:
continue
device_config_entry = entry
break
if not device_config_entry:
return config
if not await hass.config_entries.async_wait_component(device_config_entry):
return config
return await platform.async_validate_trigger_config(hass, config)
except InvalidDeviceAutomationConfig as err:
raise vol.Invalid(str(err) or "Invalid trigger configuration") from err
return await async_validate_device_automation_config(
hass, config, TRIGGER_SCHEMA, DeviceAutomationType.TRIGGER
)
async def async_attach_trigger(