Teach state trigger about entity registry ids (#60271)

* Teach state trigger about entity registry ids

* Tweak

* Add tests

* Tweak tests

* Fix tests

* Resolve entity ids during config validation

* Update device_triggers

* Fix mistake

* Tweak trigger validator to ensure we don't modify the original config

* Add index from entry id to entry

* Update scaffold

* Pre-compile UUID regex

* Address review comment

* Tweak mock_registry

* Tweak

* Apply suggestion from code review
This commit is contained in:
Erik Montnemery 2021-12-02 14:26:45 +01:00 committed by GitHub
parent c0fb1bffce
commit c85bb27d0d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 324 additions and 74 deletions

View file

@ -157,7 +157,7 @@ async def async_attach_trigger(
} }
if CONF_FOR in config: if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR] state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config) state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
return await state_trigger.async_attach_trigger( return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device" hass, state_config, action, automation_info, platform_type="device"
) )

View file

@ -220,7 +220,7 @@ async def async_attach_trigger(hass, config, action, automation_info):
if CONF_FOR in config: if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR] state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config) state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
return await state_trigger.async_attach_trigger( return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device" hass, state_config, action, automation_info, platform_type="device"
) )

View file

@ -11,8 +11,8 @@ from homeassistant.components.automation import (
) )
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
from homeassistant.components.homeassistant.triggers.state import ( from homeassistant.components.homeassistant.triggers.state import (
TRIGGER_SCHEMA as STATE_TRIGGER_SCHEMA,
async_attach_trigger as async_attach_state_trigger, async_attach_trigger as async_attach_state_trigger,
async_validate_trigger_config as async_validate_state_trigger_config,
) )
from homeassistant.const import ( from homeassistant.const import (
CONF_DEVICE_ID, CONF_DEVICE_ID,
@ -67,7 +67,7 @@ async def async_attach_trigger(
CONF_ENTITY_ID: config[CONF_ENTITY_ID], CONF_ENTITY_ID: config[CONF_ENTITY_ID],
} }
state_config = STATE_TRIGGER_SCHEMA(state_config) state_config = await async_validate_state_trigger_config(hass, state_config)
return await async_attach_state_trigger( return await async_attach_state_trigger(
hass, state_config, action, automation_info, platform_type="device" hass, state_config, action, automation_info, platform_type="device"
) )

View file

@ -131,7 +131,9 @@ async def async_attach_trigger(
} }
if CONF_FOR in config: if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR] state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config) state_config = await state_trigger.async_validate_trigger_config(
hass, state_config
)
return await state_trigger.async_attach_trigger( return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device" hass, state_config, action, automation_info, platform_type="device"
) )

View file

@ -170,7 +170,9 @@ async def async_attach_trigger(
} }
if CONF_FOR in config: if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR] state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config) state_config = await state_trigger.async_validate_trigger_config(
hass, state_config
)
return await state_trigger.async_attach_trigger( return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device" hass, state_config, action, automation_info, platform_type="device"
) )

View file

@ -164,7 +164,7 @@ async def async_attach_trigger(
if CONF_FOR in config: if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR] state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config) state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
return await state_trigger.async_attach_trigger( return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device" hass, state_config, action, automation_info, platform_type="device"
) )

View file

@ -3,20 +3,24 @@ from __future__ import annotations
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant import exceptions from homeassistant import exceptions
from homeassistant.const import CONF_ATTRIBUTE, CONF_FOR, CONF_PLATFORM, MATCH_ALL from homeassistant.const import CONF_ATTRIBUTE, CONF_FOR, CONF_PLATFORM, MATCH_ALL
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, State, callback from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, State, callback
from homeassistant.helpers import config_validation as cv, template from homeassistant.helpers import (
config_validation as cv,
entity_registry as er,
template,
)
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
Event, Event,
async_track_same_state, async_track_same_state,
async_track_state_change_event, async_track_state_change_event,
process_state_match, process_state_match,
) )
from homeassistant.helpers.typing import ConfigType
# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs # mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs # mypy: no-check-untyped-defs
@ -30,7 +34,7 @@ CONF_TO = "to"
BASE_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend( BASE_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend(
{ {
vol.Required(CONF_PLATFORM): "state", vol.Required(CONF_PLATFORM): "state",
vol.Required(CONF_ENTITY_ID): cv.entity_ids, vol.Required(CONF_ENTITY_ID): cv.entity_ids_or_uuids,
vol.Optional(CONF_FOR): cv.positive_time_period_template, vol.Optional(CONF_FOR): cv.positive_time_period_template,
vol.Optional(CONF_ATTRIBUTE): cv.match_all, vol.Optional(CONF_ATTRIBUTE): cv.match_all,
} }
@ -52,17 +56,26 @@ TRIGGER_ATTRIBUTE_SCHEMA = BASE_SCHEMA.extend(
) )
def TRIGGER_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name async def async_validate_trigger_config(
"""Validate trigger.""" hass: HomeAssistant, config: ConfigType
if not isinstance(value, dict): ) -> ConfigType:
"""Validate trigger config."""
if not isinstance(config, dict):
raise vol.Invalid("Expected a dictionary") raise vol.Invalid("Expected a dictionary")
# We use this approach instead of vol.Any because # We use this approach instead of vol.Any because
# this gives better error messages. # this gives better error messages.
if CONF_ATTRIBUTE in value: if CONF_ATTRIBUTE in config:
return TRIGGER_ATTRIBUTE_SCHEMA(value) config = TRIGGER_ATTRIBUTE_SCHEMA(config)
else:
config = TRIGGER_STATE_SCHEMA(config)
return TRIGGER_STATE_SCHEMA(value) registry = er.async_get(hass)
config[CONF_ENTITY_ID] = er.async_resolve_entity_ids(
registry, cv.entity_ids_or_uuids(config[CONF_ENTITY_ID])
)
return config
async def async_attach_trigger( async def async_attach_trigger(
@ -74,7 +87,7 @@ async def async_attach_trigger(
platform_type: str = "state", platform_type: str = "state",
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
entity_id = config.get(CONF_ENTITY_ID) entity_ids = config[CONF_ENTITY_ID]
if (from_state := config.get(CONF_FROM)) is None: if (from_state := config.get(CONF_FROM)) is None:
from_state = MATCH_ALL from_state = MATCH_ALL
if (to_state := config.get(CONF_TO)) is None: if (to_state := config.get(CONF_TO)) is None:
@ -196,7 +209,7 @@ async def async_attach_trigger(
entity_ids=entity, entity_ids=entity,
) )
unsub = async_track_state_change_event(hass, entity_id, state_automation_listener) unsub = async_track_state_change_event(hass, entity_ids, state_automation_listener)
@callback @callback
def async_remove(): def async_remove():

View file

@ -104,7 +104,7 @@ async def async_attach_trigger(
} }
if CONF_FOR in config: if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR] state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config) state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
return await state_trigger.async_attach_trigger( return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device" hass, state_config, action, automation_info, platform_type="device"
) )

View file

@ -104,7 +104,7 @@ async def async_attach_trigger(
} }
if CONF_FOR in config: if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR] state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config) state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
return await state_trigger.async_attach_trigger( return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device" hass, state_config, action, automation_info, platform_type="device"
) )

View file

@ -14,8 +14,8 @@ from homeassistant.components.homeassistant.triggers.state import (
CONF_FOR, CONF_FOR,
CONF_FROM, CONF_FROM,
CONF_TO, CONF_TO,
TRIGGER_SCHEMA as STATE_TRIGGER_SCHEMA,
async_attach_trigger as async_attach_state_trigger, async_attach_trigger as async_attach_state_trigger,
async_validate_trigger_config as async_validate_state_trigger_config,
) )
from homeassistant.components.select.const import ATTR_OPTIONS from homeassistant.components.select.const import ATTR_OPTIONS
from homeassistant.const import ( from homeassistant.const import (
@ -84,7 +84,7 @@ async def async_attach_trigger(
if CONF_FOR in config: if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR] state_config[CONF_FOR] = config[CONF_FOR]
state_config = STATE_TRIGGER_SCHEMA(state_config) state_config = await async_validate_state_trigger_config(hass, state_config)
return await async_attach_state_trigger( return await async_attach_state_trigger(
hass, state_config, action, automation_info, platform_type="device" hass, state_config, action, automation_info, platform_type="device"
) )

View file

@ -92,7 +92,7 @@ async def async_attach_trigger(
} }
if CONF_FOR in config: if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR] state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config) state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
return await state_trigger.async_attach_trigger( return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device" hass, state_config, action, automation_info, platform_type="device"
) )

View file

@ -415,7 +415,7 @@ async def async_attach_trigger(
else: else:
raise HomeAssistantError(f"Unhandled trigger type {trigger_type}") raise HomeAssistantError(f"Unhandled trigger type {trigger_type}")
state_config = state.TRIGGER_SCHEMA(state_config) state_config = await state.async_validate_trigger_config(hass, state_config)
return await state.async_attach_trigger( return await state.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device" hass, state_config, action, automation_info, platform_type="device"
) )

View file

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Hashable from collections.abc import Callable, Hashable
import contextlib
from datetime import ( from datetime import (
date as date_sys, date as date_sys,
datetime as datetime_sys, datetime as datetime_sys,
@ -262,14 +263,34 @@ def entity_id(value: Any) -> str:
raise vol.Invalid(f"Entity ID {value} is an invalid entity ID") raise vol.Invalid(f"Entity ID {value} is an invalid entity ID")
def entity_ids(value: str | list) -> list[str]: def entity_id_or_uuid(value: Any) -> str:
"""Validate Entity IDs.""" """Validate Entity specified by entity_id or uuid."""
with contextlib.suppress(vol.Invalid):
return entity_id(value)
with contextlib.suppress(vol.Invalid):
return fake_uuid4_hex(value)
raise vol.Invalid(f"Entity {value} is neither a valid entity ID nor a valid UUID")
def _entity_ids(value: str | list, allow_uuid: bool) -> list[str]:
"""Help validate entity IDs or UUIDs."""
if value is None: if value is None:
raise vol.Invalid("Entity IDs can not be None") raise vol.Invalid("Entity IDs can not be None")
if isinstance(value, str): if isinstance(value, str):
value = [ent_id.strip() for ent_id in value.split(",")] value = [ent_id.strip() for ent_id in value.split(",")]
return [entity_id(ent_id) for ent_id in value] validator = entity_id_or_uuid if allow_uuid else entity_id
return [validator(ent_id) for ent_id in value]
def entity_ids(value: str | list) -> list[str]:
"""Validate Entity IDs."""
return _entity_ids(value, False)
def entity_ids_or_uuids(value: str | list) -> list[str]:
"""Validate entities specified by entity IDs or UUIDs."""
return _entity_ids(value, True)
comp_entity_ids = vol.Any( comp_entity_ids = vol.Any(
@ -682,6 +703,16 @@ def uuid4_hex(value: Any) -> str:
return result.hex return result.hex
_FAKE_UUID_4_HEX = re.compile(r"^[0-9a-f]{32}$")
def fake_uuid4_hex(value: Any) -> str:
"""Validate a fake v4 UUID generated by random_uuid_hex."""
if not _FAKE_UUID_4_HEX.match(value):
raise vol.Invalid("Invalid UUID")
return cast(str, value) # Pattern.match throws if input is not a string
def ensure_list_csv(value: Any) -> list: def ensure_list_csv(value: Any) -> list:
"""Ensure that input is a list or make one from comma-separated string.""" """Ensure that input is a list or make one from comma-separated string."""
if isinstance(value, str): if isinstance(value, str):

View file

@ -9,12 +9,13 @@ timer.
""" """
from __future__ import annotations from __future__ import annotations
from collections import OrderedDict from collections import UserDict
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
import logging import logging
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
import attr import attr
import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
ATTR_DEVICE_CLASS, ATTR_DEVICE_CLASS,
@ -161,14 +162,57 @@ class EntityRegistryStore(storage.Store):
return await _async_migrate(old_major_version, old_minor_version, old_data) return await _async_migrate(old_major_version, old_minor_version, old_data)
class EntityRegistryItems(UserDict):
"""Container for entity registry items, maps entity_id -> entry.
Maintains two additional indexes:
- id -> entry
- (domain, platform, unique_id) -> entry
"""
def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._entry_ids: dict[str, RegistryEntry] = {}
self._index: dict[tuple[str, str, str], str] = {}
def __setitem__(self, key: str, entry: RegistryEntry) -> None:
"""Add an item."""
if key in self:
old_entry = self[key]
del self._entry_ids[old_entry.id]
del self._index[(old_entry.domain, old_entry.platform, old_entry.unique_id)]
super().__setitem__(key, entry)
self._entry_ids.__setitem__(entry.id, entry)
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
self._entry_ids.__delitem__(entry.id)
self._index.__delitem__((entry.domain, entry.platform, entry.unique_id))
super().__delitem__(key)
def __getitem__(self, key: str) -> RegistryEntry:
"""Get an item."""
return cast(RegistryEntry, super().__getitem__(key))
def get_entity_id(self, key: tuple[str, str, str]) -> str | None:
"""Get entity_id from (domain, platform, unique_id)."""
return self._index.get(key)
def get_entry(self, key: str) -> RegistryEntry | None:
"""Get entry from id."""
return self._entry_ids.get(key)
class EntityRegistry: class EntityRegistry:
"""Class to hold a registry of entities.""" """Class to hold a registry of entities."""
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the registry.""" """Initialize the registry."""
self.hass = hass self.hass = hass
self.entities: dict[str, RegistryEntry] self.entities: EntityRegistryItems
self._index: dict[tuple[str, str, str], str] = {}
self._store = EntityRegistryStore( self._store = EntityRegistryStore(
hass, hass,
STORAGE_VERSION_MAJOR, STORAGE_VERSION_MAJOR,
@ -218,7 +262,7 @@ class EntityRegistry:
self, domain: str, platform: str, unique_id: str self, domain: str, platform: str, unique_id: str
) -> str | None: ) -> str | None:
"""Check if an entity_id is currently registered.""" """Check if an entity_id is currently registered."""
return self._index.get((domain, platform, unique_id)) return self.entities.get_entity_id((domain, platform, unique_id))
@callback @callback
def async_generate_entity_id( def async_generate_entity_id(
@ -320,7 +364,7 @@ class EntityRegistry:
): ):
disabled_by = DISABLED_INTEGRATION disabled_by = DISABLED_INTEGRATION
entity = RegistryEntry( entry = RegistryEntry(
area_id=area_id, area_id=area_id,
capabilities=capabilities, capabilities=capabilities,
config_entry_id=config_entry_id, config_entry_id=config_entry_id,
@ -336,7 +380,7 @@ class EntityRegistry:
unique_id=unique_id, unique_id=unique_id,
unit_of_measurement=unit_of_measurement, unit_of_measurement=unit_of_measurement,
) )
self._register_entry(entity) self.entities[entity_id] = entry
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id) _LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
self.async_schedule_save() self.async_schedule_save()
@ -344,12 +388,12 @@ class EntityRegistry:
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id} EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id}
) )
return entity return entry
@callback @callback
def async_remove(self, entity_id: str) -> None: def async_remove(self, entity_id: str) -> None:
"""Remove an entity from registry.""" """Remove an entity from registry."""
self._unregister_entry(self.entities[entity_id]) self.entities.pop(entity_id)
self.hass.bus.async_fire( self.hass.bus.async_fire(
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id} EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id}
) )
@ -513,9 +557,7 @@ class EntityRegistry:
if not new_values: if not new_values:
return old return old
self._remove_index(old) new = self.entities[entity_id] = attr.evolve(old, **new_values)
new = attr.evolve(old, **new_values)
self._register_entry(new)
self.async_schedule_save() self.async_schedule_save()
@ -539,7 +581,7 @@ class EntityRegistry:
old_conf_load_func=load_yaml, old_conf_load_func=load_yaml,
old_conf_migrate_func=_async_migrate_yaml_to_json, old_conf_migrate_func=_async_migrate_yaml_to_json,
) )
entities: dict[str, RegistryEntry] = OrderedDict() entities = EntityRegistryItems()
if data is not None: if data is not None:
for entity in data["entities"]: for entity in data["entities"]:
@ -571,7 +613,6 @@ class EntityRegistry:
) )
self.entities = entities self.entities = entities
self._rebuild_index()
@callback @callback
def async_schedule_save(self) -> None: def async_schedule_save(self) -> None:
@ -626,25 +667,6 @@ class EntityRegistry:
if area_id == entry.area_id: if area_id == entry.area_id:
self._async_update_entity(entity_id, area_id=None) self._async_update_entity(entity_id, area_id=None)
def _register_entry(self, entry: RegistryEntry) -> None:
self.entities[entry.entity_id] = entry
self._add_index(entry)
def _add_index(self, entry: RegistryEntry) -> None:
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
def _unregister_entry(self, entry: RegistryEntry) -> None:
self._remove_index(entry)
del self.entities[entry.entity_id]
def _remove_index(self, entry: RegistryEntry) -> None:
del self._index[(entry.domain, entry.platform, entry.unique_id)]
def _rebuild_index(self) -> None:
self._index = {}
for entry in self.entities.values():
self._add_index(entry)
@callback @callback
def async_get(hass: HomeAssistant) -> EntityRegistry: def async_get(hass: HomeAssistant) -> EntityRegistry:
@ -841,3 +863,25 @@ async def async_migrate_entries(
if updates is not None: if updates is not None:
ent_reg.async_update_entity(entry.entity_id, **updates) ent_reg.async_update_entity(entry.entity_id, **updates)
@callback
def async_resolve_entity_ids(
registry: EntityRegistry, entity_ids_or_uuids: list[str]
) -> list[str]:
"""Resolve a list of entity ids or UUIDs to a list of entity ids."""
def resolve_entity(entity_id_or_uuid: str) -> str | None:
"""Resolve an entity id or UUID to an entity id or None."""
if valid_entity_id(entity_id_or_uuid):
return entity_id_or_uuid
if (entry := registry.entities.get_entry(entity_id_or_uuid)) is None:
raise vol.Invalid(f"Unknown entity registry entry {entity_id_or_uuid}")
return entry.entity_id
tmp = [
resolved_item
for item in entity_ids_or_uuids
if (resolved_item := resolve_entity(item)) is not None
]
return tmp

View file

@ -10,7 +10,7 @@ from homeassistant.components.automation import (
AutomationTriggerInfo, AutomationTriggerInfo,
) )
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
from homeassistant.components.homeassistant.triggers import state from homeassistant.components.homeassistant.triggers import state as state_trigger
from homeassistant.const import ( from homeassistant.const import (
CONF_DEVICE_ID, CONF_DEVICE_ID,
CONF_DOMAIN, CONF_DOMAIN,
@ -86,11 +86,11 @@ async def async_attach_trigger(
to_state = STATE_OFF to_state = STATE_OFF
state_config = { state_config = {
state.CONF_PLATFORM: "state", state_trigger.CONF_PLATFORM: "state",
CONF_ENTITY_ID: config[CONF_ENTITY_ID], CONF_ENTITY_ID: config[CONF_ENTITY_ID],
state.CONF_TO: to_state, state_trigger.CONF_TO: to_state,
} }
state_config = state.TRIGGER_SCHEMA(state_config) state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
return await state.async_attach_trigger( return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device" hass, state_config, action, automation_info, platform_type="device"
) )

View file

@ -440,8 +440,11 @@ def mock_component(hass, component):
def mock_registry(hass, mock_entries=None): def mock_registry(hass, mock_entries=None):
"""Mock the Entity Registry.""" """Mock the Entity Registry."""
registry = entity_registry.EntityRegistry(hass) registry = entity_registry.EntityRegistry(hass)
registry.entities = mock_entries or OrderedDict() if mock_entries is None:
registry._rebuild_index() mock_entries = {}
registry.entities = entity_registry.EntityRegistryItems()
for key, entry in mock_entries.items():
registry.entities[key] = entry
hass.data[entity_registry.DATA_REGISTRY] = registry hass.data[entity_registry.DATA_REGISTRY] = registry
return registry return registry

View file

@ -8,6 +8,7 @@ import homeassistant.components.automation as automation
from homeassistant.components.homeassistant.triggers import state as state_trigger from homeassistant.components.homeassistant.triggers import state as state_trigger
from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL, SERVICE_TURN_OFF from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL, SERVICE_TURN_OFF
from homeassistant.core import Context from homeassistant.core import Context
from homeassistant.helpers import entity_registry as er
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -82,6 +83,64 @@ async def test_if_fires_on_entity_change(hass, calls):
assert len(calls) == 1 assert len(calls) == 1
async def test_if_fires_on_entity_change_uuid(hass, calls):
"""Test for firing on entity change."""
context = Context()
registry = er.async_get(hass)
entry = registry.async_get_or_create(
"test", "hue", "1234", suggested_object_id="beer"
)
assert entry.entity_id == "test.beer"
hass.states.async_set("test.beer", "hello")
await hass.async_block_till_done()
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": {"platform": "state", "entity_id": entry.id},
"action": {
"service": "test.automation",
"data_template": {
"some": "{{ trigger.%s }}"
% "}} - {{ trigger.".join(
(
"platform",
"entity_id",
"from_state.state",
"to_state.state",
"for",
"id",
)
)
},
},
}
},
)
await hass.async_block_till_done()
hass.states.async_set("test.beer", "world", context=context)
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].context.parent_id == context.id
assert calls[0].data["some"] == "state - test.beer - hello - world - None - 0"
await hass.services.async_call(
automation.DOMAIN,
SERVICE_TURN_OFF,
{ATTR_ENTITY_ID: ENTITY_MATCH_ALL},
blocking=True,
)
hass.states.async_set("test.beer", "planet")
await hass.async_block_till_done()
assert len(calls) == 1
async def test_if_fires_on_entity_change_with_from_filter(hass, calls): async def test_if_fires_on_entity_change_with_from_filter(hass, calls):
"""Test for firing on entity change with filter.""" """Test for firing on entity change with filter."""
assert await async_setup_component( assert await async_setup_component(

View file

@ -172,9 +172,10 @@ def test_entity_id():
assert schema("sensor.LIGHT") == "sensor.light" assert schema("sensor.LIGHT") == "sensor.light"
def test_entity_ids(): @pytest.mark.parametrize("validator", [cv.entity_ids, cv.entity_ids_or_uuids])
def test_entity_ids(validator):
"""Test entity ID validation.""" """Test entity ID validation."""
schema = vol.Schema(cv.entity_ids) schema = vol.Schema(validator)
options = ( options = (
"invalid_entity", "invalid_entity",
@ -194,6 +195,32 @@ def test_entity_ids():
assert schema("sensor.LIGHT, light.kitchen ") == ["sensor.light", "light.kitchen"] assert schema("sensor.LIGHT, light.kitchen ") == ["sensor.light", "light.kitchen"]
def test_entity_ids_or_uuids():
"""Test entity ID validation."""
schema = vol.Schema(cv.entity_ids_or_uuids)
valid_uuid = "a266a680b608c32770e6c45bfe6b8411"
valid_uuid2 = "a266a680b608c32770e6c45bfe6b8412"
invalid_uuid_capital_letters = "A266A680B608C32770E6C45bfE6B8412"
options = (
"invalid_uuid",
invalid_uuid_capital_letters,
f"{valid_uuid},invalid_uuid",
["invalid_uuid"],
[valid_uuid, "invalid_uuid"],
[f"{valid_uuid},invalid_uuid"],
)
for value in options:
with pytest.raises(vol.MultipleInvalid):
schema(value)
options = ([], [valid_uuid], valid_uuid)
for value in options:
schema(value)
assert schema(f"{valid_uuid}, {valid_uuid2} ") == [valid_uuid, valid_uuid2]
def test_entity_domain(): def test_entity_domain():
"""Test entity domain validation.""" """Test entity domain validation."""
schema = vol.Schema(cv.entity_domain("sensor")) schema = vol.Schema(cv.entity_domain("sensor"))

View file

@ -2,6 +2,7 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE
@ -1023,3 +1024,60 @@ async def test_entity_max_length_exceeded(hass, registry):
assert exc_info.value.property_name == "generated_entity_id" assert exc_info.value.property_name == "generated_entity_id"
assert exc_info.value.max_length == 255 assert exc_info.value.max_length == 255
assert exc_info.value.value == f"sensor.{long_entity_id_name}_2" assert exc_info.value.value == f"sensor.{long_entity_id_name}_2"
async def test_resolve_entity_ids(hass, registry):
"""Test resolving entity IDs."""
entry1 = registry.async_get_or_create(
"light", "hue", "1234", suggested_object_id="beer"
)
assert entry1.entity_id == "light.beer"
entry2 = registry.async_get_or_create(
"light", "hue", "2345", suggested_object_id="milk"
)
assert entry2.entity_id == "light.milk"
expected = ["light.beer", "light.milk"]
assert er.async_resolve_entity_ids(registry, [entry1.id, entry2.id]) == expected
expected = ["light.beer", "light.milk"]
assert er.async_resolve_entity_ids(registry, ["light.beer", entry2.id]) == expected
with pytest.raises(vol.Invalid):
er.async_resolve_entity_ids(registry, ["light.beer", "bad_uuid"])
expected = ["light.unknown"]
assert er.async_resolve_entity_ids(registry, ["light.unknown"]) == expected
with pytest.raises(vol.Invalid):
er.async_resolve_entity_ids(registry, ["unknown_uuid"])
def test_entity_registry_items():
"""Test the EntityRegistryItems container."""
entities = er.EntityRegistryItems()
assert entities.get_entity_id(("a", "b", "c")) is None
assert entities.get_entry("abc") is None
entry1 = er.RegistryEntry("test.entity1", "1234", "hue")
entry2 = er.RegistryEntry("test.entity2", "2345", "hue")
entities["test.entity1"] = entry1
entities["test.entity2"] = entry2
assert entities["test.entity1"] is entry1
assert entities["test.entity2"] is entry2
assert entities.get_entity_id(("test", "hue", "1234")) is entry1.entity_id
assert entities.get_entry(entry1.id) is entry1
assert entities.get_entity_id(("test", "hue", "2345")) is entry2.entity_id
assert entities.get_entry(entry2.id) is entry2
entities.pop("test.entity1")
del entities["test.entity2"]
assert entities.get_entity_id(("test", "hue", "1234")) is None
assert entities.get_entry(entry1.id) is None
assert entities.get_entity_id(("test", "hue", "2345")) is None
assert entities.get_entry(entry2.id) is None

View file

@ -748,6 +748,7 @@ async def test_wait_basic(hass, action_type):
"to": "off", "to": "off",
} }
sequence = cv.SCRIPT_SCHEMA(action) sequence = cv.SCRIPT_SCHEMA(action)
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain") script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, wait_alias) wait_started_flag = async_watch_for_action(script_obj, wait_alias)
@ -848,6 +849,7 @@ async def test_wait_basic_times_out(hass, action_type):
"to": "off", "to": "off",
} }
sequence = cv.SCRIPT_SCHEMA(action) sequence = cv.SCRIPT_SCHEMA(action)
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain") script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, wait_alias) wait_started_flag = async_watch_for_action(script_obj, wait_alias)
timed_out = False timed_out = False
@ -904,6 +906,7 @@ async def test_multiple_runs_wait(hass, action_type):
{"event": event, "event_data": {"value": 2}}, {"event": event, "event_data": {"value": 2}},
] ]
) )
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script( script_obj = script.Script(
hass, sequence, "Test Name", "test_domain", script_mode="parallel", max_runs=2 hass, sequence, "Test Name", "test_domain", script_mode="parallel", max_runs=2
) )
@ -952,6 +955,7 @@ async def test_cancel_wait(hass, action_type):
} }
} }
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}]) sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain") script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait") wait_started_flag = async_watch_for_action(script_obj, "wait")
@ -1049,6 +1053,7 @@ async def test_wait_timeout(hass, caplog, timeout_param, action_type):
action["timeout"] = timeout_param action["timeout"] = timeout_param
action["continue_on_timeout"] = True action["continue_on_timeout"] = True
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}]) sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain") script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait") wait_started_flag = async_watch_for_action(script_obj, "wait")
@ -1116,6 +1121,7 @@ async def test_wait_continue_on_timeout(
if continue_on_timeout is not None: if continue_on_timeout is not None:
action["continue_on_timeout"] = continue_on_timeout action["continue_on_timeout"] = continue_on_timeout
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}]) sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain") script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait") wait_started_flag = async_watch_for_action(script_obj, "wait")
@ -1287,6 +1293,7 @@ async def test_wait_variables_out(hass, mode, action_type):
}, },
] ]
sequence = cv.SCRIPT_SCHEMA(sequence) sequence = cv.SCRIPT_SCHEMA(sequence)
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain") script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait") wait_started_flag = async_watch_for_action(script_obj, "wait")
@ -1326,11 +1333,13 @@ async def test_wait_variables_out(hass, mode, action_type):
async def test_wait_for_trigger_bad(hass, caplog): async def test_wait_for_trigger_bad(hass, caplog):
"""Test bad wait_for_trigger.""" """Test bad wait_for_trigger."""
sequence = cv.SCRIPT_SCHEMA(
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
)
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script( script_obj = script.Script(
hass, hass,
cv.SCRIPT_SCHEMA( sequence,
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
),
"Test Name", "Test Name",
"test_domain", "test_domain",
) )
@ -1356,11 +1365,13 @@ async def test_wait_for_trigger_bad(hass, caplog):
async def test_wait_for_trigger_generated_exception(hass, caplog): async def test_wait_for_trigger_generated_exception(hass, caplog):
"""Test bad wait_for_trigger.""" """Test bad wait_for_trigger."""
sequence = cv.SCRIPT_SCHEMA(
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
)
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script( script_obj = script.Script(
hass, hass,
cv.SCRIPT_SCHEMA( sequence,
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
),
"Test Name", "Test Name",
"test_domain", "test_domain",
) )