Simplify groups (#63477)

* Simplify group

* Rename async_update to async_update_group_state and mark it as callback

* Simplify _async_start
This commit is contained in:
Erik Montnemery 2022-01-07 08:58:45 +01:00 committed by GitHub
parent e222e1b6f0
commit 8bf8709d99
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 45 additions and 66 deletions

View file

@ -21,19 +21,13 @@ from homeassistant.const import (
CONF_NAME, CONF_NAME,
ENTITY_MATCH_ALL, ENTITY_MATCH_ALL,
ENTITY_MATCH_NONE, ENTITY_MATCH_NONE,
EVENT_HOMEASSISTANT_START,
SERVICE_RELOAD, SERVICE_RELOAD,
STATE_OFF, STATE_OFF,
STATE_ON, STATE_ON,
Platform, Platform,
) )
from homeassistant.core import ( from homeassistant.core import HomeAssistant, ServiceCall, callback, split_entity_id
CoreState, from homeassistant.helpers import start
HomeAssistant,
ServiceCall,
callback,
split_entity_id,
)
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity, async_generate_entity_id from homeassistant.helpers.entity import Entity, async_generate_entity_id
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
@ -407,21 +401,22 @@ class GroupEntity(Entity):
"""Register listeners.""" """Register listeners."""
async def _update_at_start(_): async def _update_at_start(_):
await self.async_update() self.async_update_group_state()
self.async_write_ha_state() self.async_write_ha_state()
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, _update_at_start) start.async_at_start(self.hass, _update_at_start)
async def async_defer_or_update_ha_state(self) -> None: @callback
def async_defer_or_update_ha_state(self) -> None:
"""Only update once at start.""" """Only update once at start."""
if self.hass.state != CoreState.running: if not self.hass.is_running:
return return
await self.async_update() self.async_update_group_state()
self.async_write_ha_state() self.async_write_ha_state()
@abstractmethod @abstractmethod
async def async_update(self) -> None: def async_update_group_state(self) -> None:
"""Abstract method to update the entity.""" """Abstract method to update the entity."""
@ -636,22 +631,15 @@ class Group(Entity):
self._async_unsub_state_changed() self._async_unsub_state_changed()
self._async_unsub_state_changed = None self._async_unsub_state_changed = None
async def async_update(self): @callback
def async_update_group_state(self):
"""Query all members and determine current group state.""" """Query all members and determine current group state."""
self._state = None self._state = None
self._async_update_group_state() self._async_update_group_state()
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Handle addition to Home Assistant.""" """Handle addition to Home Assistant."""
if self.hass.state != CoreState.running: start.async_at_start(self.hass, self._async_start)
self.hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_START, self._async_start
)
return
if self.tracking:
self._reset_tracked_state()
self._async_start_tracking()
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
"""Handle removal from Home Assistant.""" """Handle removal from Home Assistant."""

View file

@ -20,7 +20,7 @@ from homeassistant.const import (
STATE_ON, STATE_ON,
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
) )
from homeassistant.core import CoreState, Event, HomeAssistant from homeassistant.core import Event, HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.event import async_track_state_change_event
@ -90,10 +90,11 @@ class BinarySensorGroup(GroupEntity, BinarySensorEntity):
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Register callbacks.""" """Register callbacks."""
async def async_state_changed_listener(event: Event) -> None: @callback
def async_state_changed_listener(event: Event) -> None:
"""Handle child updates.""" """Handle child updates."""
self.async_set_context(event.context) self.async_set_context(event.context)
await self.async_defer_or_update_ha_state() self.async_defer_or_update_ha_state()
self.async_on_remove( self.async_on_remove(
async_track_state_change_event( async_track_state_change_event(
@ -101,13 +102,10 @@ class BinarySensorGroup(GroupEntity, BinarySensorEntity):
) )
) )
if self.hass.state == CoreState.running:
await self.async_update()
return
await super().async_added_to_hass() await super().async_added_to_hass()
async def async_update(self) -> None: @callback
def async_update_group_state(self) -> None:
"""Query all members and determine the binary sensor group state.""" """Query all members and determine the binary sensor group state."""
all_states = [self.hass.states.get(x) for x in self._entity_ids] all_states = [self.hass.states.get(x) for x in self._entity_ids]
filtered_states: list[str] = [x.state for x in all_states if x is not None] filtered_states: list[str] = [x.state for x in all_states if x is not None]
@ -120,7 +118,6 @@ class BinarySensorGroup(GroupEntity, BinarySensorEntity):
states = list(map(lambda x: x == STATE_ON, filtered_states)) states = list(map(lambda x: x == STATE_ON, filtered_states))
state = self.mode(states) state = self.mode(states)
self._attr_is_on = state self._attr_is_on = state
self.async_write_ha_state()
@property @property
def device_class(self) -> str | None: def device_class(self) -> str | None:

View file

@ -42,7 +42,7 @@ from homeassistant.const import (
STATE_OPEN, STATE_OPEN,
STATE_OPENING, STATE_OPENING,
) )
from homeassistant.core import CoreState, Event, HomeAssistant, State from homeassistant.core import Event, HomeAssistant, State, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.event import async_track_state_change_event
@ -110,14 +110,14 @@ class CoverGroup(GroupEntity, CoverEntity):
self._attr_extra_state_attributes = {ATTR_ENTITY_ID: entities} self._attr_extra_state_attributes = {ATTR_ENTITY_ID: entities}
self._attr_unique_id = unique_id self._attr_unique_id = unique_id
async def _update_supported_features_event(self, event: Event) -> None: @callback
def _update_supported_features_event(self, event: Event) -> None:
self.async_set_context(event.context) self.async_set_context(event.context)
if (entity := event.data.get("entity_id")) is not None: if (entity := event.data.get("entity_id")) is not None:
await self.async_update_supported_features( self.async_update_supported_features(entity, event.data.get("new_state"))
entity, event.data.get("new_state")
)
async def async_update_supported_features( @callback
def async_update_supported_features(
self, self,
entity_id: str, entity_id: str,
new_state: State | None, new_state: State | None,
@ -130,7 +130,7 @@ class CoverGroup(GroupEntity, CoverEntity):
for values in self._tilts.values(): for values in self._tilts.values():
values.discard(entity_id) values.discard(entity_id)
if update_state: if update_state:
await self.async_defer_or_update_ha_state() self.async_defer_or_update_ha_state()
return return
features = new_state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) features = new_state.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
@ -162,14 +162,14 @@ class CoverGroup(GroupEntity, CoverEntity):
self._tilts[KEY_POSITION].discard(entity_id) self._tilts[KEY_POSITION].discard(entity_id)
if update_state: if update_state:
await self.async_defer_or_update_ha_state() self.async_defer_or_update_ha_state()
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Register listeners.""" """Register listeners."""
for entity_id in self._entities: for entity_id in self._entities:
if (new_state := self.hass.states.get(entity_id)) is None: if (new_state := self.hass.states.get(entity_id)) is None:
continue continue
await self.async_update_supported_features( self.async_update_supported_features(
entity_id, new_state, update_state=False entity_id, new_state, update_state=False
) )
self.async_on_remove( self.async_on_remove(
@ -178,9 +178,6 @@ class CoverGroup(GroupEntity, CoverEntity):
) )
) )
if self.hass.state == CoreState.running:
await self.async_update()
return
await super().async_added_to_hass() await super().async_added_to_hass()
async def async_open_cover(self, **kwargs: Any) -> None: async def async_open_cover(self, **kwargs: Any) -> None:
@ -253,7 +250,8 @@ class CoverGroup(GroupEntity, CoverEntity):
context=self._context, context=self._context,
) )
async def async_update(self) -> None: @callback
def async_update_group_state(self) -> None:
"""Update state and attributes.""" """Update state and attributes."""
self._attr_assumed_state = False self._attr_assumed_state = False

View file

@ -34,7 +34,7 @@ from homeassistant.const import (
CONF_UNIQUE_ID, CONF_UNIQUE_ID,
STATE_ON, STATE_ON,
) )
from homeassistant.core import CoreState, Event, HomeAssistant, State from homeassistant.core import Event, HomeAssistant, State, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.event import async_track_state_change_event
@ -125,14 +125,14 @@ class FanGroup(GroupEntity, FanEntity):
"""Return whether or not the fan is currently oscillating.""" """Return whether or not the fan is currently oscillating."""
return self._oscillating return self._oscillating
async def _update_supported_features_event(self, event: Event) -> None: @callback
def _update_supported_features_event(self, event: Event) -> None:
self.async_set_context(event.context) self.async_set_context(event.context)
if (entity := event.data.get("entity_id")) is not None: if (entity := event.data.get("entity_id")) is not None:
await self.async_update_supported_features( self.async_update_supported_features(entity, event.data.get("new_state"))
entity, event.data.get("new_state")
)
async def async_update_supported_features( @callback
def async_update_supported_features(
self, self,
entity_id: str, entity_id: str,
new_state: State | None, new_state: State | None,
@ -151,14 +151,14 @@ class FanGroup(GroupEntity, FanEntity):
self._fans[feature].discard(entity_id) self._fans[feature].discard(entity_id)
if update_state: if update_state:
await self.async_defer_or_update_ha_state() self.async_defer_or_update_ha_state()
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Register listeners.""" """Register listeners."""
for entity_id in self._entities: for entity_id in self._entities:
if (new_state := self.hass.states.get(entity_id)) is None: if (new_state := self.hass.states.get(entity_id)) is None:
continue continue
await self.async_update_supported_features( self.async_update_supported_features(
entity_id, new_state, update_state=False entity_id, new_state, update_state=False
) )
self.async_on_remove( self.async_on_remove(
@ -167,9 +167,6 @@ class FanGroup(GroupEntity, FanEntity):
) )
) )
if self.hass.state == CoreState.running:
await self.async_update()
return
await super().async_added_to_hass() await super().async_added_to_hass()
async def async_set_percentage(self, percentage: int) -> None: async def async_set_percentage(self, percentage: int) -> None:
@ -244,7 +241,8 @@ class FanGroup(GroupEntity, FanEntity):
setattr(self, attr, most_frequent_attribute(states, entity_attr)) setattr(self, attr, most_frequent_attribute(states, entity_attr))
self._attr_assumed_state |= not attribute_equal(states, entity_attr) self._attr_assumed_state |= not attribute_equal(states, entity_attr)
async def async_update(self) -> None: @callback
def async_update_group_state(self) -> None:
"""Update state and attributes.""" """Update state and attributes."""
self._attr_assumed_state = False self._attr_assumed_state = False

View file

@ -47,7 +47,7 @@ from homeassistant.const import (
STATE_ON, STATE_ON,
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
) )
from homeassistant.core import CoreState, Event, HomeAssistant, State from homeassistant.core import Event, HomeAssistant, State, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.event import async_track_state_change_event
@ -129,10 +129,11 @@ class LightGroup(GroupEntity, LightEntity):
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Register callbacks.""" """Register callbacks."""
async def async_state_changed_listener(event: Event) -> None: @callback
def async_state_changed_listener(event: Event) -> None:
"""Handle child updates.""" """Handle child updates."""
self.async_set_context(event.context) self.async_set_context(event.context)
await self.async_defer_or_update_ha_state() self.async_defer_or_update_ha_state()
self.async_on_remove( self.async_on_remove(
async_track_state_change_event( async_track_state_change_event(
@ -140,10 +141,6 @@ class LightGroup(GroupEntity, LightEntity):
) )
) )
if self.hass.state == CoreState.running:
await self.async_update()
return
await super().async_added_to_hass() await super().async_added_to_hass()
@property @property
@ -183,7 +180,8 @@ class LightGroup(GroupEntity, LightEntity):
context=self._context, context=self._context,
) )
async def async_update(self) -> None: @callback
def async_update_group_state(self) -> None:
"""Query all members and determine the light group state.""" """Query all members and determine the light group state."""
all_states = [self.hass.states.get(x) for x in self._entity_ids] all_states = [self.hass.states.get(x) for x in self._entity_ids]
states: list[State] = list(filter(None, all_states)) states: list[State] = list(filter(None, all_states))