Fix extracting entity and device IDs from scripts (#44048)

* Fix extracting entity and device IDs from scripts

* Fix extracting from data_template
This commit is contained in:
Paulus Schoutsen 2020-12-08 13:06:29 +01:00 committed by GitHub
parent 0b7b6b1d81
commit ac2af69d26
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 86 additions and 27 deletions

View file

@ -22,10 +22,10 @@ from async_timeout import timeout
import voluptuous as vol import voluptuous as vol
from homeassistant import exceptions from homeassistant import exceptions
import homeassistant.components.device_automation as device_automation from homeassistant.components import device_automation, scene
from homeassistant.components.logger import LOGSEVERITY from homeassistant.components.logger import LOGSEVERITY
import homeassistant.components.scene as scene
from homeassistant.const import ( from homeassistant.const import (
ATTR_DEVICE_ID,
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
CONF_ALIAS, CONF_ALIAS,
CONF_CHOOSE, CONF_CHOOSE,
@ -44,6 +44,7 @@ from homeassistant.const import (
CONF_REPEAT, CONF_REPEAT,
CONF_SCENE, CONF_SCENE,
CONF_SEQUENCE, CONF_SEQUENCE,
CONF_TARGET,
CONF_TIMEOUT, CONF_TIMEOUT,
CONF_UNTIL, CONF_UNTIL,
CONF_VARIABLES, CONF_VARIABLES,
@ -60,13 +61,9 @@ from homeassistant.core import (
HomeAssistant, HomeAssistant,
callback, callback,
) )
from homeassistant.helpers import condition, config_validation as cv, template from homeassistant.helpers import condition, config_validation as cv, service, template
from homeassistant.helpers.event import async_call_later, async_track_template from homeassistant.helpers.event import async_call_later, async_track_template
from homeassistant.helpers.script_variables import ScriptVariables from homeassistant.helpers.script_variables import ScriptVariables
from homeassistant.helpers.service import (
CONF_SERVICE_DATA,
async_prepare_call_from_config,
)
from homeassistant.helpers.trigger import ( from homeassistant.helpers.trigger import (
async_initialize_triggers, async_initialize_triggers,
async_validate_trigger_config, async_validate_trigger_config,
@ -429,13 +426,13 @@ class _ScriptRun:
self._script.last_action = self._action.get(CONF_ALIAS, "call service") self._script.last_action = self._action.get(CONF_ALIAS, "call service")
self._log("Executing step %s", self._script.last_action) self._log("Executing step %s", self._script.last_action)
domain, service, service_data = async_prepare_call_from_config( domain, service_name, service_data = service.async_prepare_call_from_config(
self._hass, self._action, self._variables self._hass, self._action, self._variables
) )
running_script = ( running_script = (
domain == "automation" domain == "automation"
and service == "trigger" and service_name == "trigger"
or domain in ("python_script", "script") or domain in ("python_script", "script")
) )
# If this might start a script then disable the call timeout. # If this might start a script then disable the call timeout.
@ -448,7 +445,7 @@ class _ScriptRun:
service_task = self._hass.async_create_task( service_task = self._hass.async_create_task(
self._hass.services.async_call( self._hass.services.async_call(
domain, domain,
service, service_name,
service_data, service_data,
blocking=True, blocking=True,
context=self._context, context=self._context,
@ -755,6 +752,23 @@ async def _async_stop_scripts_at_shutdown(hass, event):
_VarsType = Union[Dict[str, Any], MappingProxyType] _VarsType = Union[Dict[str, Any], MappingProxyType]
def _referenced_extract_ids(data: Dict, key: str, found: Set[str]) -> None:
"""Extract referenced IDs."""
if not data:
return
item_ids = data.get(key)
if item_ids is None or isinstance(item_ids, template.Template):
return
if isinstance(item_ids, str):
item_ids = [item_ids]
for item_id in item_ids:
found.add(item_id)
class Script: class Script:
"""Representation of a script.""" """Representation of a script."""
@ -889,7 +903,16 @@ class Script:
for step in self.sequence: for step in self.sequence:
action = cv.determine_script_action(step) action = cv.determine_script_action(step)
if action == cv.SCRIPT_ACTION_CHECK_CONDITION: if action == cv.SCRIPT_ACTION_CALL_SERVICE:
for data in (
step,
step.get(CONF_TARGET),
step.get(service.CONF_SERVICE_DATA),
step.get(service.CONF_SERVICE_DATA_TEMPLATE),
):
_referenced_extract_ids(data, ATTR_DEVICE_ID, referenced)
elif action == cv.SCRIPT_ACTION_CHECK_CONDITION:
referenced |= condition.async_extract_devices(step) referenced |= condition.async_extract_devices(step)
elif action == cv.SCRIPT_ACTION_DEVICE_AUTOMATION: elif action == cv.SCRIPT_ACTION_DEVICE_AUTOMATION:
@ -910,20 +933,13 @@ class Script:
action = cv.determine_script_action(step) action = cv.determine_script_action(step)
if action == cv.SCRIPT_ACTION_CALL_SERVICE: if action == cv.SCRIPT_ACTION_CALL_SERVICE:
data = step.get(CONF_SERVICE_DATA) for data in (
if not data: step,
continue step.get(CONF_TARGET),
step.get(service.CONF_SERVICE_DATA),
entity_ids = data.get(ATTR_ENTITY_ID) step.get(service.CONF_SERVICE_DATA_TEMPLATE),
):
if entity_ids is None or isinstance(entity_ids, template.Template): _referenced_extract_ids(data, ATTR_ENTITY_ID, referenced)
continue
if isinstance(entity_ids, str):
entity_ids = [entity_ids]
for entity_id in entity_ids:
referenced.add(entity_id)
elif action == cv.SCRIPT_ACTION_CHECK_CONDITION: elif action == cv.SCRIPT_ACTION_CHECK_CONDITION:
referenced |= condition.async_extract_entities(step) referenced |= condition.async_extract_entities(step)

View file

@ -1254,3 +1254,6 @@ async def test_blueprint_automation(hass, calls):
hass.bus.async_fire("blueprint_event") hass.bus.async_fire("blueprint_event")
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
assert automation.entities_in_automation(hass, "automation.automation_0") == [
"light.kitchen"
]

View file

@ -124,7 +124,7 @@ async def test_save_blueprint(hass, aioclient_mock, hass_ws_client):
assert msg["success"] assert msg["success"]
assert write_mock.mock_calls assert write_mock.mock_calls
assert write_mock.call_args[0] == ( assert write_mock.call_args[0] == (
"blueprint:\n name: Call service based on event\n domain: automation\n input:\n trigger_event:\n service_to_call:\n source_url: https://github.com/balloob/home-assistant-config/blob/main/blueprints/automation/motion_light.yaml\ntrigger:\n platform: event\n event_type: !input 'trigger_event'\naction:\n service: !input 'service_to_call'\n", "blueprint:\n name: Call service based on event\n domain: automation\n input:\n trigger_event:\n service_to_call:\n source_url: https://github.com/balloob/home-assistant-config/blob/main/blueprints/automation/motion_light.yaml\ntrigger:\n platform: event\n event_type: !input 'trigger_event'\naction:\n service: !input 'service_to_call'\n entity_id: light.kitchen\n",
) )

View file

@ -1338,6 +1338,18 @@ async def test_referenced_entities(hass):
"service": "test.script", "service": "test.script",
"data": {"entity_id": "{{ 'light.service_template' }}"}, "data": {"entity_id": "{{ 'light.service_template' }}"},
}, },
{
"service": "test.script",
"entity_id": "light.direct_entity_referenced",
},
{
"service": "test.script",
"target": {"entity_id": "light.entity_in_target"},
},
{
"service": "test.script",
"data_template": {"entity_id": "light.entity_in_data_template"},
},
{ {
"condition": "state", "condition": "state",
"entity_id": "sensor.condition", "entity_id": "sensor.condition",
@ -1357,6 +1369,9 @@ async def test_referenced_entities(hass):
"light.service_list", "light.service_list",
"sensor.condition", "sensor.condition",
"scene.hello", "scene.hello",
"light.direct_entity_referenced",
"light.entity_in_target",
"light.entity_in_data_template",
} }
# Test we cache results. # Test we cache results.
assert script_obj.referenced_entities is script_obj.referenced_entities assert script_obj.referenced_entities is script_obj.referenced_entities
@ -1374,12 +1389,36 @@ async def test_referenced_devices(hass):
"device_id": "condition-dev-id", "device_id": "condition-dev-id",
"domain": "switch", "domain": "switch",
}, },
{
"service": "test.script",
"data": {"device_id": "data-string-id"},
},
{
"service": "test.script",
"data_template": {"device_id": "data-template-string-id"},
},
{
"service": "test.script",
"target": {"device_id": "target-string-id"},
},
{
"service": "test.script",
"target": {"device_id": ["target-list-id-1", "target-list-id-2"]},
},
] ]
), ),
"Test Name", "Test Name",
"test_domain", "test_domain",
) )
assert script_obj.referenced_devices == {"script-dev-id", "condition-dev-id"} assert script_obj.referenced_devices == {
"script-dev-id",
"condition-dev-id",
"data-string-id",
"data-template-string-id",
"target-string-id",
"target-list-id-1",
"target-list-id-2",
}
# Test we cache results. # Test we cache results.
assert script_obj.referenced_devices is script_obj.referenced_devices assert script_obj.referenced_devices is script_obj.referenced_devices

View file

@ -9,3 +9,4 @@ trigger:
event_type: !input trigger_event event_type: !input trigger_event
action: action:
service: !input service_to_call service: !input service_to_call
entity_id: light.kitchen