Improve script validation (#32461)

This commit is contained in:
Paulus Schoutsen 2020-03-05 11:44:42 -08:00 committed by GitHub
parent da7c5518f3
commit 6a21afa2a8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 141 additions and 89 deletions

View file

@ -15,9 +15,16 @@ import homeassistant.components.scene as scene
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_CONDITION,
CONF_CONTINUE_ON_TIMEOUT,
CONF_DELAY,
CONF_DEVICE_ID,
CONF_DOMAIN,
CONF_EVENT,
CONF_EVENT_DATA,
CONF_EVENT_DATA_TEMPLATE,
CONF_SCENE,
CONF_TIMEOUT,
CONF_WAIT_TEMPLATE,
SERVICE_TURN_ON,
)
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
@ -37,24 +44,6 @@ from homeassistant.util.dt import utcnow
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
CONF_ALIAS = "alias"
CONF_SERVICE = "service"
CONF_SERVICE_DATA = "data"
CONF_SEQUENCE = "sequence"
CONF_EVENT = "event"
CONF_EVENT_DATA = "event_data"
CONF_EVENT_DATA_TEMPLATE = "event_data_template"
CONF_DELAY = "delay"
CONF_WAIT_TEMPLATE = "wait_template"
CONF_CONTINUE = "continue_on_timeout"
CONF_SCENE = "scene"
ACTION_DELAY = "delay"
ACTION_WAIT_TEMPLATE = "wait_template"
ACTION_CHECK_CONDITION = "condition"
ACTION_FIRE_EVENT = "event"
ACTION_CALL_SERVICE = "call_service"
ACTION_DEVICE_AUTOMATION = "device"
ACTION_ACTIVATE_SCENE = "scene"
IF_RUNNING_ERROR = "error"
IF_RUNNING_IGNORE = "ignore"
@ -82,41 +71,21 @@ _LOG_EXCEPTION = logging.ERROR + 1
_TIMEOUT_MSG = "Timeout reached, abort script."
def _determine_action(action):
"""Determine action type."""
if CONF_DELAY in action:
return ACTION_DELAY
if CONF_WAIT_TEMPLATE in action:
return ACTION_WAIT_TEMPLATE
if CONF_CONDITION in action:
return ACTION_CHECK_CONDITION
if CONF_EVENT in action:
return ACTION_FIRE_EVENT
if CONF_DEVICE_ID in action:
return ACTION_DEVICE_AUTOMATION
if CONF_SCENE in action:
return ACTION_ACTIVATE_SCENE
return ACTION_CALL_SERVICE
async def async_validate_action_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
action_type = _determine_action(config)
action_type = cv.determine_script_action(config)
if action_type == ACTION_DEVICE_AUTOMATION:
if action_type == cv.SCRIPT_ACTION_DEVICE_AUTOMATION:
platform = await device_automation.async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "action"
)
config = platform.ACTION_SCHEMA(config) # type: ignore
if action_type == ACTION_CHECK_CONDITION and config[CONF_CONDITION] == "device":
if (
action_type == cv.SCRIPT_ACTION_CHECK_CONDITION
and config[CONF_CONDITION] == "device"
):
platform = await device_automation.async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "condition"
)
@ -165,7 +134,9 @@ class _ScriptRunBase(ABC):
async def _async_step(self, log_exceptions):
try:
await getattr(self, f"_async_{_determine_action(self._action)}_step")()
await getattr(
self, f"_async_{cv.determine_script_action(self._action)}_step"
)()
except Exception as err:
if not isinstance(err, (_SuspendScript, _StopScript)) and (
self._log_exceptions or log_exceptions
@ -178,7 +149,7 @@ class _ScriptRunBase(ABC):
"""Stop script run."""
def _log_exception(self, exception):
action_type = _determine_action(self._action)
action_type = cv.determine_script_action(self._action)
error = str(exception)
level = logging.ERROR
@ -406,7 +377,7 @@ class _ScriptRun(_ScriptRunBase):
timeout,
)
except asyncio.TimeoutError:
if not self._action.get(CONF_CONTINUE, True):
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._log(_TIMEOUT_MSG)
raise _StopScript
finally:
@ -547,7 +518,7 @@ class _LegacyScriptRun(_ScriptRunBase):
# Check if we want to continue to execute
# the script after the timeout
if self._action.get(CONF_CONTINUE, True):
if self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._hass.async_create_task(self._async_run(False))
else:
self._log(_TIMEOUT_MSG)
@ -632,12 +603,12 @@ class Script:
referenced = set()
for step in self.sequence:
action = _determine_action(step)
action = cv.determine_script_action(step)
if action == ACTION_CHECK_CONDITION:
if action == cv.SCRIPT_ACTION_CHECK_CONDITION:
referenced |= condition.async_extract_devices(step)
elif action == ACTION_DEVICE_AUTOMATION:
elif action == cv.SCRIPT_ACTION_DEVICE_AUTOMATION:
referenced.add(step[CONF_DEVICE_ID])
self._referenced_devices = referenced
@ -652,9 +623,9 @@ class Script:
referenced = set()
for step in self.sequence:
action = _determine_action(step)
action = cv.determine_script_action(step)
if action == ACTION_CALL_SERVICE:
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
data = step.get(service.CONF_SERVICE_DATA)
if not data:
continue
@ -670,10 +641,10 @@ class Script:
for entity_id in entity_ids:
referenced.add(entity_id)
elif action == ACTION_CHECK_CONDITION:
elif action == cv.SCRIPT_ACTION_CHECK_CONDITION:
referenced |= condition.async_extract_entities(step)
elif action == ACTION_ACTIVATE_SCENE:
elif action == cv.SCRIPT_ACTION_ACTIVATE_SCENE:
referenced.add(step[CONF_SCENE])
self._referenced_entities = referenced