Teach state trigger about entity registry ids (#60271)
* Teach state trigger about entity registry ids * Tweak * Add tests * Tweak tests * Fix tests * Resolve entity ids during config validation * Update device_triggers * Fix mistake * Tweak trigger validator to ensure we don't modify the original config * Add index from entry id to entry * Update scaffold * Pre-compile UUID regex * Address review comment * Tweak mock_registry * Tweak * Apply suggestion from code review
This commit is contained in:
parent
c0fb1bffce
commit
c85bb27d0d
20 changed files with 324 additions and 74 deletions
|
@ -157,7 +157,7 @@ async def async_attach_trigger(
|
||||||
}
|
}
|
||||||
if CONF_FOR in config:
|
if CONF_FOR in config:
|
||||||
state_config[CONF_FOR] = config[CONF_FOR]
|
state_config[CONF_FOR] = config[CONF_FOR]
|
||||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
|
||||||
return await state_trigger.async_attach_trigger(
|
return await state_trigger.async_attach_trigger(
|
||||||
hass, state_config, action, automation_info, platform_type="device"
|
hass, state_config, action, automation_info, platform_type="device"
|
||||||
)
|
)
|
||||||
|
|
|
@ -220,7 +220,7 @@ async def async_attach_trigger(hass, config, action, automation_info):
|
||||||
if CONF_FOR in config:
|
if CONF_FOR in config:
|
||||||
state_config[CONF_FOR] = config[CONF_FOR]
|
state_config[CONF_FOR] = config[CONF_FOR]
|
||||||
|
|
||||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
|
||||||
return await state_trigger.async_attach_trigger(
|
return await state_trigger.async_attach_trigger(
|
||||||
hass, state_config, action, automation_info, platform_type="device"
|
hass, state_config, action, automation_info, platform_type="device"
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,8 +11,8 @@ from homeassistant.components.automation import (
|
||||||
)
|
)
|
||||||
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
|
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
|
||||||
from homeassistant.components.homeassistant.triggers.state import (
|
from homeassistant.components.homeassistant.triggers.state import (
|
||||||
TRIGGER_SCHEMA as STATE_TRIGGER_SCHEMA,
|
|
||||||
async_attach_trigger as async_attach_state_trigger,
|
async_attach_trigger as async_attach_state_trigger,
|
||||||
|
async_validate_trigger_config as async_validate_state_trigger_config,
|
||||||
)
|
)
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
CONF_DEVICE_ID,
|
CONF_DEVICE_ID,
|
||||||
|
@ -67,7 +67,7 @@ async def async_attach_trigger(
|
||||||
CONF_ENTITY_ID: config[CONF_ENTITY_ID],
|
CONF_ENTITY_ID: config[CONF_ENTITY_ID],
|
||||||
}
|
}
|
||||||
|
|
||||||
state_config = STATE_TRIGGER_SCHEMA(state_config)
|
state_config = await async_validate_state_trigger_config(hass, state_config)
|
||||||
return await async_attach_state_trigger(
|
return await async_attach_state_trigger(
|
||||||
hass, state_config, action, automation_info, platform_type="device"
|
hass, state_config, action, automation_info, platform_type="device"
|
||||||
)
|
)
|
||||||
|
|
|
@ -131,7 +131,9 @@ async def async_attach_trigger(
|
||||||
}
|
}
|
||||||
if CONF_FOR in config:
|
if CONF_FOR in config:
|
||||||
state_config[CONF_FOR] = config[CONF_FOR]
|
state_config[CONF_FOR] = config[CONF_FOR]
|
||||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
state_config = await state_trigger.async_validate_trigger_config(
|
||||||
|
hass, state_config
|
||||||
|
)
|
||||||
return await state_trigger.async_attach_trigger(
|
return await state_trigger.async_attach_trigger(
|
||||||
hass, state_config, action, automation_info, platform_type="device"
|
hass, state_config, action, automation_info, platform_type="device"
|
||||||
)
|
)
|
||||||
|
|
|
@ -170,7 +170,9 @@ async def async_attach_trigger(
|
||||||
}
|
}
|
||||||
if CONF_FOR in config:
|
if CONF_FOR in config:
|
||||||
state_config[CONF_FOR] = config[CONF_FOR]
|
state_config[CONF_FOR] = config[CONF_FOR]
|
||||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
state_config = await state_trigger.async_validate_trigger_config(
|
||||||
|
hass, state_config
|
||||||
|
)
|
||||||
return await state_trigger.async_attach_trigger(
|
return await state_trigger.async_attach_trigger(
|
||||||
hass, state_config, action, automation_info, platform_type="device"
|
hass, state_config, action, automation_info, platform_type="device"
|
||||||
)
|
)
|
||||||
|
|
|
@ -164,7 +164,7 @@ async def async_attach_trigger(
|
||||||
if CONF_FOR in config:
|
if CONF_FOR in config:
|
||||||
state_config[CONF_FOR] = config[CONF_FOR]
|
state_config[CONF_FOR] = config[CONF_FOR]
|
||||||
|
|
||||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
|
||||||
return await state_trigger.async_attach_trigger(
|
return await state_trigger.async_attach_trigger(
|
||||||
hass, state_config, action, automation_info, platform_type="device"
|
hass, state_config, action, automation_info, platform_type="device"
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,20 +3,24 @@ from __future__ import annotations
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import exceptions
|
from homeassistant import exceptions
|
||||||
from homeassistant.const import CONF_ATTRIBUTE, CONF_FOR, CONF_PLATFORM, MATCH_ALL
|
from homeassistant.const import CONF_ATTRIBUTE, CONF_FOR, CONF_PLATFORM, MATCH_ALL
|
||||||
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, State, callback
|
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, State, callback
|
||||||
from homeassistant.helpers import config_validation as cv, template
|
from homeassistant.helpers import (
|
||||||
|
config_validation as cv,
|
||||||
|
entity_registry as er,
|
||||||
|
template,
|
||||||
|
)
|
||||||
from homeassistant.helpers.event import (
|
from homeassistant.helpers.event import (
|
||||||
Event,
|
Event,
|
||||||
async_track_same_state,
|
async_track_same_state,
|
||||||
async_track_state_change_event,
|
async_track_state_change_event,
|
||||||
process_state_match,
|
process_state_match,
|
||||||
)
|
)
|
||||||
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs
|
# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs
|
||||||
# mypy: no-check-untyped-defs
|
# mypy: no-check-untyped-defs
|
||||||
|
@ -30,7 +34,7 @@ CONF_TO = "to"
|
||||||
BASE_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend(
|
BASE_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend(
|
||||||
{
|
{
|
||||||
vol.Required(CONF_PLATFORM): "state",
|
vol.Required(CONF_PLATFORM): "state",
|
||||||
vol.Required(CONF_ENTITY_ID): cv.entity_ids,
|
vol.Required(CONF_ENTITY_ID): cv.entity_ids_or_uuids,
|
||||||
vol.Optional(CONF_FOR): cv.positive_time_period_template,
|
vol.Optional(CONF_FOR): cv.positive_time_period_template,
|
||||||
vol.Optional(CONF_ATTRIBUTE): cv.match_all,
|
vol.Optional(CONF_ATTRIBUTE): cv.match_all,
|
||||||
}
|
}
|
||||||
|
@ -52,17 +56,26 @@ TRIGGER_ATTRIBUTE_SCHEMA = BASE_SCHEMA.extend(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def TRIGGER_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name
|
async def async_validate_trigger_config(
|
||||||
"""Validate trigger."""
|
hass: HomeAssistant, config: ConfigType
|
||||||
if not isinstance(value, dict):
|
) -> ConfigType:
|
||||||
|
"""Validate trigger config."""
|
||||||
|
if not isinstance(config, dict):
|
||||||
raise vol.Invalid("Expected a dictionary")
|
raise vol.Invalid("Expected a dictionary")
|
||||||
|
|
||||||
# We use this approach instead of vol.Any because
|
# We use this approach instead of vol.Any because
|
||||||
# this gives better error messages.
|
# this gives better error messages.
|
||||||
if CONF_ATTRIBUTE in value:
|
if CONF_ATTRIBUTE in config:
|
||||||
return TRIGGER_ATTRIBUTE_SCHEMA(value)
|
config = TRIGGER_ATTRIBUTE_SCHEMA(config)
|
||||||
|
else:
|
||||||
|
config = TRIGGER_STATE_SCHEMA(config)
|
||||||
|
|
||||||
return TRIGGER_STATE_SCHEMA(value)
|
registry = er.async_get(hass)
|
||||||
|
config[CONF_ENTITY_ID] = er.async_resolve_entity_ids(
|
||||||
|
registry, cv.entity_ids_or_uuids(config[CONF_ENTITY_ID])
|
||||||
|
)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
async def async_attach_trigger(
|
async def async_attach_trigger(
|
||||||
|
@ -74,7 +87,7 @@ async def async_attach_trigger(
|
||||||
platform_type: str = "state",
|
platform_type: str = "state",
|
||||||
) -> CALLBACK_TYPE:
|
) -> CALLBACK_TYPE:
|
||||||
"""Listen for state changes based on configuration."""
|
"""Listen for state changes based on configuration."""
|
||||||
entity_id = config.get(CONF_ENTITY_ID)
|
entity_ids = config[CONF_ENTITY_ID]
|
||||||
if (from_state := config.get(CONF_FROM)) is None:
|
if (from_state := config.get(CONF_FROM)) is None:
|
||||||
from_state = MATCH_ALL
|
from_state = MATCH_ALL
|
||||||
if (to_state := config.get(CONF_TO)) is None:
|
if (to_state := config.get(CONF_TO)) is None:
|
||||||
|
@ -196,7 +209,7 @@ async def async_attach_trigger(
|
||||||
entity_ids=entity,
|
entity_ids=entity,
|
||||||
)
|
)
|
||||||
|
|
||||||
unsub = async_track_state_change_event(hass, entity_id, state_automation_listener)
|
unsub = async_track_state_change_event(hass, entity_ids, state_automation_listener)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_remove():
|
def async_remove():
|
||||||
|
|
|
@ -104,7 +104,7 @@ async def async_attach_trigger(
|
||||||
}
|
}
|
||||||
if CONF_FOR in config:
|
if CONF_FOR in config:
|
||||||
state_config[CONF_FOR] = config[CONF_FOR]
|
state_config[CONF_FOR] = config[CONF_FOR]
|
||||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
|
||||||
return await state_trigger.async_attach_trigger(
|
return await state_trigger.async_attach_trigger(
|
||||||
hass, state_config, action, automation_info, platform_type="device"
|
hass, state_config, action, automation_info, platform_type="device"
|
||||||
)
|
)
|
||||||
|
|
|
@ -104,7 +104,7 @@ async def async_attach_trigger(
|
||||||
}
|
}
|
||||||
if CONF_FOR in config:
|
if CONF_FOR in config:
|
||||||
state_config[CONF_FOR] = config[CONF_FOR]
|
state_config[CONF_FOR] = config[CONF_FOR]
|
||||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
|
||||||
return await state_trigger.async_attach_trigger(
|
return await state_trigger.async_attach_trigger(
|
||||||
hass, state_config, action, automation_info, platform_type="device"
|
hass, state_config, action, automation_info, platform_type="device"
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,8 +14,8 @@ from homeassistant.components.homeassistant.triggers.state import (
|
||||||
CONF_FOR,
|
CONF_FOR,
|
||||||
CONF_FROM,
|
CONF_FROM,
|
||||||
CONF_TO,
|
CONF_TO,
|
||||||
TRIGGER_SCHEMA as STATE_TRIGGER_SCHEMA,
|
|
||||||
async_attach_trigger as async_attach_state_trigger,
|
async_attach_trigger as async_attach_state_trigger,
|
||||||
|
async_validate_trigger_config as async_validate_state_trigger_config,
|
||||||
)
|
)
|
||||||
from homeassistant.components.select.const import ATTR_OPTIONS
|
from homeassistant.components.select.const import ATTR_OPTIONS
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
|
@ -84,7 +84,7 @@ async def async_attach_trigger(
|
||||||
if CONF_FOR in config:
|
if CONF_FOR in config:
|
||||||
state_config[CONF_FOR] = config[CONF_FOR]
|
state_config[CONF_FOR] = config[CONF_FOR]
|
||||||
|
|
||||||
state_config = STATE_TRIGGER_SCHEMA(state_config)
|
state_config = await async_validate_state_trigger_config(hass, state_config)
|
||||||
return await async_attach_state_trigger(
|
return await async_attach_state_trigger(
|
||||||
hass, state_config, action, automation_info, platform_type="device"
|
hass, state_config, action, automation_info, platform_type="device"
|
||||||
)
|
)
|
||||||
|
|
|
@ -92,7 +92,7 @@ async def async_attach_trigger(
|
||||||
}
|
}
|
||||||
if CONF_FOR in config:
|
if CONF_FOR in config:
|
||||||
state_config[CONF_FOR] = config[CONF_FOR]
|
state_config[CONF_FOR] = config[CONF_FOR]
|
||||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
|
||||||
return await state_trigger.async_attach_trigger(
|
return await state_trigger.async_attach_trigger(
|
||||||
hass, state_config, action, automation_info, platform_type="device"
|
hass, state_config, action, automation_info, platform_type="device"
|
||||||
)
|
)
|
||||||
|
|
|
@ -415,7 +415,7 @@ async def async_attach_trigger(
|
||||||
else:
|
else:
|
||||||
raise HomeAssistantError(f"Unhandled trigger type {trigger_type}")
|
raise HomeAssistantError(f"Unhandled trigger type {trigger_type}")
|
||||||
|
|
||||||
state_config = state.TRIGGER_SCHEMA(state_config)
|
state_config = await state.async_validate_trigger_config(hass, state_config)
|
||||||
return await state.async_attach_trigger(
|
return await state.async_attach_trigger(
|
||||||
hass, state_config, action, automation_info, platform_type="device"
|
hass, state_config, action, automation_info, platform_type="device"
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable, Hashable
|
from collections.abc import Callable, Hashable
|
||||||
|
import contextlib
|
||||||
from datetime import (
|
from datetime import (
|
||||||
date as date_sys,
|
date as date_sys,
|
||||||
datetime as datetime_sys,
|
datetime as datetime_sys,
|
||||||
|
@ -262,14 +263,34 @@ def entity_id(value: Any) -> str:
|
||||||
raise vol.Invalid(f"Entity ID {value} is an invalid entity ID")
|
raise vol.Invalid(f"Entity ID {value} is an invalid entity ID")
|
||||||
|
|
||||||
|
|
||||||
def entity_ids(value: str | list) -> list[str]:
|
def entity_id_or_uuid(value: Any) -> str:
|
||||||
"""Validate Entity IDs."""
|
"""Validate Entity specified by entity_id or uuid."""
|
||||||
|
with contextlib.suppress(vol.Invalid):
|
||||||
|
return entity_id(value)
|
||||||
|
with contextlib.suppress(vol.Invalid):
|
||||||
|
return fake_uuid4_hex(value)
|
||||||
|
raise vol.Invalid(f"Entity {value} is neither a valid entity ID nor a valid UUID")
|
||||||
|
|
||||||
|
|
||||||
|
def _entity_ids(value: str | list, allow_uuid: bool) -> list[str]:
|
||||||
|
"""Help validate entity IDs or UUIDs."""
|
||||||
if value is None:
|
if value is None:
|
||||||
raise vol.Invalid("Entity IDs can not be None")
|
raise vol.Invalid("Entity IDs can not be None")
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
value = [ent_id.strip() for ent_id in value.split(",")]
|
value = [ent_id.strip() for ent_id in value.split(",")]
|
||||||
|
|
||||||
return [entity_id(ent_id) for ent_id in value]
|
validator = entity_id_or_uuid if allow_uuid else entity_id
|
||||||
|
return [validator(ent_id) for ent_id in value]
|
||||||
|
|
||||||
|
|
||||||
|
def entity_ids(value: str | list) -> list[str]:
|
||||||
|
"""Validate Entity IDs."""
|
||||||
|
return _entity_ids(value, False)
|
||||||
|
|
||||||
|
|
||||||
|
def entity_ids_or_uuids(value: str | list) -> list[str]:
|
||||||
|
"""Validate entities specified by entity IDs or UUIDs."""
|
||||||
|
return _entity_ids(value, True)
|
||||||
|
|
||||||
|
|
||||||
comp_entity_ids = vol.Any(
|
comp_entity_ids = vol.Any(
|
||||||
|
@ -682,6 +703,16 @@ def uuid4_hex(value: Any) -> str:
|
||||||
return result.hex
|
return result.hex
|
||||||
|
|
||||||
|
|
||||||
|
_FAKE_UUID_4_HEX = re.compile(r"^[0-9a-f]{32}$")
|
||||||
|
|
||||||
|
|
||||||
|
def fake_uuid4_hex(value: Any) -> str:
|
||||||
|
"""Validate a fake v4 UUID generated by random_uuid_hex."""
|
||||||
|
if not _FAKE_UUID_4_HEX.match(value):
|
||||||
|
raise vol.Invalid("Invalid UUID")
|
||||||
|
return cast(str, value) # Pattern.match throws if input is not a string
|
||||||
|
|
||||||
|
|
||||||
def ensure_list_csv(value: Any) -> list:
|
def ensure_list_csv(value: Any) -> list:
|
||||||
"""Ensure that input is a list or make one from comma-separated string."""
|
"""Ensure that input is a list or make one from comma-separated string."""
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
|
|
|
@ -9,12 +9,13 @@ timer.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import UserDict
|
||||||
from collections.abc import Callable, Iterable, Mapping
|
from collections.abc import Callable, Iterable, Mapping
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_DEVICE_CLASS,
|
ATTR_DEVICE_CLASS,
|
||||||
|
@ -161,14 +162,57 @@ class EntityRegistryStore(storage.Store):
|
||||||
return await _async_migrate(old_major_version, old_minor_version, old_data)
|
return await _async_migrate(old_major_version, old_minor_version, old_data)
|
||||||
|
|
||||||
|
|
||||||
|
class EntityRegistryItems(UserDict):
|
||||||
|
"""Container for entity registry items, maps entity_id -> entry.
|
||||||
|
|
||||||
|
Maintains two additional indexes:
|
||||||
|
- id -> entry
|
||||||
|
- (domain, platform, unique_id) -> entry
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the container."""
|
||||||
|
super().__init__()
|
||||||
|
self._entry_ids: dict[str, RegistryEntry] = {}
|
||||||
|
self._index: dict[tuple[str, str, str], str] = {}
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, entry: RegistryEntry) -> None:
|
||||||
|
"""Add an item."""
|
||||||
|
if key in self:
|
||||||
|
old_entry = self[key]
|
||||||
|
del self._entry_ids[old_entry.id]
|
||||||
|
del self._index[(old_entry.domain, old_entry.platform, old_entry.unique_id)]
|
||||||
|
super().__setitem__(key, entry)
|
||||||
|
self._entry_ids.__setitem__(entry.id, entry)
|
||||||
|
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
|
||||||
|
|
||||||
|
def __delitem__(self, key: str) -> None:
|
||||||
|
"""Remove an item."""
|
||||||
|
entry = self[key]
|
||||||
|
self._entry_ids.__delitem__(entry.id)
|
||||||
|
self._index.__delitem__((entry.domain, entry.platform, entry.unique_id))
|
||||||
|
super().__delitem__(key)
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> RegistryEntry:
|
||||||
|
"""Get an item."""
|
||||||
|
return cast(RegistryEntry, super().__getitem__(key))
|
||||||
|
|
||||||
|
def get_entity_id(self, key: tuple[str, str, str]) -> str | None:
|
||||||
|
"""Get entity_id from (domain, platform, unique_id)."""
|
||||||
|
return self._index.get(key)
|
||||||
|
|
||||||
|
def get_entry(self, key: str) -> RegistryEntry | None:
|
||||||
|
"""Get entry from id."""
|
||||||
|
return self._entry_ids.get(key)
|
||||||
|
|
||||||
|
|
||||||
class EntityRegistry:
|
class EntityRegistry:
|
||||||
"""Class to hold a registry of entities."""
|
"""Class to hold a registry of entities."""
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize the registry."""
|
"""Initialize the registry."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.entities: dict[str, RegistryEntry]
|
self.entities: EntityRegistryItems
|
||||||
self._index: dict[tuple[str, str, str], str] = {}
|
|
||||||
self._store = EntityRegistryStore(
|
self._store = EntityRegistryStore(
|
||||||
hass,
|
hass,
|
||||||
STORAGE_VERSION_MAJOR,
|
STORAGE_VERSION_MAJOR,
|
||||||
|
@ -218,7 +262,7 @@ class EntityRegistry:
|
||||||
self, domain: str, platform: str, unique_id: str
|
self, domain: str, platform: str, unique_id: str
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Check if an entity_id is currently registered."""
|
"""Check if an entity_id is currently registered."""
|
||||||
return self._index.get((domain, platform, unique_id))
|
return self.entities.get_entity_id((domain, platform, unique_id))
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_generate_entity_id(
|
def async_generate_entity_id(
|
||||||
|
@ -320,7 +364,7 @@ class EntityRegistry:
|
||||||
):
|
):
|
||||||
disabled_by = DISABLED_INTEGRATION
|
disabled_by = DISABLED_INTEGRATION
|
||||||
|
|
||||||
entity = RegistryEntry(
|
entry = RegistryEntry(
|
||||||
area_id=area_id,
|
area_id=area_id,
|
||||||
capabilities=capabilities,
|
capabilities=capabilities,
|
||||||
config_entry_id=config_entry_id,
|
config_entry_id=config_entry_id,
|
||||||
|
@ -336,7 +380,7 @@ class EntityRegistry:
|
||||||
unique_id=unique_id,
|
unique_id=unique_id,
|
||||||
unit_of_measurement=unit_of_measurement,
|
unit_of_measurement=unit_of_measurement,
|
||||||
)
|
)
|
||||||
self._register_entry(entity)
|
self.entities[entity_id] = entry
|
||||||
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
|
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
|
||||||
self.async_schedule_save()
|
self.async_schedule_save()
|
||||||
|
|
||||||
|
@ -344,12 +388,12 @@ class EntityRegistry:
|
||||||
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id}
|
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
return entity
|
return entry
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_remove(self, entity_id: str) -> None:
|
def async_remove(self, entity_id: str) -> None:
|
||||||
"""Remove an entity from registry."""
|
"""Remove an entity from registry."""
|
||||||
self._unregister_entry(self.entities[entity_id])
|
self.entities.pop(entity_id)
|
||||||
self.hass.bus.async_fire(
|
self.hass.bus.async_fire(
|
||||||
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id}
|
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id}
|
||||||
)
|
)
|
||||||
|
@ -513,9 +557,7 @@ class EntityRegistry:
|
||||||
if not new_values:
|
if not new_values:
|
||||||
return old
|
return old
|
||||||
|
|
||||||
self._remove_index(old)
|
new = self.entities[entity_id] = attr.evolve(old, **new_values)
|
||||||
new = attr.evolve(old, **new_values)
|
|
||||||
self._register_entry(new)
|
|
||||||
|
|
||||||
self.async_schedule_save()
|
self.async_schedule_save()
|
||||||
|
|
||||||
|
@ -539,7 +581,7 @@ class EntityRegistry:
|
||||||
old_conf_load_func=load_yaml,
|
old_conf_load_func=load_yaml,
|
||||||
old_conf_migrate_func=_async_migrate_yaml_to_json,
|
old_conf_migrate_func=_async_migrate_yaml_to_json,
|
||||||
)
|
)
|
||||||
entities: dict[str, RegistryEntry] = OrderedDict()
|
entities = EntityRegistryItems()
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
for entity in data["entities"]:
|
for entity in data["entities"]:
|
||||||
|
@ -571,7 +613,6 @@ class EntityRegistry:
|
||||||
)
|
)
|
||||||
|
|
||||||
self.entities = entities
|
self.entities = entities
|
||||||
self._rebuild_index()
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_schedule_save(self) -> None:
|
def async_schedule_save(self) -> None:
|
||||||
|
@ -626,25 +667,6 @@ class EntityRegistry:
|
||||||
if area_id == entry.area_id:
|
if area_id == entry.area_id:
|
||||||
self._async_update_entity(entity_id, area_id=None)
|
self._async_update_entity(entity_id, area_id=None)
|
||||||
|
|
||||||
def _register_entry(self, entry: RegistryEntry) -> None:
|
|
||||||
self.entities[entry.entity_id] = entry
|
|
||||||
self._add_index(entry)
|
|
||||||
|
|
||||||
def _add_index(self, entry: RegistryEntry) -> None:
|
|
||||||
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
|
|
||||||
|
|
||||||
def _unregister_entry(self, entry: RegistryEntry) -> None:
|
|
||||||
self._remove_index(entry)
|
|
||||||
del self.entities[entry.entity_id]
|
|
||||||
|
|
||||||
def _remove_index(self, entry: RegistryEntry) -> None:
|
|
||||||
del self._index[(entry.domain, entry.platform, entry.unique_id)]
|
|
||||||
|
|
||||||
def _rebuild_index(self) -> None:
|
|
||||||
self._index = {}
|
|
||||||
for entry in self.entities.values():
|
|
||||||
self._add_index(entry)
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get(hass: HomeAssistant) -> EntityRegistry:
|
def async_get(hass: HomeAssistant) -> EntityRegistry:
|
||||||
|
@ -841,3 +863,25 @@ async def async_migrate_entries(
|
||||||
|
|
||||||
if updates is not None:
|
if updates is not None:
|
||||||
ent_reg.async_update_entity(entry.entity_id, **updates)
|
ent_reg.async_update_entity(entry.entity_id, **updates)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_resolve_entity_ids(
|
||||||
|
registry: EntityRegistry, entity_ids_or_uuids: list[str]
|
||||||
|
) -> list[str]:
|
||||||
|
"""Resolve a list of entity ids or UUIDs to a list of entity ids."""
|
||||||
|
|
||||||
|
def resolve_entity(entity_id_or_uuid: str) -> str | None:
|
||||||
|
"""Resolve an entity id or UUID to an entity id or None."""
|
||||||
|
if valid_entity_id(entity_id_or_uuid):
|
||||||
|
return entity_id_or_uuid
|
||||||
|
if (entry := registry.entities.get_entry(entity_id_or_uuid)) is None:
|
||||||
|
raise vol.Invalid(f"Unknown entity registry entry {entity_id_or_uuid}")
|
||||||
|
return entry.entity_id
|
||||||
|
|
||||||
|
tmp = [
|
||||||
|
resolved_item
|
||||||
|
for item in entity_ids_or_uuids
|
||||||
|
if (resolved_item := resolve_entity(item)) is not None
|
||||||
|
]
|
||||||
|
return tmp
|
||||||
|
|
|
@ -10,7 +10,7 @@ from homeassistant.components.automation import (
|
||||||
AutomationTriggerInfo,
|
AutomationTriggerInfo,
|
||||||
)
|
)
|
||||||
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
|
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
|
||||||
from homeassistant.components.homeassistant.triggers import state
|
from homeassistant.components.homeassistant.triggers import state as state_trigger
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
CONF_DEVICE_ID,
|
CONF_DEVICE_ID,
|
||||||
CONF_DOMAIN,
|
CONF_DOMAIN,
|
||||||
|
@ -86,11 +86,11 @@ async def async_attach_trigger(
|
||||||
to_state = STATE_OFF
|
to_state = STATE_OFF
|
||||||
|
|
||||||
state_config = {
|
state_config = {
|
||||||
state.CONF_PLATFORM: "state",
|
state_trigger.CONF_PLATFORM: "state",
|
||||||
CONF_ENTITY_ID: config[CONF_ENTITY_ID],
|
CONF_ENTITY_ID: config[CONF_ENTITY_ID],
|
||||||
state.CONF_TO: to_state,
|
state_trigger.CONF_TO: to_state,
|
||||||
}
|
}
|
||||||
state_config = state.TRIGGER_SCHEMA(state_config)
|
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
|
||||||
return await state.async_attach_trigger(
|
return await state_trigger.async_attach_trigger(
|
||||||
hass, state_config, action, automation_info, platform_type="device"
|
hass, state_config, action, automation_info, platform_type="device"
|
||||||
)
|
)
|
||||||
|
|
|
@ -440,8 +440,11 @@ def mock_component(hass, component):
|
||||||
def mock_registry(hass, mock_entries=None):
|
def mock_registry(hass, mock_entries=None):
|
||||||
"""Mock the Entity Registry."""
|
"""Mock the Entity Registry."""
|
||||||
registry = entity_registry.EntityRegistry(hass)
|
registry = entity_registry.EntityRegistry(hass)
|
||||||
registry.entities = mock_entries or OrderedDict()
|
if mock_entries is None:
|
||||||
registry._rebuild_index()
|
mock_entries = {}
|
||||||
|
registry.entities = entity_registry.EntityRegistryItems()
|
||||||
|
for key, entry in mock_entries.items():
|
||||||
|
registry.entities[key] = entry
|
||||||
|
|
||||||
hass.data[entity_registry.DATA_REGISTRY] = registry
|
hass.data[entity_registry.DATA_REGISTRY] = registry
|
||||||
return registry
|
return registry
|
||||||
|
|
|
@ -8,6 +8,7 @@ import homeassistant.components.automation as automation
|
||||||
from homeassistant.components.homeassistant.triggers import state as state_trigger
|
from homeassistant.components.homeassistant.triggers import state as state_trigger
|
||||||
from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL, SERVICE_TURN_OFF
|
from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL, SERVICE_TURN_OFF
|
||||||
from homeassistant.core import Context
|
from homeassistant.core import Context
|
||||||
|
from homeassistant.helpers import entity_registry as er
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
|
||||||
|
@ -82,6 +83,64 @@ async def test_if_fires_on_entity_change(hass, calls):
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_if_fires_on_entity_change_uuid(hass, calls):
|
||||||
|
"""Test for firing on entity change."""
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
registry = er.async_get(hass)
|
||||||
|
entry = registry.async_get_or_create(
|
||||||
|
"test", "hue", "1234", suggested_object_id="beer"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert entry.entity_id == "test.beer"
|
||||||
|
|
||||||
|
hass.states.async_set("test.beer", "hello")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
automation.DOMAIN,
|
||||||
|
{
|
||||||
|
automation.DOMAIN: {
|
||||||
|
"trigger": {"platform": "state", "entity_id": entry.id},
|
||||||
|
"action": {
|
||||||
|
"service": "test.automation",
|
||||||
|
"data_template": {
|
||||||
|
"some": "{{ trigger.%s }}"
|
||||||
|
% "}} - {{ trigger.".join(
|
||||||
|
(
|
||||||
|
"platform",
|
||||||
|
"entity_id",
|
||||||
|
"from_state.state",
|
||||||
|
"to_state.state",
|
||||||
|
"for",
|
||||||
|
"id",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
hass.states.async_set("test.beer", "world", context=context)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 1
|
||||||
|
assert calls[0].context.parent_id == context.id
|
||||||
|
assert calls[0].data["some"] == "state - test.beer - hello - world - None - 0"
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
automation.DOMAIN,
|
||||||
|
SERVICE_TURN_OFF,
|
||||||
|
{ATTR_ENTITY_ID: ENTITY_MATCH_ALL},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
hass.states.async_set("test.beer", "planet")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_if_fires_on_entity_change_with_from_filter(hass, calls):
|
async def test_if_fires_on_entity_change_with_from_filter(hass, calls):
|
||||||
"""Test for firing on entity change with filter."""
|
"""Test for firing on entity change with filter."""
|
||||||
assert await async_setup_component(
|
assert await async_setup_component(
|
||||||
|
|
|
@ -172,9 +172,10 @@ def test_entity_id():
|
||||||
assert schema("sensor.LIGHT") == "sensor.light"
|
assert schema("sensor.LIGHT") == "sensor.light"
|
||||||
|
|
||||||
|
|
||||||
def test_entity_ids():
|
@pytest.mark.parametrize("validator", [cv.entity_ids, cv.entity_ids_or_uuids])
|
||||||
|
def test_entity_ids(validator):
|
||||||
"""Test entity ID validation."""
|
"""Test entity ID validation."""
|
||||||
schema = vol.Schema(cv.entity_ids)
|
schema = vol.Schema(validator)
|
||||||
|
|
||||||
options = (
|
options = (
|
||||||
"invalid_entity",
|
"invalid_entity",
|
||||||
|
@ -194,6 +195,32 @@ def test_entity_ids():
|
||||||
assert schema("sensor.LIGHT, light.kitchen ") == ["sensor.light", "light.kitchen"]
|
assert schema("sensor.LIGHT, light.kitchen ") == ["sensor.light", "light.kitchen"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_entity_ids_or_uuids():
|
||||||
|
"""Test entity ID validation."""
|
||||||
|
schema = vol.Schema(cv.entity_ids_or_uuids)
|
||||||
|
|
||||||
|
valid_uuid = "a266a680b608c32770e6c45bfe6b8411"
|
||||||
|
valid_uuid2 = "a266a680b608c32770e6c45bfe6b8412"
|
||||||
|
invalid_uuid_capital_letters = "A266A680B608C32770E6C45bfE6B8412"
|
||||||
|
options = (
|
||||||
|
"invalid_uuid",
|
||||||
|
invalid_uuid_capital_letters,
|
||||||
|
f"{valid_uuid},invalid_uuid",
|
||||||
|
["invalid_uuid"],
|
||||||
|
[valid_uuid, "invalid_uuid"],
|
||||||
|
[f"{valid_uuid},invalid_uuid"],
|
||||||
|
)
|
||||||
|
for value in options:
|
||||||
|
with pytest.raises(vol.MultipleInvalid):
|
||||||
|
schema(value)
|
||||||
|
|
||||||
|
options = ([], [valid_uuid], valid_uuid)
|
||||||
|
for value in options:
|
||||||
|
schema(value)
|
||||||
|
|
||||||
|
assert schema(f"{valid_uuid}, {valid_uuid2} ") == [valid_uuid, valid_uuid2]
|
||||||
|
|
||||||
|
|
||||||
def test_entity_domain():
|
def test_entity_domain():
|
||||||
"""Test entity domain validation."""
|
"""Test entity domain validation."""
|
||||||
schema = vol.Schema(cv.entity_domain("sensor"))
|
schema = vol.Schema(cv.entity_domain("sensor"))
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE
|
from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE
|
||||||
|
@ -1023,3 +1024,60 @@ async def test_entity_max_length_exceeded(hass, registry):
|
||||||
assert exc_info.value.property_name == "generated_entity_id"
|
assert exc_info.value.property_name == "generated_entity_id"
|
||||||
assert exc_info.value.max_length == 255
|
assert exc_info.value.max_length == 255
|
||||||
assert exc_info.value.value == f"sensor.{long_entity_id_name}_2"
|
assert exc_info.value.value == f"sensor.{long_entity_id_name}_2"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resolve_entity_ids(hass, registry):
|
||||||
|
"""Test resolving entity IDs."""
|
||||||
|
|
||||||
|
entry1 = registry.async_get_or_create(
|
||||||
|
"light", "hue", "1234", suggested_object_id="beer"
|
||||||
|
)
|
||||||
|
assert entry1.entity_id == "light.beer"
|
||||||
|
|
||||||
|
entry2 = registry.async_get_or_create(
|
||||||
|
"light", "hue", "2345", suggested_object_id="milk"
|
||||||
|
)
|
||||||
|
assert entry2.entity_id == "light.milk"
|
||||||
|
|
||||||
|
expected = ["light.beer", "light.milk"]
|
||||||
|
assert er.async_resolve_entity_ids(registry, [entry1.id, entry2.id]) == expected
|
||||||
|
|
||||||
|
expected = ["light.beer", "light.milk"]
|
||||||
|
assert er.async_resolve_entity_ids(registry, ["light.beer", entry2.id]) == expected
|
||||||
|
|
||||||
|
with pytest.raises(vol.Invalid):
|
||||||
|
er.async_resolve_entity_ids(registry, ["light.beer", "bad_uuid"])
|
||||||
|
|
||||||
|
expected = ["light.unknown"]
|
||||||
|
assert er.async_resolve_entity_ids(registry, ["light.unknown"]) == expected
|
||||||
|
|
||||||
|
with pytest.raises(vol.Invalid):
|
||||||
|
er.async_resolve_entity_ids(registry, ["unknown_uuid"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_entity_registry_items():
|
||||||
|
"""Test the EntityRegistryItems container."""
|
||||||
|
entities = er.EntityRegistryItems()
|
||||||
|
assert entities.get_entity_id(("a", "b", "c")) is None
|
||||||
|
assert entities.get_entry("abc") is None
|
||||||
|
|
||||||
|
entry1 = er.RegistryEntry("test.entity1", "1234", "hue")
|
||||||
|
entry2 = er.RegistryEntry("test.entity2", "2345", "hue")
|
||||||
|
entities["test.entity1"] = entry1
|
||||||
|
entities["test.entity2"] = entry2
|
||||||
|
|
||||||
|
assert entities["test.entity1"] is entry1
|
||||||
|
assert entities["test.entity2"] is entry2
|
||||||
|
|
||||||
|
assert entities.get_entity_id(("test", "hue", "1234")) is entry1.entity_id
|
||||||
|
assert entities.get_entry(entry1.id) is entry1
|
||||||
|
assert entities.get_entity_id(("test", "hue", "2345")) is entry2.entity_id
|
||||||
|
assert entities.get_entry(entry2.id) is entry2
|
||||||
|
|
||||||
|
entities.pop("test.entity1")
|
||||||
|
del entities["test.entity2"]
|
||||||
|
|
||||||
|
assert entities.get_entity_id(("test", "hue", "1234")) is None
|
||||||
|
assert entities.get_entry(entry1.id) is None
|
||||||
|
assert entities.get_entity_id(("test", "hue", "2345")) is None
|
||||||
|
assert entities.get_entry(entry2.id) is None
|
||||||
|
|
|
@ -748,6 +748,7 @@ async def test_wait_basic(hass, action_type):
|
||||||
"to": "off",
|
"to": "off",
|
||||||
}
|
}
|
||||||
sequence = cv.SCRIPT_SCHEMA(action)
|
sequence = cv.SCRIPT_SCHEMA(action)
|
||||||
|
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||||
wait_started_flag = async_watch_for_action(script_obj, wait_alias)
|
wait_started_flag = async_watch_for_action(script_obj, wait_alias)
|
||||||
|
|
||||||
|
@ -848,6 +849,7 @@ async def test_wait_basic_times_out(hass, action_type):
|
||||||
"to": "off",
|
"to": "off",
|
||||||
}
|
}
|
||||||
sequence = cv.SCRIPT_SCHEMA(action)
|
sequence = cv.SCRIPT_SCHEMA(action)
|
||||||
|
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||||
wait_started_flag = async_watch_for_action(script_obj, wait_alias)
|
wait_started_flag = async_watch_for_action(script_obj, wait_alias)
|
||||||
timed_out = False
|
timed_out = False
|
||||||
|
@ -904,6 +906,7 @@ async def test_multiple_runs_wait(hass, action_type):
|
||||||
{"event": event, "event_data": {"value": 2}},
|
{"event": event, "event_data": {"value": 2}},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||||
script_obj = script.Script(
|
script_obj = script.Script(
|
||||||
hass, sequence, "Test Name", "test_domain", script_mode="parallel", max_runs=2
|
hass, sequence, "Test Name", "test_domain", script_mode="parallel", max_runs=2
|
||||||
)
|
)
|
||||||
|
@ -952,6 +955,7 @@ async def test_cancel_wait(hass, action_type):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
|
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
|
||||||
|
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||||
|
|
||||||
|
@ -1049,6 +1053,7 @@ async def test_wait_timeout(hass, caplog, timeout_param, action_type):
|
||||||
action["timeout"] = timeout_param
|
action["timeout"] = timeout_param
|
||||||
action["continue_on_timeout"] = True
|
action["continue_on_timeout"] = True
|
||||||
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
|
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
|
||||||
|
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||||
|
|
||||||
|
@ -1116,6 +1121,7 @@ async def test_wait_continue_on_timeout(
|
||||||
if continue_on_timeout is not None:
|
if continue_on_timeout is not None:
|
||||||
action["continue_on_timeout"] = continue_on_timeout
|
action["continue_on_timeout"] = continue_on_timeout
|
||||||
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
|
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
|
||||||
|
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||||
|
|
||||||
|
@ -1287,6 +1293,7 @@ async def test_wait_variables_out(hass, mode, action_type):
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
sequence = cv.SCRIPT_SCHEMA(sequence)
|
sequence = cv.SCRIPT_SCHEMA(sequence)
|
||||||
|
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||||
|
|
||||||
|
@ -1326,11 +1333,13 @@ async def test_wait_variables_out(hass, mode, action_type):
|
||||||
|
|
||||||
async def test_wait_for_trigger_bad(hass, caplog):
|
async def test_wait_for_trigger_bad(hass, caplog):
|
||||||
"""Test bad wait_for_trigger."""
|
"""Test bad wait_for_trigger."""
|
||||||
|
sequence = cv.SCRIPT_SCHEMA(
|
||||||
|
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
|
||||||
|
)
|
||||||
|
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||||
script_obj = script.Script(
|
script_obj = script.Script(
|
||||||
hass,
|
hass,
|
||||||
cv.SCRIPT_SCHEMA(
|
sequence,
|
||||||
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
|
|
||||||
),
|
|
||||||
"Test Name",
|
"Test Name",
|
||||||
"test_domain",
|
"test_domain",
|
||||||
)
|
)
|
||||||
|
@ -1356,11 +1365,13 @@ async def test_wait_for_trigger_bad(hass, caplog):
|
||||||
|
|
||||||
async def test_wait_for_trigger_generated_exception(hass, caplog):
|
async def test_wait_for_trigger_generated_exception(hass, caplog):
|
||||||
"""Test bad wait_for_trigger."""
|
"""Test bad wait_for_trigger."""
|
||||||
|
sequence = cv.SCRIPT_SCHEMA(
|
||||||
|
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
|
||||||
|
)
|
||||||
|
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||||
script_obj = script.Script(
|
script_obj = script.Script(
|
||||||
hass,
|
hass,
|
||||||
cv.SCRIPT_SCHEMA(
|
sequence,
|
||||||
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
|
|
||||||
),
|
|
||||||
"Test Name",
|
"Test Name",
|
||||||
"test_domain",
|
"test_domain",
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue