Improve type hints in group (#78350)
This commit is contained in:
parent
03a24e3a05
commit
5cccb24830
3 changed files with 99 additions and 77 deletions
|
@ -3,10 +3,10 @@ from __future__ import annotations
|
|||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Collection, Iterable
|
||||
from contextvars import ContextVar
|
||||
import logging
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, Protocol, Union, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -27,7 +27,15 @@ from homeassistant.const import (
|
|||
STATE_ON,
|
||||
Platform,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, ServiceCall, callback, split_entity_id
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
Event,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
State,
|
||||
callback,
|
||||
split_entity_id,
|
||||
)
|
||||
from homeassistant.helpers import config_validation as cv, entity_registry as er, start
|
||||
from homeassistant.helpers.entity import Entity, async_generate_entity_id
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
@ -42,8 +50,6 @@ from homeassistant.loader import bind_hass
|
|||
|
||||
from .const import CONF_HIDE_MEMBERS
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
DOMAIN = "group"
|
||||
GROUP_ORDER = "group_order"
|
||||
|
||||
|
@ -79,10 +85,19 @@ _LOGGER = logging.getLogger(__name__)
|
|||
current_domain: ContextVar[str] = ContextVar("current_domain")
|
||||
|
||||
|
||||
def _conf_preprocess(value):
|
||||
class GroupProtocol(Protocol):
|
||||
"""Define the format of group platforms."""
|
||||
|
||||
def async_describe_on_off_states(
|
||||
self, hass: HomeAssistant, registry: GroupIntegrationRegistry
|
||||
) -> None:
|
||||
"""Describe group on off states."""
|
||||
|
||||
|
||||
def _conf_preprocess(value: Any) -> dict[str, Any]:
|
||||
"""Preprocess alternative configuration formats."""
|
||||
if not isinstance(value, dict):
|
||||
value = {CONF_ENTITIES: value}
|
||||
return {CONF_ENTITIES: value}
|
||||
|
||||
return value
|
||||
|
||||
|
@ -135,14 +150,15 @@ class GroupIntegrationRegistry:
|
|||
|
||||
|
||||
@bind_hass
|
||||
def is_on(hass, entity_id):
|
||||
def is_on(hass: HomeAssistant, entity_id: str) -> bool:
|
||||
"""Test if the group state is in its ON-state."""
|
||||
if REG_KEY not in hass.data:
|
||||
# Integration not setup yet, it cannot be on
|
||||
return False
|
||||
|
||||
if (state := hass.states.get(entity_id)) is not None:
|
||||
return state.state in hass.data[REG_KEY].on_off_mapping
|
||||
registry: GroupIntegrationRegistry = hass.data[REG_KEY]
|
||||
return state.state in registry.on_off_mapping
|
||||
|
||||
return False
|
||||
|
||||
|
@ -408,10 +424,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
async def _process_group_platform(hass, domain, platform):
|
||||
async def _process_group_platform(
|
||||
hass: HomeAssistant, domain: str, platform: GroupProtocol
|
||||
) -> None:
|
||||
"""Process a group platform."""
|
||||
current_domain.set(domain)
|
||||
platform.async_describe_on_off_states(hass, hass.data[REG_KEY])
|
||||
registry: GroupIntegrationRegistry = hass.data[REG_KEY]
|
||||
platform.async_describe_on_off_states(hass, registry)
|
||||
|
||||
|
||||
async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None:
|
||||
|
@ -423,7 +442,7 @@ async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None
|
|||
|
||||
for object_id, conf in domain_config.items():
|
||||
name: str = conf.get(CONF_NAME, object_id)
|
||||
entity_ids: Iterable[str] = conf.get(CONF_ENTITIES) or []
|
||||
entity_ids: Collection[str] = conf.get(CONF_ENTITIES) or []
|
||||
icon: str | None = conf.get(CONF_ICON)
|
||||
mode = bool(conf.get(CONF_ALL))
|
||||
order: int = hass.data[GROUP_ORDER]
|
||||
|
@ -456,15 +475,12 @@ async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None
|
|||
class GroupEntity(Entity):
|
||||
"""Representation of a Group of entities."""
|
||||
|
||||
@property
|
||||
def should_poll(self) -> bool:
|
||||
"""Disable polling for group."""
|
||||
return False
|
||||
_attr_should_poll = False
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Register listeners."""
|
||||
|
||||
async def _update_at_start(_):
|
||||
async def _update_at_start(_: HomeAssistant) -> None:
|
||||
self.async_update_group_state()
|
||||
self.async_write_ha_state()
|
||||
|
||||
|
@ -487,6 +503,10 @@ class GroupEntity(Entity):
|
|||
class Group(Entity):
|
||||
"""Track a group of entity ids."""
|
||||
|
||||
_attr_should_poll = False
|
||||
tracking: tuple[str, ...]
|
||||
trackable: tuple[str, ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
|
@ -494,7 +514,7 @@ class Group(Entity):
|
|||
order: int | None = None,
|
||||
icon: str | None = None,
|
||||
user_defined: bool = True,
|
||||
entity_ids: Iterable[str] | None = None,
|
||||
entity_ids: Collection[str] | None = None,
|
||||
mode: bool | None = None,
|
||||
) -> None:
|
||||
"""Initialize a group.
|
||||
|
@ -503,25 +523,25 @@ class Group(Entity):
|
|||
"""
|
||||
self.hass = hass
|
||||
self._name = name
|
||||
self._state = None
|
||||
self._state: str | None = None
|
||||
self._icon = icon
|
||||
self._set_tracked(entity_ids)
|
||||
self._on_off = None
|
||||
self._assumed = None
|
||||
self._on_states = None
|
||||
self._on_off: dict[str, bool] = {}
|
||||
self._assumed: dict[str, bool] = {}
|
||||
self._on_states: set[str] = set()
|
||||
self.user_defined = user_defined
|
||||
self.mode = any
|
||||
if mode:
|
||||
self.mode = all
|
||||
self._order = order
|
||||
self._assumed_state = False
|
||||
self._async_unsub_state_changed = None
|
||||
self._async_unsub_state_changed: CALLBACK_TYPE | None = None
|
||||
|
||||
@staticmethod
|
||||
def create_group(
|
||||
hass: HomeAssistant,
|
||||
name: str,
|
||||
entity_ids: Iterable[str] | None = None,
|
||||
entity_ids: Collection[str] | None = None,
|
||||
user_defined: bool = True,
|
||||
icon: str | None = None,
|
||||
object_id: str | None = None,
|
||||
|
@ -541,7 +561,7 @@ class Group(Entity):
|
|||
def async_create_group_entity(
|
||||
hass: HomeAssistant,
|
||||
name: str,
|
||||
entity_ids: Iterable[str] | None = None,
|
||||
entity_ids: Collection[str] | None = None,
|
||||
user_defined: bool = True,
|
||||
icon: str | None = None,
|
||||
object_id: str | None = None,
|
||||
|
@ -577,7 +597,7 @@ class Group(Entity):
|
|||
async def async_create_group(
|
||||
hass: HomeAssistant,
|
||||
name: str,
|
||||
entity_ids: Iterable[str] | None = None,
|
||||
entity_ids: Collection[str] | None = None,
|
||||
user_defined: bool = True,
|
||||
icon: str | None = None,
|
||||
object_id: str | None = None,
|
||||
|
@ -597,37 +617,32 @@ class Group(Entity):
|
|||
return group
|
||||
|
||||
@property
|
||||
def should_poll(self):
|
||||
"""No need to poll because groups will update themselves."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
"""Return the name of the group."""
|
||||
return self._name
|
||||
|
||||
@name.setter
|
||||
def name(self, value):
|
||||
def name(self, value: str) -> None:
|
||||
"""Set Group name."""
|
||||
self._name = value
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
def state(self) -> str | None:
|
||||
"""Return the state of the group."""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def icon(self):
|
||||
def icon(self) -> str | None:
|
||||
"""Return the icon of the group."""
|
||||
return self._icon
|
||||
|
||||
@icon.setter
|
||||
def icon(self, value):
|
||||
def icon(self, value: str | None) -> None:
|
||||
"""Set Icon for group."""
|
||||
self._icon = value
|
||||
|
||||
@property
|
||||
def extra_state_attributes(self):
|
||||
def extra_state_attributes(self) -> dict[str, Any]:
|
||||
"""Return the state attributes for the group."""
|
||||
data = {ATTR_ENTITY_ID: self.tracking, ATTR_ORDER: self._order}
|
||||
if not self.user_defined:
|
||||
|
@ -636,17 +651,19 @@ class Group(Entity):
|
|||
return data
|
||||
|
||||
@property
|
||||
def assumed_state(self):
|
||||
def assumed_state(self) -> bool:
|
||||
"""Test if any member has an assumed state."""
|
||||
return self._assumed_state
|
||||
|
||||
def update_tracked_entity_ids(self, entity_ids):
|
||||
def update_tracked_entity_ids(self, entity_ids: Collection[str] | None) -> None:
|
||||
"""Update the member entity IDs."""
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.async_update_tracked_entity_ids(entity_ids), self.hass.loop
|
||||
).result()
|
||||
|
||||
async def async_update_tracked_entity_ids(self, entity_ids):
|
||||
async def async_update_tracked_entity_ids(
|
||||
self, entity_ids: Collection[str] | None
|
||||
) -> None:
|
||||
"""Update the member entity IDs.
|
||||
|
||||
This method must be run in the event loop.
|
||||
|
@ -656,7 +673,7 @@ class Group(Entity):
|
|||
self._reset_tracked_state()
|
||||
self._async_start()
|
||||
|
||||
def _set_tracked(self, entity_ids):
|
||||
def _set_tracked(self, entity_ids: Collection[str] | None) -> None:
|
||||
"""Tuple of entities to be tracked."""
|
||||
# tracking are the entities we want to track
|
||||
# trackable are the entities we actually watch
|
||||
|
@ -666,10 +683,11 @@ class Group(Entity):
|
|||
self.trackable = ()
|
||||
return
|
||||
|
||||
excluded_domains = self.hass.data[REG_KEY].exclude_domains
|
||||
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
|
||||
excluded_domains = registry.exclude_domains
|
||||
|
||||
tracking = []
|
||||
trackable = []
|
||||
tracking: list[str] = []
|
||||
trackable: list[str] = []
|
||||
for ent_id in entity_ids:
|
||||
ent_id_lower = ent_id.lower()
|
||||
domain = split_entity_id(ent_id_lower)[0]
|
||||
|
@ -681,14 +699,14 @@ class Group(Entity):
|
|||
self.tracking = tuple(tracking)
|
||||
|
||||
@callback
|
||||
def _async_start(self, *_):
|
||||
def _async_start(self, _: HomeAssistant | None = None) -> None:
|
||||
"""Start tracking members and write state."""
|
||||
self._reset_tracked_state()
|
||||
self._async_start_tracking()
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def _async_start_tracking(self):
|
||||
def _async_start_tracking(self) -> None:
|
||||
"""Start tracking members.
|
||||
|
||||
This method must be run in the event loop.
|
||||
|
@ -701,7 +719,7 @@ class Group(Entity):
|
|||
self._async_update_group_state()
|
||||
|
||||
@callback
|
||||
def _async_stop(self):
|
||||
def _async_stop(self) -> None:
|
||||
"""Unregister the group from Home Assistant.
|
||||
|
||||
This method must be run in the event loop.
|
||||
|
@ -711,20 +729,20 @@ class Group(Entity):
|
|||
self._async_unsub_state_changed = None
|
||||
|
||||
@callback
|
||||
def async_update_group_state(self):
|
||||
def async_update_group_state(self) -> None:
|
||||
"""Query all members and determine current group state."""
|
||||
self._state = None
|
||||
self._async_update_group_state()
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Handle addition to Home Assistant."""
|
||||
self.async_on_remove(start.async_at_start(self.hass, self._async_start))
|
||||
|
||||
async def async_will_remove_from_hass(self):
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Handle removal from Home Assistant."""
|
||||
self._async_stop()
|
||||
|
||||
async def _async_state_changed_listener(self, event):
|
||||
async def _async_state_changed_listener(self, event: Event) -> None:
|
||||
"""Respond to a member state changing.
|
||||
|
||||
This method must be run in the event loop.
|
||||
|
@ -742,7 +760,7 @@ class Group(Entity):
|
|||
self._async_update_group_state(new_state)
|
||||
self.async_write_ha_state()
|
||||
|
||||
def _reset_tracked_state(self):
|
||||
def _reset_tracked_state(self) -> None:
|
||||
"""Reset tracked state."""
|
||||
self._on_off = {}
|
||||
self._assumed = {}
|
||||
|
@ -752,13 +770,13 @@ class Group(Entity):
|
|||
if (state := self.hass.states.get(entity_id)) is not None:
|
||||
self._see_state(state)
|
||||
|
||||
def _see_state(self, new_state):
|
||||
def _see_state(self, new_state: State) -> None:
|
||||
"""Keep track of the the state."""
|
||||
entity_id = new_state.entity_id
|
||||
domain = new_state.domain
|
||||
state = new_state.state
|
||||
registry = self.hass.data[REG_KEY]
|
||||
self._assumed[entity_id] = new_state.attributes.get(ATTR_ASSUMED_STATE)
|
||||
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
|
||||
self._assumed[entity_id] = bool(new_state.attributes.get(ATTR_ASSUMED_STATE))
|
||||
|
||||
if domain not in registry.on_states_by_domain:
|
||||
# Handle the group of a group case
|
||||
|
@ -769,12 +787,12 @@ class Group(Entity):
|
|||
self._on_off[entity_id] = state in registry.on_off_mapping
|
||||
else:
|
||||
entity_on_state = registry.on_states_by_domain[domain]
|
||||
if domain in self.hass.data[REG_KEY].on_states_by_domain:
|
||||
if domain in registry.on_states_by_domain:
|
||||
self._on_states.update(entity_on_state)
|
||||
self._on_off[entity_id] = state in entity_on_state
|
||||
|
||||
@callback
|
||||
def _async_update_group_state(self, tr_state=None):
|
||||
def _async_update_group_state(self, tr_state: State | None = None) -> None:
|
||||
"""Update group state.
|
||||
|
||||
Optionally you can provide the only state changed since last update
|
||||
|
@ -818,4 +836,5 @@ class Group(Entity):
|
|||
if group_is_on:
|
||||
self._state = on_state
|
||||
else:
|
||||
self._state = self.hass.data[REG_KEY].on_off_mapping[on_state]
|
||||
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
|
||||
self._state = registry.on_off_mapping[on_state]
|
||||
|
|
|
@ -103,6 +103,7 @@ class MediaPlayerGroup(MediaPlayerEntity):
|
|||
"""Representation of a Media Group."""
|
||||
|
||||
_attr_available: bool = False
|
||||
_attr_should_poll = False
|
||||
|
||||
def __init__(self, unique_id: str | None, name: str, entities: list[str]) -> None:
|
||||
"""Initialize a Media Group entity."""
|
||||
|
@ -216,11 +217,6 @@ class MediaPlayerGroup(MediaPlayerEntity):
|
|||
"""Flag supported features."""
|
||||
return self._supported_features
|
||||
|
||||
@property
|
||||
def should_poll(self) -> bool:
|
||||
"""No polling needed for a media group."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def extra_state_attributes(self) -> dict:
|
||||
"""Return the state attributes for the media group."""
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
"""Group platform for notify component."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Coroutine, Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -13,9 +16,9 @@ from homeassistant.components.notify import (
|
|||
BaseNotificationService,
|
||||
)
|
||||
from homeassistant.const import ATTR_SERVICE
|
||||
from homeassistant.core import HomeAssistant
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
CONF_SERVICES = "services"
|
||||
|
||||
|
@ -29,46 +32,50 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
|||
)
|
||||
|
||||
|
||||
def update(input_dict, update_source):
|
||||
def update(input_dict: dict[str, Any], update_source: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Deep update a dictionary.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
for key, val in update_source.items():
|
||||
if isinstance(val, Mapping):
|
||||
recurse = update(input_dict.get(key, {}), val)
|
||||
recurse = update(input_dict.get(key, {}), val) # type: ignore[arg-type]
|
||||
input_dict[key] = recurse
|
||||
else:
|
||||
input_dict[key] = update_source[key]
|
||||
return input_dict
|
||||
|
||||
|
||||
async def async_get_service(hass, config, discovery_info=None):
|
||||
async def async_get_service(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> GroupNotifyPlatform:
|
||||
"""Get the Group notification service."""
|
||||
return GroupNotifyPlatform(hass, config.get(CONF_SERVICES))
|
||||
return GroupNotifyPlatform(hass, config[CONF_SERVICES])
|
||||
|
||||
|
||||
class GroupNotifyPlatform(BaseNotificationService):
|
||||
"""Implement the notification service for the group notify platform."""
|
||||
|
||||
def __init__(self, hass, entities):
|
||||
def __init__(self, hass: HomeAssistant, entities: list[dict[str, Any]]) -> None:
|
||||
"""Initialize the service."""
|
||||
self.hass = hass
|
||||
self.entities = entities
|
||||
|
||||
async def async_send_message(self, message="", **kwargs):
|
||||
async def async_send_message(self, message: str = "", **kwargs: Any) -> None:
|
||||
"""Send message to all entities in the group."""
|
||||
payload = {ATTR_MESSAGE: message}
|
||||
payload: dict[str, Any] = {ATTR_MESSAGE: message}
|
||||
payload.update({key: val for key, val in kwargs.items() if val})
|
||||
|
||||
tasks = []
|
||||
tasks: list[Coroutine[Any, Any, bool | None]] = []
|
||||
for entity in self.entities:
|
||||
sending_payload = deepcopy(payload.copy())
|
||||
if entity.get(ATTR_DATA) is not None:
|
||||
update(sending_payload, entity.get(ATTR_DATA))
|
||||
if (data := entity.get(ATTR_DATA)) is not None:
|
||||
update(sending_payload, data)
|
||||
tasks.append(
|
||||
self.hass.services.async_call(
|
||||
DOMAIN, entity.get(ATTR_SERVICE), sending_payload
|
||||
DOMAIN, entity[ATTR_SERVICE], sending_payload
|
||||
)
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue