Entity to handle updates via events (#24733)

* Entity to handle updates via events

* Fix a bug

* Update entity.py
This commit is contained in:
Paulus Schoutsen 2019-06-26 09:22:51 -07:00 committed by GitHub
parent 9e0636eefa
commit 06af6f19a3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 71 additions and 66 deletions

View file

@ -10,6 +10,7 @@ from homeassistant.const import (
ATTR_UNIT_OF_MEASUREMENT, DEVICE_DEFAULT_NAME, STATE_OFF, STATE_ON, ATTR_UNIT_OF_MEASUREMENT, DEVICE_DEFAULT_NAME, STATE_OFF, STATE_ON,
STATE_UNAVAILABLE, STATE_UNKNOWN, TEMP_CELSIUS, TEMP_FAHRENHEIT, STATE_UNAVAILABLE, STATE_UNKNOWN, TEMP_CELSIUS, TEMP_FAHRENHEIT,
ATTR_ENTITY_PICTURE, ATTR_SUPPORTED_FEATURES, ATTR_DEVICE_CLASS) ATTR_ENTITY_PICTURE, ATTR_SUPPORTED_FEATURES, ATTR_DEVICE_CLASS)
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.config import DATA_CUSTOMIZE from homeassistant.config import DATA_CUSTOMIZE
from homeassistant.exceptions import NoEntitySpecifiedError from homeassistant.exceptions import NoEntitySpecifiedError
@ -78,8 +79,8 @@ class Entity:
# Process updates in parallel # Process updates in parallel
parallel_updates = None parallel_updates = None
# Name in the entity registry # Entry in the entity registry
registry_name = None registry_entry = None
# Hold list for functions to call on remove. # Hold list for functions to call on remove.
_on_remove = None _on_remove = None
@ -259,7 +260,9 @@ class Entity:
if unit_of_measurement is not None: if unit_of_measurement is not None:
attr[ATTR_UNIT_OF_MEASUREMENT] = unit_of_measurement attr[ATTR_UNIT_OF_MEASUREMENT] = unit_of_measurement
name = self.registry_name or self.name entry = self.registry_entry
# pylint: disable=consider-using-ternary
name = (entry and entry.name) or self.name
if name is not None: if name is not None:
attr[ATTR_FRIENDLY_NAME] = name attr[ATTR_FRIENDLY_NAME] = name
@ -391,6 +394,7 @@ class Entity:
async def async_remove(self): async def async_remove(self):
"""Remove entity from Home Assistant.""" """Remove entity from Home Assistant."""
await self.async_internal_will_remove_from_hass()
await self.async_will_remove_from_hass() await self.async_will_remove_from_hass()
if self._on_remove is not None: if self._on_remove is not None:
@ -399,27 +403,52 @@ class Entity:
self.hass.states.async_remove(self.entity_id) self.hass.states.async_remove(self.entity_id)
@callback
def async_registry_updated(self, old, new):
"""Handle entity registry update."""
self.registry_name = new.name
if new.entity_id == self.entity_id:
self.async_schedule_update_ha_state()
return
async def readd():
"""Remove and add entity again."""
await self.async_remove()
await self.platform.async_add_entities([self])
self.hass.async_create_task(readd())
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass.""" """Run when entity about to be added to hass.
To be extended by integrations.
"""
async def async_will_remove_from_hass(self) -> None: async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass.""" """Run when entity will be removed from hass.
To be extended by integrations.
"""
async def async_internal_added_to_hass(self) -> None:
"""Run when entity about to be added to hass.
Not to be extended by integrations.
"""
if self.registry_entry is not None:
self.async_on_remove(self.hass.bus.async_listen(
EVENT_ENTITY_REGISTRY_UPDATED, self._async_registry_updated))
async def async_internal_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass.
Not to be extended by integrations.
"""
async def _async_registry_updated(self, event):
"""Handle entity registry update."""
data = event.data
if data['action'] != 'update' and data.get(
'old_entity_id', data['entity_id']) != self.entity_id:
return
ent_reg = await self.hass.helpers.entity_registry.async_get_registry()
old = self.registry_entry
self.registry_entry = ent_reg.async_get(data['entity_id'])
if self.registry_entry.entity_id == old.entity_id:
self.async_write_ha_state()
return
await self.async_remove()
self.entity_id = self.registry_entry.entity_id
await self.platform.async_add_entities([self])
def __eq__(self, other): def __eq__(self, other):
"""Return the comparison.""" """Return the comparison."""

View file

@ -320,9 +320,8 @@ class EntityPlatform:
'"{} {}"'.format(self.platform_name, entity.unique_id)) '"{} {}"'.format(self.platform_name, entity.unique_id))
return return
entity.registry_entry = entry
entity.entity_id = entry.entity_id entity.entity_id = entry.entity_id
entity.registry_name = entry.name
entity.async_on_remove(entry.add_update_listener(entity))
# We won't generate an entity ID if the platform has already set one # We won't generate an entity ID if the platform has already set one
# We will however make sure that platform cannot pick a registered ID # We will however make sure that platform cannot pick a registered ID
@ -360,6 +359,7 @@ class EntityPlatform:
self.entities[entity_id] = entity self.entities[entity_id] = entity
entity.async_on_remove(lambda: self.entities.pop(entity_id)) entity.async_on_remove(lambda: self.entities.pop(entity_id))
await entity.async_internal_added_to_hass()
await entity.async_added_to_hass() await entity.async_added_to_hass()
await entity.async_update_ha_state() await entity.async_update_ha_state()

View file

@ -12,7 +12,6 @@ from collections import OrderedDict
from itertools import chain from itertools import chain
import logging import logging
from typing import List, Optional, cast from typing import List, Optional, cast
import weakref
import attr import attr
@ -50,8 +49,6 @@ class RegistryEntry:
disabled_by = attr.ib( disabled_by = attr.ib(
type=str, default=None, type=str, default=None,
validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None))) validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None)))
update_listeners = attr.ib(type=list, default=attr.Factory(list),
repr=False)
domain = attr.ib(type=str, init=False, repr=False) domain = attr.ib(type=str, init=False, repr=False)
@domain.default @domain.default
@ -64,18 +61,6 @@ class RegistryEntry:
"""Return if entry is disabled.""" """Return if entry is disabled."""
return self.disabled_by is not None return self.disabled_by is not None
def add_update_listener(self, listener):
"""Listen for when entry is updated.
Listener: Callback function(old_entry, new_entry)
Returns function to unlisten.
"""
weak_listener = weakref.ref(listener)
self.update_listeners.append(weak_listener)
return lambda: self.update_listeners.remove(weak_listener)
class EntityRegistry: class EntityRegistry:
"""Class to hold a registry of entities.""" """Class to hold a registry of entities."""
@ -247,26 +232,17 @@ class EntityRegistry:
new = self.entities[entity_id] = attr.evolve(old, **changes) new = self.entities[entity_id] = attr.evolve(old, **changes)
to_remove = []
for listener_ref in new.update_listeners:
listener = listener_ref()
if listener is None:
to_remove.append(listener_ref)
else:
try:
listener.async_registry_updated(old, new)
except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error calling update listener')
for ref in to_remove:
new.update_listeners.remove(ref)
self.async_schedule_save() self.async_schedule_save()
self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, { data = {
'action': 'update', 'action': 'update',
'entity_id': entity_id 'entity_id': entity_id,
}) }
if old.entity_id != entity_id:
data['old_entity_id'] = old.entity_id
self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, data)
return new return new

View file

@ -186,18 +186,18 @@ class RestoreStateData():
class RestoreEntity(Entity): class RestoreEntity(Entity):
"""Mixin class for restoring previous entity state.""" """Mixin class for restoring previous entity state."""
async def async_added_to_hass(self) -> None: async def async_internal_added_to_hass(self) -> None:
"""Register this entity as a restorable entity.""" """Register this entity as a restorable entity."""
_, data = await asyncio.gather( _, data = await asyncio.gather(
super().async_added_to_hass(), super().async_internal_added_to_hass(),
RestoreStateData.async_get_instance(self.hass), RestoreStateData.async_get_instance(self.hass),
) )
data.async_restore_entity_added(self.entity_id) data.async_restore_entity_added(self.entity_id)
async def async_will_remove_from_hass(self) -> None: async def async_internal_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass.""" """Run when entity will be removed from hass."""
_, data = await asyncio.gather( _, data = await asyncio.gather(
super().async_will_remove_from_hass(), super().async_internal_will_remove_from_hass(),
RestoreStateData.async_get_instance(self.hass), RestoreStateData.async_get_instance(self.hass),
) )
data.async_restore_entity_removed(self.entity_id) data.async_restore_entity_removed(self.entity_id)

View file

@ -104,12 +104,12 @@ async def test_dump_data(hass):
entity = Entity() entity = Entity()
entity.hass = hass entity.hass = hass
entity.entity_id = 'input_boolean.b0' entity.entity_id = 'input_boolean.b0'
await entity.async_added_to_hass() await entity.async_internal_added_to_hass()
entity = RestoreEntity() entity = RestoreEntity()
entity.hass = hass entity.hass = hass
entity.entity_id = 'input_boolean.b1' entity.entity_id = 'input_boolean.b1'
await entity.async_added_to_hass() await entity.async_internal_added_to_hass()
data = await RestoreStateData.async_get_instance(hass) data = await RestoreStateData.async_get_instance(hass)
now = dt_util.utcnow() now = dt_util.utcnow()
@ -144,7 +144,7 @@ async def test_dump_data(hass):
assert written_states[1]['state']['state'] == 'off' assert written_states[1]['state']['state'] == 'off'
# Test that removed entities are not persisted # Test that removed entities are not persisted
await entity.async_will_remove_from_hass() await entity.async_remove()
with patch('homeassistant.helpers.restore_state.Store.async_save' with patch('homeassistant.helpers.restore_state.Store.async_save'
) as mock_write_data, patch.object( ) as mock_write_data, patch.object(
@ -170,12 +170,12 @@ async def test_dump_error(hass):
entity = Entity() entity = Entity()
entity.hass = hass entity.hass = hass
entity.entity_id = 'input_boolean.b0' entity.entity_id = 'input_boolean.b0'
await entity.async_added_to_hass() await entity.async_internal_added_to_hass()
entity = RestoreEntity() entity = RestoreEntity()
entity.hass = hass entity.hass = hass
entity.entity_id = 'input_boolean.b1' entity.entity_id = 'input_boolean.b1'
await entity.async_added_to_hass() await entity.async_internal_added_to_hass()
data = await RestoreStateData.async_get_instance(hass) data = await RestoreStateData.async_get_instance(hass)
@ -206,7 +206,7 @@ async def test_state_saved_on_remove(hass):
entity = RestoreEntity() entity = RestoreEntity()
entity.hass = hass entity.hass = hass
entity.entity_id = 'input_boolean.b0' entity.entity_id = 'input_boolean.b0'
await entity.async_added_to_hass() await entity.async_internal_added_to_hass()
hass.states.async_set('input_boolean.b0', 'on') hass.states.async_set('input_boolean.b0', 'on')
@ -215,7 +215,7 @@ async def test_state_saved_on_remove(hass):
# No last states should currently be saved # No last states should currently be saved
assert not data.last_states assert not data.last_states
await entity.async_will_remove_from_hass() await entity.async_remove()
# We should store the input boolean state when it is removed # We should store the input boolean state when it is removed
assert data.last_states['input_boolean.b0'].state.state == 'on' assert data.last_states['input_boolean.b0'].state.state == 'on'