diff --git a/homeassistant/components/alarm_control_panel/device_trigger.py b/homeassistant/components/alarm_control_panel/device_trigger.py index 92c73b07bbd..9eea745862a 100644 --- a/homeassistant/components/alarm_control_panel/device_trigger.py +++ b/homeassistant/components/alarm_control_panel/device_trigger.py @@ -157,7 +157,7 @@ async def async_attach_trigger( } if CONF_FOR in config: state_config[CONF_FOR] = config[CONF_FOR] - state_config = state_trigger.TRIGGER_SCHEMA(state_config) + state_config = await state_trigger.async_validate_trigger_config(hass, state_config) return await state_trigger.async_attach_trigger( hass, state_config, action, automation_info, platform_type="device" ) diff --git a/homeassistant/components/binary_sensor/device_trigger.py b/homeassistant/components/binary_sensor/device_trigger.py index 72cd885d467..0f2c7a836a2 100644 --- a/homeassistant/components/binary_sensor/device_trigger.py +++ b/homeassistant/components/binary_sensor/device_trigger.py @@ -220,7 +220,7 @@ async def async_attach_trigger(hass, config, action, automation_info): if CONF_FOR in config: state_config[CONF_FOR] = config[CONF_FOR] - state_config = state_trigger.TRIGGER_SCHEMA(state_config) + state_config = await state_trigger.async_validate_trigger_config(hass, state_config) return await state_trigger.async_attach_trigger( hass, state_config, action, automation_info, platform_type="device" ) diff --git a/homeassistant/components/button/device_trigger.py b/homeassistant/components/button/device_trigger.py index 2005bc82194..6d4692234f7 100644 --- a/homeassistant/components/button/device_trigger.py +++ b/homeassistant/components/button/device_trigger.py @@ -11,8 +11,8 @@ from homeassistant.components.automation import ( ) from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA from homeassistant.components.homeassistant.triggers.state import ( - TRIGGER_SCHEMA as STATE_TRIGGER_SCHEMA, async_attach_trigger as async_attach_state_trigger, + async_validate_trigger_config as async_validate_state_trigger_config, ) from homeassistant.const import ( CONF_DEVICE_ID, @@ -67,7 +67,7 @@ async def async_attach_trigger( CONF_ENTITY_ID: config[CONF_ENTITY_ID], } - state_config = STATE_TRIGGER_SCHEMA(state_config) + state_config = await async_validate_state_trigger_config(hass, state_config) return await async_attach_state_trigger( hass, state_config, action, automation_info, platform_type="device" ) diff --git a/homeassistant/components/climate/device_trigger.py b/homeassistant/components/climate/device_trigger.py index 05212e6ab99..3a9e0e45900 100644 --- a/homeassistant/components/climate/device_trigger.py +++ b/homeassistant/components/climate/device_trigger.py @@ -131,7 +131,9 @@ async def async_attach_trigger( } if CONF_FOR in config: state_config[CONF_FOR] = config[CONF_FOR] - state_config = state_trigger.TRIGGER_SCHEMA(state_config) + state_config = await state_trigger.async_validate_trigger_config( + hass, state_config + ) return await state_trigger.async_attach_trigger( hass, state_config, action, automation_info, platform_type="device" ) diff --git a/homeassistant/components/cover/device_trigger.py b/homeassistant/components/cover/device_trigger.py index f4a2f4443d1..b9a0aefb7a2 100644 --- a/homeassistant/components/cover/device_trigger.py +++ b/homeassistant/components/cover/device_trigger.py @@ -170,7 +170,9 @@ async def async_attach_trigger( } if CONF_FOR in config: state_config[CONF_FOR] = config[CONF_FOR] - state_config = state_trigger.TRIGGER_SCHEMA(state_config) + state_config = await state_trigger.async_validate_trigger_config( + hass, state_config + ) return await state_trigger.async_attach_trigger( hass, state_config, action, automation_info, platform_type="device" ) diff --git a/homeassistant/components/device_automation/toggle_entity.py b/homeassistant/components/device_automation/toggle_entity.py index a1ee84da2fb..f9f7555eeb6 100644 --- a/homeassistant/components/device_automation/toggle_entity.py +++ b/homeassistant/components/device_automation/toggle_entity.py @@ -164,7 +164,7 @@ async def async_attach_trigger( if CONF_FOR in config: state_config[CONF_FOR] = config[CONF_FOR] - state_config = state_trigger.TRIGGER_SCHEMA(state_config) + state_config = await state_trigger.async_validate_trigger_config(hass, state_config) return await state_trigger.async_attach_trigger( hass, state_config, action, automation_info, platform_type="device" ) diff --git a/homeassistant/components/homeassistant/triggers/state.py b/homeassistant/components/homeassistant/triggers/state.py index f1e2bbf2c09..e16416c2f13 100644 --- a/homeassistant/components/homeassistant/triggers/state.py +++ b/homeassistant/components/homeassistant/triggers/state.py @@ -3,20 +3,24 @@ from __future__ import annotations from datetime import timedelta import logging -from typing import Any import voluptuous as vol from homeassistant import exceptions from homeassistant.const import CONF_ATTRIBUTE, CONF_FOR, CONF_PLATFORM, MATCH_ALL from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, State, callback -from homeassistant.helpers import config_validation as cv, template +from homeassistant.helpers import ( + config_validation as cv, + entity_registry as er, + template, +) from homeassistant.helpers.event import ( Event, async_track_same_state, async_track_state_change_event, process_state_match, ) +from homeassistant.helpers.typing import ConfigType # mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs # mypy: no-check-untyped-defs @@ -30,7 +34,7 @@ CONF_TO = "to" BASE_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend( { vol.Required(CONF_PLATFORM): "state", - vol.Required(CONF_ENTITY_ID): cv.entity_ids, + vol.Required(CONF_ENTITY_ID): cv.entity_ids_or_uuids, vol.Optional(CONF_FOR): cv.positive_time_period_template, vol.Optional(CONF_ATTRIBUTE): cv.match_all, } @@ -52,17 +56,26 @@ TRIGGER_ATTRIBUTE_SCHEMA = BASE_SCHEMA.extend( ) -def TRIGGER_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name - """Validate trigger.""" - if not isinstance(value, dict): +async def async_validate_trigger_config( + hass: HomeAssistant, config: ConfigType +) -> ConfigType: + """Validate trigger config.""" + if not isinstance(config, dict): raise vol.Invalid("Expected a dictionary") # We use this approach instead of vol.Any because # this gives better error messages. - if CONF_ATTRIBUTE in value: - return TRIGGER_ATTRIBUTE_SCHEMA(value) + if CONF_ATTRIBUTE in config: + config = TRIGGER_ATTRIBUTE_SCHEMA(config) + else: + config = TRIGGER_STATE_SCHEMA(config) - return TRIGGER_STATE_SCHEMA(value) + registry = er.async_get(hass) + config[CONF_ENTITY_ID] = er.async_resolve_entity_ids( + registry, cv.entity_ids_or_uuids(config[CONF_ENTITY_ID]) + ) + + return config async def async_attach_trigger( @@ -74,7 +87,7 @@ async def async_attach_trigger( platform_type: str = "state", ) -> CALLBACK_TYPE: """Listen for state changes based on configuration.""" - entity_id = config.get(CONF_ENTITY_ID) + entity_ids = config[CONF_ENTITY_ID] if (from_state := config.get(CONF_FROM)) is None: from_state = MATCH_ALL if (to_state := config.get(CONF_TO)) is None: @@ -196,7 +209,7 @@ async def async_attach_trigger( entity_ids=entity, ) - unsub = async_track_state_change_event(hass, entity_id, state_automation_listener) + unsub = async_track_state_change_event(hass, entity_ids, state_automation_listener) @callback def async_remove(): diff --git a/homeassistant/components/lock/device_trigger.py b/homeassistant/components/lock/device_trigger.py index cbdab7abb3d..75415bbf3e1 100644 --- a/homeassistant/components/lock/device_trigger.py +++ b/homeassistant/components/lock/device_trigger.py @@ -104,7 +104,7 @@ async def async_attach_trigger( } if CONF_FOR in config: state_config[CONF_FOR] = config[CONF_FOR] - state_config = state_trigger.TRIGGER_SCHEMA(state_config) + state_config = await state_trigger.async_validate_trigger_config(hass, state_config) return await state_trigger.async_attach_trigger( hass, state_config, action, automation_info, platform_type="device" ) diff --git a/homeassistant/components/media_player/device_trigger.py b/homeassistant/components/media_player/device_trigger.py index 9aa75ab935c..d48a657794b 100644 --- a/homeassistant/components/media_player/device_trigger.py +++ b/homeassistant/components/media_player/device_trigger.py @@ -104,7 +104,7 @@ async def async_attach_trigger( } if CONF_FOR in config: state_config[CONF_FOR] = config[CONF_FOR] - state_config = state_trigger.TRIGGER_SCHEMA(state_config) + state_config = await state_trigger.async_validate_trigger_config(hass, state_config) return await state_trigger.async_attach_trigger( hass, state_config, action, automation_info, platform_type="device" ) diff --git a/homeassistant/components/select/device_trigger.py b/homeassistant/components/select/device_trigger.py index 6dabacf34e5..2c05b59c5d5 100644 --- a/homeassistant/components/select/device_trigger.py +++ b/homeassistant/components/select/device_trigger.py @@ -14,8 +14,8 @@ from homeassistant.components.homeassistant.triggers.state import ( CONF_FOR, CONF_FROM, CONF_TO, - TRIGGER_SCHEMA as STATE_TRIGGER_SCHEMA, async_attach_trigger as async_attach_state_trigger, + async_validate_trigger_config as async_validate_state_trigger_config, ) from homeassistant.components.select.const import ATTR_OPTIONS from homeassistant.const import ( @@ -84,7 +84,7 @@ async def async_attach_trigger( if CONF_FOR in config: state_config[CONF_FOR] = config[CONF_FOR] - state_config = STATE_TRIGGER_SCHEMA(state_config) + state_config = await async_validate_state_trigger_config(hass, state_config) return await async_attach_state_trigger( hass, state_config, action, automation_info, platform_type="device" ) diff --git a/homeassistant/components/vacuum/device_trigger.py b/homeassistant/components/vacuum/device_trigger.py index f4fdbcf972e..25a874a1e69 100644 --- a/homeassistant/components/vacuum/device_trigger.py +++ b/homeassistant/components/vacuum/device_trigger.py @@ -92,7 +92,7 @@ async def async_attach_trigger( } if CONF_FOR in config: state_config[CONF_FOR] = config[CONF_FOR] - state_config = state_trigger.TRIGGER_SCHEMA(state_config) + state_config = await state_trigger.async_validate_trigger_config(hass, state_config) return await state_trigger.async_attach_trigger( hass, state_config, action, automation_info, platform_type="device" ) diff --git a/homeassistant/components/zwave_js/device_trigger.py b/homeassistant/components/zwave_js/device_trigger.py index 368226d36a5..481fc429cb0 100644 --- a/homeassistant/components/zwave_js/device_trigger.py +++ b/homeassistant/components/zwave_js/device_trigger.py @@ -415,7 +415,7 @@ async def async_attach_trigger( else: raise HomeAssistantError(f"Unhandled trigger type {trigger_type}") - state_config = state.TRIGGER_SCHEMA(state_config) + state_config = await state.async_validate_trigger_config(hass, state_config) return await state.async_attach_trigger( hass, state_config, action, automation_info, platform_type="device" ) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 2d38acafadf..8357746c2cd 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections.abc import Callable, Hashable +import contextlib from datetime import ( date as date_sys, datetime as datetime_sys, @@ -262,14 +263,34 @@ def entity_id(value: Any) -> str: raise vol.Invalid(f"Entity ID {value} is an invalid entity ID") -def entity_ids(value: str | list) -> list[str]: - """Validate Entity IDs.""" +def entity_id_or_uuid(value: Any) -> str: + """Validate Entity specified by entity_id or uuid.""" + with contextlib.suppress(vol.Invalid): + return entity_id(value) + with contextlib.suppress(vol.Invalid): + return fake_uuid4_hex(value) + raise vol.Invalid(f"Entity {value} is neither a valid entity ID nor a valid UUID") + + +def _entity_ids(value: str | list, allow_uuid: bool) -> list[str]: + """Help validate entity IDs or UUIDs.""" if value is None: raise vol.Invalid("Entity IDs can not be None") if isinstance(value, str): value = [ent_id.strip() for ent_id in value.split(",")] - return [entity_id(ent_id) for ent_id in value] + validator = entity_id_or_uuid if allow_uuid else entity_id + return [validator(ent_id) for ent_id in value] + + +def entity_ids(value: str | list) -> list[str]: + """Validate Entity IDs.""" + return _entity_ids(value, False) + + +def entity_ids_or_uuids(value: str | list) -> list[str]: + """Validate entities specified by entity IDs or UUIDs.""" + return _entity_ids(value, True) comp_entity_ids = vol.Any( @@ -682,6 +703,16 @@ def uuid4_hex(value: Any) -> str: return result.hex +_FAKE_UUID_4_HEX = re.compile(r"^[0-9a-f]{32}$") + + +def fake_uuid4_hex(value: Any) -> str: + """Validate a fake v4 UUID generated by random_uuid_hex.""" + if not _FAKE_UUID_4_HEX.match(value): + raise vol.Invalid("Invalid UUID") + return cast(str, value) # Pattern.match throws if input is not a string + + def ensure_list_csv(value: Any) -> list: """Ensure that input is a list or make one from comma-separated string.""" if isinstance(value, str): diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 036f235e132..da8223dfec9 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -9,12 +9,13 @@ timer. """ from __future__ import annotations -from collections import OrderedDict +from collections import UserDict from collections.abc import Callable, Iterable, Mapping import logging from typing import TYPE_CHECKING, Any, cast import attr +import voluptuous as vol from homeassistant.const import ( ATTR_DEVICE_CLASS, @@ -161,14 +162,57 @@ class EntityRegistryStore(storage.Store): return await _async_migrate(old_major_version, old_minor_version, old_data) +class EntityRegistryItems(UserDict): + """Container for entity registry items, maps entity_id -> entry. + + Maintains two additional indexes: + - id -> entry + - (domain, platform, unique_id) -> entry + """ + + def __init__(self) -> None: + """Initialize the container.""" + super().__init__() + self._entry_ids: dict[str, RegistryEntry] = {} + self._index: dict[tuple[str, str, str], str] = {} + + def __setitem__(self, key: str, entry: RegistryEntry) -> None: + """Add an item.""" + if key in self: + old_entry = self[key] + del self._entry_ids[old_entry.id] + del self._index[(old_entry.domain, old_entry.platform, old_entry.unique_id)] + super().__setitem__(key, entry) + self._entry_ids.__setitem__(entry.id, entry) + self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id + + def __delitem__(self, key: str) -> None: + """Remove an item.""" + entry = self[key] + self._entry_ids.__delitem__(entry.id) + self._index.__delitem__((entry.domain, entry.platform, entry.unique_id)) + super().__delitem__(key) + + def __getitem__(self, key: str) -> RegistryEntry: + """Get an item.""" + return cast(RegistryEntry, super().__getitem__(key)) + + def get_entity_id(self, key: tuple[str, str, str]) -> str | None: + """Get entity_id from (domain, platform, unique_id).""" + return self._index.get(key) + + def get_entry(self, key: str) -> RegistryEntry | None: + """Get entry from id.""" + return self._entry_ids.get(key) + + class EntityRegistry: """Class to hold a registry of entities.""" def __init__(self, hass: HomeAssistant) -> None: """Initialize the registry.""" self.hass = hass - self.entities: dict[str, RegistryEntry] - self._index: dict[tuple[str, str, str], str] = {} + self.entities: EntityRegistryItems self._store = EntityRegistryStore( hass, STORAGE_VERSION_MAJOR, @@ -218,7 +262,7 @@ class EntityRegistry: self, domain: str, platform: str, unique_id: str ) -> str | None: """Check if an entity_id is currently registered.""" - return self._index.get((domain, platform, unique_id)) + return self.entities.get_entity_id((domain, platform, unique_id)) @callback def async_generate_entity_id( @@ -320,7 +364,7 @@ class EntityRegistry: ): disabled_by = DISABLED_INTEGRATION - entity = RegistryEntry( + entry = RegistryEntry( area_id=area_id, capabilities=capabilities, config_entry_id=config_entry_id, @@ -336,7 +380,7 @@ class EntityRegistry: unique_id=unique_id, unit_of_measurement=unit_of_measurement, ) - self._register_entry(entity) + self.entities[entity_id] = entry _LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id) self.async_schedule_save() @@ -344,12 +388,12 @@ class EntityRegistry: EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id} ) - return entity + return entry @callback def async_remove(self, entity_id: str) -> None: """Remove an entity from registry.""" - self._unregister_entry(self.entities[entity_id]) + self.entities.pop(entity_id) self.hass.bus.async_fire( EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id} ) @@ -513,9 +557,7 @@ class EntityRegistry: if not new_values: return old - self._remove_index(old) - new = attr.evolve(old, **new_values) - self._register_entry(new) + new = self.entities[entity_id] = attr.evolve(old, **new_values) self.async_schedule_save() @@ -539,7 +581,7 @@ class EntityRegistry: old_conf_load_func=load_yaml, old_conf_migrate_func=_async_migrate_yaml_to_json, ) - entities: dict[str, RegistryEntry] = OrderedDict() + entities = EntityRegistryItems() if data is not None: for entity in data["entities"]: @@ -571,7 +613,6 @@ class EntityRegistry: ) self.entities = entities - self._rebuild_index() @callback def async_schedule_save(self) -> None: @@ -626,25 +667,6 @@ class EntityRegistry: if area_id == entry.area_id: self._async_update_entity(entity_id, area_id=None) - def _register_entry(self, entry: RegistryEntry) -> None: - self.entities[entry.entity_id] = entry - self._add_index(entry) - - def _add_index(self, entry: RegistryEntry) -> None: - self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id - - def _unregister_entry(self, entry: RegistryEntry) -> None: - self._remove_index(entry) - del self.entities[entry.entity_id] - - def _remove_index(self, entry: RegistryEntry) -> None: - del self._index[(entry.domain, entry.platform, entry.unique_id)] - - def _rebuild_index(self) -> None: - self._index = {} - for entry in self.entities.values(): - self._add_index(entry) - @callback def async_get(hass: HomeAssistant) -> EntityRegistry: @@ -841,3 +863,25 @@ async def async_migrate_entries( if updates is not None: ent_reg.async_update_entity(entry.entity_id, **updates) + + +@callback +def async_resolve_entity_ids( + registry: EntityRegistry, entity_ids_or_uuids: list[str] +) -> list[str]: + """Resolve a list of entity ids or UUIDs to a list of entity ids.""" + + def resolve_entity(entity_id_or_uuid: str) -> str | None: + """Resolve an entity id or UUID to an entity id or None.""" + if valid_entity_id(entity_id_or_uuid): + return entity_id_or_uuid + if (entry := registry.entities.get_entry(entity_id_or_uuid)) is None: + raise vol.Invalid(f"Unknown entity registry entry {entity_id_or_uuid}") + return entry.entity_id + + tmp = [ + resolved_item + for item in entity_ids_or_uuids + if (resolved_item := resolve_entity(item)) is not None + ] + return tmp diff --git a/script/scaffold/templates/device_trigger/integration/device_trigger.py b/script/scaffold/templates/device_trigger/integration/device_trigger.py index 45c6adb4dcf..9082d27953a 100644 --- a/script/scaffold/templates/device_trigger/integration/device_trigger.py +++ b/script/scaffold/templates/device_trigger/integration/device_trigger.py @@ -10,7 +10,7 @@ from homeassistant.components.automation import ( AutomationTriggerInfo, ) from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA -from homeassistant.components.homeassistant.triggers import state +from homeassistant.components.homeassistant.triggers import state as state_trigger from homeassistant.const import ( CONF_DEVICE_ID, CONF_DOMAIN, @@ -86,11 +86,11 @@ async def async_attach_trigger( to_state = STATE_OFF state_config = { - state.CONF_PLATFORM: "state", + state_trigger.CONF_PLATFORM: "state", CONF_ENTITY_ID: config[CONF_ENTITY_ID], - state.CONF_TO: to_state, + state_trigger.CONF_TO: to_state, } - state_config = state.TRIGGER_SCHEMA(state_config) - return await state.async_attach_trigger( + state_config = await state_trigger.async_validate_trigger_config(hass, state_config) + return await state_trigger.async_attach_trigger( hass, state_config, action, automation_info, platform_type="device" ) diff --git a/tests/common.py b/tests/common.py index 19f0aaec44b..55c76e953cd 100644 --- a/tests/common.py +++ b/tests/common.py @@ -440,8 +440,11 @@ def mock_component(hass, component): def mock_registry(hass, mock_entries=None): """Mock the Entity Registry.""" registry = entity_registry.EntityRegistry(hass) - registry.entities = mock_entries or OrderedDict() - registry._rebuild_index() + if mock_entries is None: + mock_entries = {} + registry.entities = entity_registry.EntityRegistryItems() + for key, entry in mock_entries.items(): + registry.entities[key] = entry hass.data[entity_registry.DATA_REGISTRY] = registry return registry diff --git a/tests/components/homeassistant/triggers/test_state.py b/tests/components/homeassistant/triggers/test_state.py index c86bb0cc879..026f096022b 100644 --- a/tests/components/homeassistant/triggers/test_state.py +++ b/tests/components/homeassistant/triggers/test_state.py @@ -8,6 +8,7 @@ import homeassistant.components.automation as automation from homeassistant.components.homeassistant.triggers import state as state_trigger from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL, SERVICE_TURN_OFF from homeassistant.core import Context +from homeassistant.helpers import entity_registry as er from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -82,6 +83,64 @@ async def test_if_fires_on_entity_change(hass, calls): assert len(calls) == 1 +async def test_if_fires_on_entity_change_uuid(hass, calls): + """Test for firing on entity change.""" + context = Context() + + registry = er.async_get(hass) + entry = registry.async_get_or_create( + "test", "hue", "1234", suggested_object_id="beer" + ) + + assert entry.entity_id == "test.beer" + + hass.states.async_set("test.beer", "hello") + await hass.async_block_till_done() + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "trigger": {"platform": "state", "entity_id": entry.id}, + "action": { + "service": "test.automation", + "data_template": { + "some": "{{ trigger.%s }}" + % "}} - {{ trigger.".join( + ( + "platform", + "entity_id", + "from_state.state", + "to_state.state", + "for", + "id", + ) + ) + }, + }, + } + }, + ) + await hass.async_block_till_done() + + hass.states.async_set("test.beer", "world", context=context) + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0].context.parent_id == context.id + assert calls[0].data["some"] == "state - test.beer - hello - world - None - 0" + + await hass.services.async_call( + automation.DOMAIN, + SERVICE_TURN_OFF, + {ATTR_ENTITY_ID: ENTITY_MATCH_ALL}, + blocking=True, + ) + hass.states.async_set("test.beer", "planet") + await hass.async_block_till_done() + assert len(calls) == 1 + + async def test_if_fires_on_entity_change_with_from_filter(hass, calls): """Test for firing on entity change with filter.""" assert await async_setup_component( diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py index 4b8e4fe9e49..8327eb2e320 100644 --- a/tests/helpers/test_config_validation.py +++ b/tests/helpers/test_config_validation.py @@ -172,9 +172,10 @@ def test_entity_id(): assert schema("sensor.LIGHT") == "sensor.light" -def test_entity_ids(): +@pytest.mark.parametrize("validator", [cv.entity_ids, cv.entity_ids_or_uuids]) +def test_entity_ids(validator): """Test entity ID validation.""" - schema = vol.Schema(cv.entity_ids) + schema = vol.Schema(validator) options = ( "invalid_entity", @@ -194,6 +195,32 @@ def test_entity_ids(): assert schema("sensor.LIGHT, light.kitchen ") == ["sensor.light", "light.kitchen"] +def test_entity_ids_or_uuids(): + """Test entity ID validation.""" + schema = vol.Schema(cv.entity_ids_or_uuids) + + valid_uuid = "a266a680b608c32770e6c45bfe6b8411" + valid_uuid2 = "a266a680b608c32770e6c45bfe6b8412" + invalid_uuid_capital_letters = "A266A680B608C32770E6C45bfE6B8412" + options = ( + "invalid_uuid", + invalid_uuid_capital_letters, + f"{valid_uuid},invalid_uuid", + ["invalid_uuid"], + [valid_uuid, "invalid_uuid"], + [f"{valid_uuid},invalid_uuid"], + ) + for value in options: + with pytest.raises(vol.MultipleInvalid): + schema(value) + + options = ([], [valid_uuid], valid_uuid) + for value in options: + schema(value) + + assert schema(f"{valid_uuid}, {valid_uuid2} ") == [valid_uuid, valid_uuid2] + + def test_entity_domain(): """Test entity domain validation.""" schema = vol.Schema(cv.entity_domain("sensor")) diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 3dc9cf775c4..0bd0abbc92f 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest +import voluptuous as vol from homeassistant import config_entries from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE @@ -1023,3 +1024,60 @@ async def test_entity_max_length_exceeded(hass, registry): assert exc_info.value.property_name == "generated_entity_id" assert exc_info.value.max_length == 255 assert exc_info.value.value == f"sensor.{long_entity_id_name}_2" + + +async def test_resolve_entity_ids(hass, registry): + """Test resolving entity IDs.""" + + entry1 = registry.async_get_or_create( + "light", "hue", "1234", suggested_object_id="beer" + ) + assert entry1.entity_id == "light.beer" + + entry2 = registry.async_get_or_create( + "light", "hue", "2345", suggested_object_id="milk" + ) + assert entry2.entity_id == "light.milk" + + expected = ["light.beer", "light.milk"] + assert er.async_resolve_entity_ids(registry, [entry1.id, entry2.id]) == expected + + expected = ["light.beer", "light.milk"] + assert er.async_resolve_entity_ids(registry, ["light.beer", entry2.id]) == expected + + with pytest.raises(vol.Invalid): + er.async_resolve_entity_ids(registry, ["light.beer", "bad_uuid"]) + + expected = ["light.unknown"] + assert er.async_resolve_entity_ids(registry, ["light.unknown"]) == expected + + with pytest.raises(vol.Invalid): + er.async_resolve_entity_ids(registry, ["unknown_uuid"]) + + +def test_entity_registry_items(): + """Test the EntityRegistryItems container.""" + entities = er.EntityRegistryItems() + assert entities.get_entity_id(("a", "b", "c")) is None + assert entities.get_entry("abc") is None + + entry1 = er.RegistryEntry("test.entity1", "1234", "hue") + entry2 = er.RegistryEntry("test.entity2", "2345", "hue") + entities["test.entity1"] = entry1 + entities["test.entity2"] = entry2 + + assert entities["test.entity1"] is entry1 + assert entities["test.entity2"] is entry2 + + assert entities.get_entity_id(("test", "hue", "1234")) is entry1.entity_id + assert entities.get_entry(entry1.id) is entry1 + assert entities.get_entity_id(("test", "hue", "2345")) is entry2.entity_id + assert entities.get_entry(entry2.id) is entry2 + + entities.pop("test.entity1") + del entities["test.entity2"] + + assert entities.get_entity_id(("test", "hue", "1234")) is None + assert entities.get_entry(entry1.id) is None + assert entities.get_entity_id(("test", "hue", "2345")) is None + assert entities.get_entry(entry2.id) is None diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 6c64b64d6de..23768cd95ce 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -748,6 +748,7 @@ async def test_wait_basic(hass, action_type): "to": "off", } sequence = cv.SCRIPT_SCHEMA(action) + sequence = await script.async_validate_actions_config(hass, sequence) script_obj = script.Script(hass, sequence, "Test Name", "test_domain") wait_started_flag = async_watch_for_action(script_obj, wait_alias) @@ -848,6 +849,7 @@ async def test_wait_basic_times_out(hass, action_type): "to": "off", } sequence = cv.SCRIPT_SCHEMA(action) + sequence = await script.async_validate_actions_config(hass, sequence) script_obj = script.Script(hass, sequence, "Test Name", "test_domain") wait_started_flag = async_watch_for_action(script_obj, wait_alias) timed_out = False @@ -904,6 +906,7 @@ async def test_multiple_runs_wait(hass, action_type): {"event": event, "event_data": {"value": 2}}, ] ) + sequence = await script.async_validate_actions_config(hass, sequence) script_obj = script.Script( hass, sequence, "Test Name", "test_domain", script_mode="parallel", max_runs=2 ) @@ -952,6 +955,7 @@ async def test_cancel_wait(hass, action_type): } } sequence = cv.SCRIPT_SCHEMA([action, {"event": event}]) + sequence = await script.async_validate_actions_config(hass, sequence) script_obj = script.Script(hass, sequence, "Test Name", "test_domain") wait_started_flag = async_watch_for_action(script_obj, "wait") @@ -1049,6 +1053,7 @@ async def test_wait_timeout(hass, caplog, timeout_param, action_type): action["timeout"] = timeout_param action["continue_on_timeout"] = True sequence = cv.SCRIPT_SCHEMA([action, {"event": event}]) + sequence = await script.async_validate_actions_config(hass, sequence) script_obj = script.Script(hass, sequence, "Test Name", "test_domain") wait_started_flag = async_watch_for_action(script_obj, "wait") @@ -1116,6 +1121,7 @@ async def test_wait_continue_on_timeout( if continue_on_timeout is not None: action["continue_on_timeout"] = continue_on_timeout sequence = cv.SCRIPT_SCHEMA([action, {"event": event}]) + sequence = await script.async_validate_actions_config(hass, sequence) script_obj = script.Script(hass, sequence, "Test Name", "test_domain") wait_started_flag = async_watch_for_action(script_obj, "wait") @@ -1287,6 +1293,7 @@ async def test_wait_variables_out(hass, mode, action_type): }, ] sequence = cv.SCRIPT_SCHEMA(sequence) + sequence = await script.async_validate_actions_config(hass, sequence) script_obj = script.Script(hass, sequence, "Test Name", "test_domain") wait_started_flag = async_watch_for_action(script_obj, "wait") @@ -1326,11 +1333,13 @@ async def test_wait_variables_out(hass, mode, action_type): async def test_wait_for_trigger_bad(hass, caplog): """Test bad wait_for_trigger.""" + sequence = cv.SCRIPT_SCHEMA( + {"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}} + ) + sequence = await script.async_validate_actions_config(hass, sequence) script_obj = script.Script( hass, - cv.SCRIPT_SCHEMA( - {"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}} - ), + sequence, "Test Name", "test_domain", ) @@ -1356,11 +1365,13 @@ async def test_wait_for_trigger_bad(hass, caplog): async def test_wait_for_trigger_generated_exception(hass, caplog): """Test bad wait_for_trigger.""" + sequence = cv.SCRIPT_SCHEMA( + {"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}} + ) + sequence = await script.async_validate_actions_config(hass, sequence) script_obj = script.Script( hass, - cv.SCRIPT_SCHEMA( - {"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}} - ), + sequence, "Test Name", "test_domain", )