Entity to handle updates via events (#24733)
* Entity to handle updates via events * Fix a bug * Update entity.py
This commit is contained in:
parent
9e0636eefa
commit
06af6f19a3
5 changed files with 71 additions and 66 deletions
|
@ -10,6 +10,7 @@ from homeassistant.const import (
|
||||||
ATTR_UNIT_OF_MEASUREMENT, DEVICE_DEFAULT_NAME, STATE_OFF, STATE_ON,
|
ATTR_UNIT_OF_MEASUREMENT, DEVICE_DEFAULT_NAME, STATE_OFF, STATE_ON,
|
||||||
STATE_UNAVAILABLE, STATE_UNKNOWN, TEMP_CELSIUS, TEMP_FAHRENHEIT,
|
STATE_UNAVAILABLE, STATE_UNKNOWN, TEMP_CELSIUS, TEMP_FAHRENHEIT,
|
||||||
ATTR_ENTITY_PICTURE, ATTR_SUPPORTED_FEATURES, ATTR_DEVICE_CLASS)
|
ATTR_ENTITY_PICTURE, ATTR_SUPPORTED_FEATURES, ATTR_DEVICE_CLASS)
|
||||||
|
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.config import DATA_CUSTOMIZE
|
from homeassistant.config import DATA_CUSTOMIZE
|
||||||
from homeassistant.exceptions import NoEntitySpecifiedError
|
from homeassistant.exceptions import NoEntitySpecifiedError
|
||||||
|
@ -78,8 +79,8 @@ class Entity:
|
||||||
# Process updates in parallel
|
# Process updates in parallel
|
||||||
parallel_updates = None
|
parallel_updates = None
|
||||||
|
|
||||||
# Name in the entity registry
|
# Entry in the entity registry
|
||||||
registry_name = None
|
registry_entry = None
|
||||||
|
|
||||||
# Hold list for functions to call on remove.
|
# Hold list for functions to call on remove.
|
||||||
_on_remove = None
|
_on_remove = None
|
||||||
|
@ -259,7 +260,9 @@ class Entity:
|
||||||
if unit_of_measurement is not None:
|
if unit_of_measurement is not None:
|
||||||
attr[ATTR_UNIT_OF_MEASUREMENT] = unit_of_measurement
|
attr[ATTR_UNIT_OF_MEASUREMENT] = unit_of_measurement
|
||||||
|
|
||||||
name = self.registry_name or self.name
|
entry = self.registry_entry
|
||||||
|
# pylint: disable=consider-using-ternary
|
||||||
|
name = (entry and entry.name) or self.name
|
||||||
if name is not None:
|
if name is not None:
|
||||||
attr[ATTR_FRIENDLY_NAME] = name
|
attr[ATTR_FRIENDLY_NAME] = name
|
||||||
|
|
||||||
|
@ -391,6 +394,7 @@ class Entity:
|
||||||
|
|
||||||
async def async_remove(self):
|
async def async_remove(self):
|
||||||
"""Remove entity from Home Assistant."""
|
"""Remove entity from Home Assistant."""
|
||||||
|
await self.async_internal_will_remove_from_hass()
|
||||||
await self.async_will_remove_from_hass()
|
await self.async_will_remove_from_hass()
|
||||||
|
|
||||||
if self._on_remove is not None:
|
if self._on_remove is not None:
|
||||||
|
@ -399,27 +403,52 @@ class Entity:
|
||||||
|
|
||||||
self.hass.states.async_remove(self.entity_id)
|
self.hass.states.async_remove(self.entity_id)
|
||||||
|
|
||||||
@callback
|
|
||||||
def async_registry_updated(self, old, new):
|
|
||||||
"""Handle entity registry update."""
|
|
||||||
self.registry_name = new.name
|
|
||||||
|
|
||||||
if new.entity_id == self.entity_id:
|
|
||||||
self.async_schedule_update_ha_state()
|
|
||||||
return
|
|
||||||
|
|
||||||
async def readd():
|
|
||||||
"""Remove and add entity again."""
|
|
||||||
await self.async_remove()
|
|
||||||
await self.platform.async_add_entities([self])
|
|
||||||
|
|
||||||
self.hass.async_create_task(readd())
|
|
||||||
|
|
||||||
async def async_added_to_hass(self) -> None:
|
async def async_added_to_hass(self) -> None:
|
||||||
"""Run when entity about to be added to hass."""
|
"""Run when entity about to be added to hass.
|
||||||
|
|
||||||
|
To be extended by integrations.
|
||||||
|
"""
|
||||||
|
|
||||||
async def async_will_remove_from_hass(self) -> None:
|
async def async_will_remove_from_hass(self) -> None:
|
||||||
"""Run when entity will be removed from hass."""
|
"""Run when entity will be removed from hass.
|
||||||
|
|
||||||
|
To be extended by integrations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def async_internal_added_to_hass(self) -> None:
|
||||||
|
"""Run when entity about to be added to hass.
|
||||||
|
|
||||||
|
Not to be extended by integrations.
|
||||||
|
"""
|
||||||
|
if self.registry_entry is not None:
|
||||||
|
self.async_on_remove(self.hass.bus.async_listen(
|
||||||
|
EVENT_ENTITY_REGISTRY_UPDATED, self._async_registry_updated))
|
||||||
|
|
||||||
|
async def async_internal_will_remove_from_hass(self) -> None:
|
||||||
|
"""Run when entity will be removed from hass.
|
||||||
|
|
||||||
|
Not to be extended by integrations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _async_registry_updated(self, event):
|
||||||
|
"""Handle entity registry update."""
|
||||||
|
data = event.data
|
||||||
|
if data['action'] != 'update' and data.get(
|
||||||
|
'old_entity_id', data['entity_id']) != self.entity_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
ent_reg = await self.hass.helpers.entity_registry.async_get_registry()
|
||||||
|
old = self.registry_entry
|
||||||
|
self.registry_entry = ent_reg.async_get(data['entity_id'])
|
||||||
|
|
||||||
|
if self.registry_entry.entity_id == old.entity_id:
|
||||||
|
self.async_write_ha_state()
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.async_remove()
|
||||||
|
|
||||||
|
self.entity_id = self.registry_entry.entity_id
|
||||||
|
await self.platform.async_add_entities([self])
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
"""Return the comparison."""
|
"""Return the comparison."""
|
||||||
|
|
|
@ -320,9 +320,8 @@ class EntityPlatform:
|
||||||
'"{} {}"'.format(self.platform_name, entity.unique_id))
|
'"{} {}"'.format(self.platform_name, entity.unique_id))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
entity.registry_entry = entry
|
||||||
entity.entity_id = entry.entity_id
|
entity.entity_id = entry.entity_id
|
||||||
entity.registry_name = entry.name
|
|
||||||
entity.async_on_remove(entry.add_update_listener(entity))
|
|
||||||
|
|
||||||
# We won't generate an entity ID if the platform has already set one
|
# We won't generate an entity ID if the platform has already set one
|
||||||
# We will however make sure that platform cannot pick a registered ID
|
# We will however make sure that platform cannot pick a registered ID
|
||||||
|
@ -360,6 +359,7 @@ class EntityPlatform:
|
||||||
self.entities[entity_id] = entity
|
self.entities[entity_id] = entity
|
||||||
entity.async_on_remove(lambda: self.entities.pop(entity_id))
|
entity.async_on_remove(lambda: self.entities.pop(entity_id))
|
||||||
|
|
||||||
|
await entity.async_internal_added_to_hass()
|
||||||
await entity.async_added_to_hass()
|
await entity.async_added_to_hass()
|
||||||
|
|
||||||
await entity.async_update_ha_state()
|
await entity.async_update_ha_state()
|
||||||
|
|
|
@ -12,7 +12,6 @@ from collections import OrderedDict
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, cast
|
from typing import List, Optional, cast
|
||||||
import weakref
|
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -50,8 +49,6 @@ class RegistryEntry:
|
||||||
disabled_by = attr.ib(
|
disabled_by = attr.ib(
|
||||||
type=str, default=None,
|
type=str, default=None,
|
||||||
validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None)))
|
validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None)))
|
||||||
update_listeners = attr.ib(type=list, default=attr.Factory(list),
|
|
||||||
repr=False)
|
|
||||||
domain = attr.ib(type=str, init=False, repr=False)
|
domain = attr.ib(type=str, init=False, repr=False)
|
||||||
|
|
||||||
@domain.default
|
@domain.default
|
||||||
|
@ -64,18 +61,6 @@ class RegistryEntry:
|
||||||
"""Return if entry is disabled."""
|
"""Return if entry is disabled."""
|
||||||
return self.disabled_by is not None
|
return self.disabled_by is not None
|
||||||
|
|
||||||
def add_update_listener(self, listener):
|
|
||||||
"""Listen for when entry is updated.
|
|
||||||
|
|
||||||
Listener: Callback function(old_entry, new_entry)
|
|
||||||
|
|
||||||
Returns function to unlisten.
|
|
||||||
"""
|
|
||||||
weak_listener = weakref.ref(listener)
|
|
||||||
self.update_listeners.append(weak_listener)
|
|
||||||
|
|
||||||
return lambda: self.update_listeners.remove(weak_listener)
|
|
||||||
|
|
||||||
|
|
||||||
class EntityRegistry:
|
class EntityRegistry:
|
||||||
"""Class to hold a registry of entities."""
|
"""Class to hold a registry of entities."""
|
||||||
|
@ -247,26 +232,17 @@ class EntityRegistry:
|
||||||
|
|
||||||
new = self.entities[entity_id] = attr.evolve(old, **changes)
|
new = self.entities[entity_id] = attr.evolve(old, **changes)
|
||||||
|
|
||||||
to_remove = []
|
|
||||||
for listener_ref in new.update_listeners:
|
|
||||||
listener = listener_ref()
|
|
||||||
if listener is None:
|
|
||||||
to_remove.append(listener_ref)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
listener.async_registry_updated(old, new)
|
|
||||||
except Exception: # pylint: disable=broad-except
|
|
||||||
_LOGGER.exception('Error calling update listener')
|
|
||||||
|
|
||||||
for ref in to_remove:
|
|
||||||
new.update_listeners.remove(ref)
|
|
||||||
|
|
||||||
self.async_schedule_save()
|
self.async_schedule_save()
|
||||||
|
|
||||||
self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, {
|
data = {
|
||||||
'action': 'update',
|
'action': 'update',
|
||||||
'entity_id': entity_id
|
'entity_id': entity_id,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if old.entity_id != entity_id:
|
||||||
|
data['old_entity_id'] = old.entity_id
|
||||||
|
|
||||||
|
self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, data)
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
|
|
|
@ -186,18 +186,18 @@ class RestoreStateData():
|
||||||
class RestoreEntity(Entity):
|
class RestoreEntity(Entity):
|
||||||
"""Mixin class for restoring previous entity state."""
|
"""Mixin class for restoring previous entity state."""
|
||||||
|
|
||||||
async def async_added_to_hass(self) -> None:
|
async def async_internal_added_to_hass(self) -> None:
|
||||||
"""Register this entity as a restorable entity."""
|
"""Register this entity as a restorable entity."""
|
||||||
_, data = await asyncio.gather(
|
_, data = await asyncio.gather(
|
||||||
super().async_added_to_hass(),
|
super().async_internal_added_to_hass(),
|
||||||
RestoreStateData.async_get_instance(self.hass),
|
RestoreStateData.async_get_instance(self.hass),
|
||||||
)
|
)
|
||||||
data.async_restore_entity_added(self.entity_id)
|
data.async_restore_entity_added(self.entity_id)
|
||||||
|
|
||||||
async def async_will_remove_from_hass(self) -> None:
|
async def async_internal_will_remove_from_hass(self) -> None:
|
||||||
"""Run when entity will be removed from hass."""
|
"""Run when entity will be removed from hass."""
|
||||||
_, data = await asyncio.gather(
|
_, data = await asyncio.gather(
|
||||||
super().async_will_remove_from_hass(),
|
super().async_internal_will_remove_from_hass(),
|
||||||
RestoreStateData.async_get_instance(self.hass),
|
RestoreStateData.async_get_instance(self.hass),
|
||||||
)
|
)
|
||||||
data.async_restore_entity_removed(self.entity_id)
|
data.async_restore_entity_removed(self.entity_id)
|
||||||
|
|
|
@ -104,12 +104,12 @@ async def test_dump_data(hass):
|
||||||
entity = Entity()
|
entity = Entity()
|
||||||
entity.hass = hass
|
entity.hass = hass
|
||||||
entity.entity_id = 'input_boolean.b0'
|
entity.entity_id = 'input_boolean.b0'
|
||||||
await entity.async_added_to_hass()
|
await entity.async_internal_added_to_hass()
|
||||||
|
|
||||||
entity = RestoreEntity()
|
entity = RestoreEntity()
|
||||||
entity.hass = hass
|
entity.hass = hass
|
||||||
entity.entity_id = 'input_boolean.b1'
|
entity.entity_id = 'input_boolean.b1'
|
||||||
await entity.async_added_to_hass()
|
await entity.async_internal_added_to_hass()
|
||||||
|
|
||||||
data = await RestoreStateData.async_get_instance(hass)
|
data = await RestoreStateData.async_get_instance(hass)
|
||||||
now = dt_util.utcnow()
|
now = dt_util.utcnow()
|
||||||
|
@ -144,7 +144,7 @@ async def test_dump_data(hass):
|
||||||
assert written_states[1]['state']['state'] == 'off'
|
assert written_states[1]['state']['state'] == 'off'
|
||||||
|
|
||||||
# Test that removed entities are not persisted
|
# Test that removed entities are not persisted
|
||||||
await entity.async_will_remove_from_hass()
|
await entity.async_remove()
|
||||||
|
|
||||||
with patch('homeassistant.helpers.restore_state.Store.async_save'
|
with patch('homeassistant.helpers.restore_state.Store.async_save'
|
||||||
) as mock_write_data, patch.object(
|
) as mock_write_data, patch.object(
|
||||||
|
@ -170,12 +170,12 @@ async def test_dump_error(hass):
|
||||||
entity = Entity()
|
entity = Entity()
|
||||||
entity.hass = hass
|
entity.hass = hass
|
||||||
entity.entity_id = 'input_boolean.b0'
|
entity.entity_id = 'input_boolean.b0'
|
||||||
await entity.async_added_to_hass()
|
await entity.async_internal_added_to_hass()
|
||||||
|
|
||||||
entity = RestoreEntity()
|
entity = RestoreEntity()
|
||||||
entity.hass = hass
|
entity.hass = hass
|
||||||
entity.entity_id = 'input_boolean.b1'
|
entity.entity_id = 'input_boolean.b1'
|
||||||
await entity.async_added_to_hass()
|
await entity.async_internal_added_to_hass()
|
||||||
|
|
||||||
data = await RestoreStateData.async_get_instance(hass)
|
data = await RestoreStateData.async_get_instance(hass)
|
||||||
|
|
||||||
|
@ -206,7 +206,7 @@ async def test_state_saved_on_remove(hass):
|
||||||
entity = RestoreEntity()
|
entity = RestoreEntity()
|
||||||
entity.hass = hass
|
entity.hass = hass
|
||||||
entity.entity_id = 'input_boolean.b0'
|
entity.entity_id = 'input_boolean.b0'
|
||||||
await entity.async_added_to_hass()
|
await entity.async_internal_added_to_hass()
|
||||||
|
|
||||||
hass.states.async_set('input_boolean.b0', 'on')
|
hass.states.async_set('input_boolean.b0', 'on')
|
||||||
|
|
||||||
|
@ -215,7 +215,7 @@ async def test_state_saved_on_remove(hass):
|
||||||
# No last states should currently be saved
|
# No last states should currently be saved
|
||||||
assert not data.last_states
|
assert not data.last_states
|
||||||
|
|
||||||
await entity.async_will_remove_from_hass()
|
await entity.async_remove()
|
||||||
|
|
||||||
# We should store the input boolean state when it is removed
|
# We should store the input boolean state when it is removed
|
||||||
assert data.last_states['input_boolean.b0'].state.state == 'on'
|
assert data.last_states['input_boolean.b0'].state.state == 'on'
|
||||||
|
|
Loading…
Add table
Reference in a new issue