Keep capabilities up to date in the entity registry (#101748)

* Keep capabilities up to date in the entity registry

* Warn if entities update their capabilities very often

* Fix updating of device class

* Stop tracking capability updates once flooding is logged

* Only sync registry if state changed

* Improve test

* Revert "Only sync registry if state changed"

This reverts commit 1c52571596c06444df234d4b088242b494b630f2.

* Avoid calculating device class twice

* Address review comments

* Revert using dataclass

* Fix unintended revert

* Add helper method
This commit is contained in:
Erik Montnemery 2023-12-13 17:27:26 +01:00 committed by GitHub
parent 4f9f548929
commit dd5a48996a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 257 additions and 15 deletions

View file

@ -509,7 +509,8 @@ class GroupEntity(Entity):
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)
preview_callback(*self._async_generate_attributes())
calculated_state = self._async_calculate_state()
preview_callback(calculated_state.state, calculated_state.attributes)
async_state_changed_listener(None)
return async_track_state_change_event(

View file

@ -236,7 +236,8 @@ class MediaPlayerGroup(MediaPlayerEntity):
) -> None:
"""Handle child updates."""
self.async_update_group_state()
preview_callback(*self._async_generate_attributes())
calculated_state = self._async_calculate_state()
preview_callback(calculated_state.state, calculated_state.attributes)
async_state_changed_listener(None)
return async_track_state_change_event(

View file

@ -430,14 +430,17 @@ class TemplateEntity(Entity):
return
try:
state, attrs = self._async_generate_attributes()
validate_state(state)
calculated_state = self._async_calculate_state()
validate_state(calculated_state.state)
except Exception as err: # pylint: disable=broad-exception-caught
self._preview_callback(None, None, None, str(err))
else:
assert self._template_result_info
self._preview_callback(
state, attrs, self._template_result_info.listeners, None
calculated_state.state,
calculated_state.attributes,
self._template_result_info.listeners,
None,
)
@callback

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from abc import ABC
import asyncio
from collections import deque
from collections.abc import Coroutine, Iterable, Mapping, MutableMapping
import dataclasses
from datetime import timedelta
@ -75,6 +76,9 @@ DATA_ENTITY_SOURCE = "entity_info"
# epsilon to make the string representation readable
FLOAT_PRECISION = abs(int(math.floor(math.log10(abs(sys.float_info.epsilon))))) - 1
# How many times per hour we allow capabilities to be updated before logging a warning
CAPABILITIES_UPDATE_LIMIT = 100
@callback
def async_setup(hass: HomeAssistant) -> None:
@ -237,6 +241,22 @@ class EntityDescription(metaclass=FrozenOrThawed, frozen_or_thawed=True):
unit_of_measurement: str | None = None
@dataclasses.dataclass(frozen=True, slots=True)
class CalculatedState:
"""Container with state and attributes.
Returned by Entity._async_calculate_state.
"""
state: str
# The union of all attributes, after overriding with entity registry settings
attributes: dict[str, Any]
# Capability attributes returned by the capability_attributes property
capability_attributes: Mapping[str, Any] | None
# Attributes which may be overridden by the entity registry
shadowed_attributes: Mapping[str, Any]
class Entity(ABC):
"""An abstract class for Home Assistant entities."""
@ -311,6 +331,8 @@ class Entity(ABC):
# and removes the need for constant None checks or asserts.
_state_info: StateInfo = None # type: ignore[assignment]
__capabilities_updated_at: deque[float]
__capabilities_updated_at_reported: bool = False
__remove_event: asyncio.Event | None = None
# Entity Properties
@ -775,12 +797,29 @@ class Entity(ABC):
return f"{device_name} {name}" if device_name else name
@callback
def _async_generate_attributes(self) -> tuple[str, dict[str, Any]]:
def _async_calculate_state(self) -> CalculatedState:
"""Calculate state string and attribute mapping."""
return CalculatedState(*self.__async_calculate_state())
def __async_calculate_state(
self,
) -> tuple[str, dict[str, Any], Mapping[str, Any] | None, Mapping[str, Any]]:
"""Calculate state string and attribute mapping.
Returns a tuple (state, attr, capability_attr, shadowed_attr).
state - the stringified state
attr - the attribute dictionary
capability_attr - a mapping with capability attributes
shadowed_attr - a mapping with attributes which may be overridden
This method is called when writing the state to avoid the overhead of creating
a dataclass object.
"""
entry = self.registry_entry
attr = self.capability_attributes
attr = dict(attr) if attr else {}
capability_attr = self.capability_attributes
attr = dict(capability_attr) if capability_attr else {}
shadowed_attr = {}
available = self.available # only call self.available once per update cycle
state = self._stringify_state(available)
@ -797,26 +836,30 @@ class Entity(ABC):
if (attribution := self.attribution) is not None:
attr[ATTR_ATTRIBUTION] = attribution
shadowed_attr[ATTR_DEVICE_CLASS] = self.device_class
if (
device_class := (entry and entry.device_class) or self.device_class
device_class := (entry and entry.device_class)
or shadowed_attr[ATTR_DEVICE_CLASS]
) is not None:
attr[ATTR_DEVICE_CLASS] = str(device_class)
if (entity_picture := self.entity_picture) is not None:
attr[ATTR_ENTITY_PICTURE] = entity_picture
if (icon := (entry and entry.icon) or self.icon) is not None:
shadowed_attr[ATTR_ICON] = self.icon
if (icon := (entry and entry.icon) or shadowed_attr[ATTR_ICON]) is not None:
attr[ATTR_ICON] = icon
shadowed_attr[ATTR_FRIENDLY_NAME] = self._friendly_name_internal()
if (
name := (entry and entry.name) or self._friendly_name_internal()
name := (entry and entry.name) or shadowed_attr[ATTR_FRIENDLY_NAME]
) is not None:
attr[ATTR_FRIENDLY_NAME] = name
if (supported_features := self.supported_features) is not None:
attr[ATTR_SUPPORTED_FEATURES] = supported_features
return (state, attr)
return (state, attr, capability_attr, shadowed_attr)
@callback
def _async_write_ha_state(self) -> None:
@ -842,9 +885,45 @@ class Entity(ABC):
return
start = timer()
state, attr = self._async_generate_attributes()
state, attr, capabilities, shadowed_attr = self.__async_calculate_state()
end = timer()
if entry:
# Make sure capabilities in the entity registry are up to date. Capabilities
# include capability attributes, device class and supported features
original_device_class: str | None = shadowed_attr[ATTR_DEVICE_CLASS]
supported_features: int = attr.get(ATTR_SUPPORTED_FEATURES) or 0
if (
capabilities != entry.capabilities
or original_device_class != entry.original_device_class
or supported_features != entry.supported_features
):
if not self.__capabilities_updated_at_reported:
time_now = hass.loop.time()
capabilities_updated_at = self.__capabilities_updated_at
capabilities_updated_at.append(time_now)
while time_now - capabilities_updated_at[0] > 3600:
capabilities_updated_at.popleft()
if len(capabilities_updated_at) > CAPABILITIES_UPDATE_LIMIT:
self.__capabilities_updated_at_reported = True
report_issue = self._suggest_report_issue()
_LOGGER.warning(
(
"Entity %s (%s) is updating its capabilities too often,"
" please %s"
),
entity_id,
type(self),
report_issue,
)
entity_registry = er.async_get(self.hass)
self.registry_entry = entity_registry.async_update_entity(
self.entity_id,
capabilities=capabilities,
original_device_class=original_device_class,
supported_features=supported_features,
)
if end - start > 0.4 and not self._slow_reported:
self._slow_reported = True
report_issue = self._suggest_report_issue()
@ -1118,6 +1197,8 @@ class Entity(ABC):
)
self._async_subscribe_device_updates()
self.__capabilities_updated_at = deque(maxlen=CAPABILITIES_UPDATE_LIMIT + 1)
async def async_internal_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass.

View file

@ -8,6 +8,7 @@ import threading
from typing import Any
from unittest.mock import MagicMock, PropertyMock, patch
from freezegun.api import FrozenDateTimeFactory
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
@ -1412,8 +1413,8 @@ async def test_repr_using_stringify_state() -> None:
"""Return the state."""
raise ValueError("Boom")
entity = MyEntity(entity_id="test.test", available=False)
assert str(entity) == "<entity test.test=unavailable>"
my_entity = MyEntity(entity_id="test.test", available=False)
assert str(my_entity) == "<entity test.test=unavailable>"
async def test_warn_using_async_update_ha_state(
@ -1761,3 +1762,158 @@ def test_extending_entity_description(snapshot: SnapshotAssertion):
assert obj == snapshot
assert obj == CustomInitEntityDescription(key="blah", extra="foo", name="name")
assert repr(obj) == snapshot
async def test_update_capabilities(
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
) -> None:
"""Test entity capabilities are updated automatically."""
platform = MockEntityPlatform(hass)
ent = MockEntity(unique_id="qwer")
await platform.async_add_entities([ent])
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities is None
assert entry.device_class is None
assert entry.supported_features == 0
ent._values["capability_attributes"] = {"bla": "blu"}
ent._values["device_class"] = "some_class"
ent._values["supported_features"] = 127
ent.async_write_ha_state()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities == {"bla": "blu"}
assert entry.original_device_class == "some_class"
assert entry.supported_features == 127
ent._values["capability_attributes"] = None
ent._values["device_class"] = None
ent._values["supported_features"] = None
ent.async_write_ha_state()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities is None
assert entry.original_device_class is None
assert entry.supported_features == 0
# Device class can be overridden by user, make sure that does not break the
# automatic updating.
entity_registry.async_update_entity(ent.entity_id, device_class="set_by_user")
await hass.async_block_till_done()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities is None
assert entry.original_device_class is None
assert entry.supported_features == 0
# This will not trigger a state change because the device class is shadowed
# by the entity registry
ent._values["device_class"] = "some_class"
ent.async_write_ha_state()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities is None
assert entry.original_device_class == "some_class"
assert entry.supported_features == 0
async def test_update_capabilities_no_unique_id(
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
) -> None:
"""Test entity capabilities are updated automatically."""
platform = MockEntityPlatform(hass)
ent = MockEntity()
await platform.async_add_entities([ent])
assert entity_registry.async_get(ent.entity_id) is None
ent._values["capability_attributes"] = {"bla": "blu"}
ent._values["supported_features"] = 127
ent.async_write_ha_state()
assert entity_registry.async_get(ent.entity_id) is None
async def test_update_capabilities_too_often(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
entity_registry: er.EntityRegistry,
) -> None:
"""Test entity capabilities are updated automatically."""
capabilities_too_often_warning = "is updating its capabilities too often"
platform = MockEntityPlatform(hass)
ent = MockEntity(unique_id="qwer")
await platform.async_add_entities([ent])
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities is None
assert entry.device_class is None
assert entry.supported_features == 0
for supported_features in range(1, entity.CAPABILITIES_UPDATE_LIMIT + 1):
ent._values["capability_attributes"] = {"bla": "blu"}
ent._values["device_class"] = "some_class"
ent._values["supported_features"] = supported_features
ent.async_write_ha_state()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities == {"bla": "blu"}
assert entry.original_device_class == "some_class"
assert entry.supported_features == supported_features
assert capabilities_too_often_warning not in caplog.text
ent._values["capability_attributes"] = {"bla": "blu"}
ent._values["device_class"] = "some_class"
ent._values["supported_features"] = supported_features + 1
ent.async_write_ha_state()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities == {"bla": "blu"}
assert entry.original_device_class == "some_class"
assert entry.supported_features == supported_features + 1
assert capabilities_too_often_warning in caplog.text
async def test_update_capabilities_too_often_cooldown(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
entity_registry: er.EntityRegistry,
freezer: FrozenDateTimeFactory,
) -> None:
"""Test entity capabilities are updated automatically."""
capabilities_too_often_warning = "is updating its capabilities too often"
platform = MockEntityPlatform(hass)
ent = MockEntity(unique_id="qwer")
await platform.async_add_entities([ent])
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities is None
assert entry.device_class is None
assert entry.supported_features == 0
for supported_features in range(1, entity.CAPABILITIES_UPDATE_LIMIT + 1):
ent._values["capability_attributes"] = {"bla": "blu"}
ent._values["device_class"] = "some_class"
ent._values["supported_features"] = supported_features
ent.async_write_ha_state()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities == {"bla": "blu"}
assert entry.original_device_class == "some_class"
assert entry.supported_features == supported_features
assert capabilities_too_often_warning not in caplog.text
freezer.tick(timedelta(minutes=60) + timedelta(seconds=1))
ent._values["capability_attributes"] = {"bla": "blu"}
ent._values["device_class"] = "some_class"
ent._values["supported_features"] = supported_features + 1
ent.async_write_ha_state()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities == {"bla": "blu"}
assert entry.original_device_class == "some_class"
assert entry.supported_features == supported_features + 1
assert capabilities_too_often_warning not in caplog.text