Keep ZHA entity enabled setting in sync with lib (#125472)

* Add ability to enable / disable entities in the ZHA lib

* disable entities at startup that are not enabled in HA

* fix IEEE lookup

* wrap in async_on_unload

* add test and correct lookup
This commit is contained in:
David F. Mulcahey 2024-10-17 07:16:48 -04:00 committed by GitHub
parent 8533f853c8
commit 065577c9ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 71 additions and 3 deletions

View file

@ -104,7 +104,7 @@ from homeassistant.const import (
ATTR_NAME,
Platform,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
config_validation as cv,
@ -495,7 +495,7 @@ class ZHAGatewayProxy(EventBase):
self.hass = hass
self.config_entry = config_entry
self.gateway = gateway
self.device_proxies: dict[str, ZHADeviceProxy] = {}
self.device_proxies: dict[EUI64, ZHADeviceProxy] = {}
self.group_proxies: dict[int, ZHAGroupProxy] = {}
self._ha_entity_refs: collections.defaultdict[EUI64, list[EntityReference]] = (
collections.defaultdict(list)
@ -509,6 +509,12 @@ class ZHAGatewayProxy(EventBase):
self._unsubs: list[Callable[[], None]] = []
self._unsubs.append(self.gateway.on_all_events(self._handle_event_protocol))
self._reload_task: asyncio.Task | None = None
config_entry.async_on_unload(
self.hass.bus.async_listen(
er.EVENT_ENTITY_REGISTRY_UPDATED,
self._handle_entity_registry_updated,
)
)
@property
def ha_entity_refs(self) -> collections.defaultdict[EUI64, list[EntityReference]]:
@ -532,6 +538,46 @@ class ZHAGatewayProxy(EventBase):
)
)
async def _handle_entity_registry_updated(
self, event: Event[er.EventEntityRegistryUpdatedData]
) -> None:
"""Handle when entity registry updated."""
entity_id = event.data["entity_id"]
entity_entry: er.RegistryEntry | None = er.async_get(self.hass).async_get(
entity_id
)
if (
entity_entry is None
or entity_entry.config_entry_id != self.config_entry.entry_id
or entity_entry.device_id is None
):
return
device_entry: dr.DeviceEntry | None = dr.async_get(self.hass).async_get(
entity_entry.device_id
)
assert device_entry
ieee_address = next(
identifier
for domain, identifier in device_entry.identifiers
if domain == DOMAIN
)
assert ieee_address
ieee = EUI64.convert(ieee_address)
assert ieee in self.device_proxies
zha_device_proxy = self.device_proxies[ieee]
entity_key = (entity_entry.domain, entity_entry.unique_id)
if entity_key not in zha_device_proxy.device.platform_entities:
return
platform_entity = zha_device_proxy.device.platform_entities[entity_key]
if entity_entry.disabled:
platform_entity.disable()
else:
platform_entity.enable()
async def async_initialize_devices_and_entities(self) -> None:
"""Initialize devices and entities."""
for device in self.gateway.devices.values():
@ -1117,7 +1163,7 @@ def async_add_entities(
if not entities:
return
entities_to_add = []
entities_to_add: list[ZHAEntity] = []
for entity_data in entities:
try:
entities_to_add.append(entity_class(entity_data))
@ -1129,6 +1175,9 @@ def async_add_entities(
"Error while adding entity from entity data: %s", entity_data
)
_async_add_entities(entities_to_add, update_before_add=False)
for entity in entities_to_add:
if not entity.enabled:
entity.entity_data.entity.disable()
entities.clear()

View file

@ -14,6 +14,7 @@ from homeassistant.components.zha.helpers import (
)
from homeassistant.const import STATE_OFF, STATE_ON, Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from .common import find_entity_id, send_attributes_report
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE
@ -37,6 +38,7 @@ def binary_sensor_platform_only():
async def test_binary_sensor(
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
setup_zha,
zigpy_device_mock,
) -> None:
@ -77,3 +79,20 @@ async def test_binary_sensor(
hass, cluster, {general.OnOff.AttributeDefs.on_off.id: OFF}
)
assert hass.states.get(entity_id).state == STATE_OFF
# test enable / disable sync w/ ZHA library
entity_entry = entity_registry.async_get(entity_id)
entity_key = (Platform.BINARY_SENSOR, entity_entry.unique_id)
assert zha_device_proxy.device.platform_entities.get(entity_key).enabled
entity_registry.async_update_entity(
entity_id=entity_id, disabled_by=er.RegistryEntryDisabler.USER
)
await hass.async_block_till_done()
assert not zha_device_proxy.device.platform_entities.get(entity_key).enabled
entity_registry.async_update_entity(entity_id=entity_id, disabled_by=None)
await hass.async_block_till_done()
assert zha_device_proxy.device.platform_entities.get(entity_key).enabled