Improve type hints in group (#78350)

This commit is contained in:
epenet 2022-09-14 11:36:28 +02:00 committed by GitHub
parent 03a24e3a05
commit 5cccb24830
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 77 deletions

View file

@ -3,10 +3,10 @@ from __future__ import annotations
from abc import abstractmethod
import asyncio
from collections.abc import Iterable
from collections.abc import Collection, Iterable
from contextvars import ContextVar
import logging
from typing import Any, Union, cast
from typing import Any, Protocol, Union, cast
import voluptuous as vol
@ -27,7 +27,15 @@ from homeassistant.const import (
STATE_ON,
Platform,
)
from homeassistant.core import HomeAssistant, ServiceCall, callback, split_entity_id
from homeassistant.core import (
CALLBACK_TYPE,
Event,
HomeAssistant,
ServiceCall,
State,
callback,
split_entity_id,
)
from homeassistant.helpers import config_validation as cv, entity_registry as er, start
from homeassistant.helpers.entity import Entity, async_generate_entity_id
from homeassistant.helpers.entity_component import EntityComponent
@ -42,8 +50,6 @@ from homeassistant.loader import bind_hass
from .const import CONF_HIDE_MEMBERS
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
DOMAIN = "group"
GROUP_ORDER = "group_order"
@ -79,10 +85,19 @@ _LOGGER = logging.getLogger(__name__)
current_domain: ContextVar[str] = ContextVar("current_domain")
def _conf_preprocess(value):
class GroupProtocol(Protocol):
"""Define the format of group platforms."""
def async_describe_on_off_states(
self, hass: HomeAssistant, registry: GroupIntegrationRegistry
) -> None:
"""Describe group on off states."""
def _conf_preprocess(value: Any) -> dict[str, Any]:
"""Preprocess alternative configuration formats."""
if not isinstance(value, dict):
value = {CONF_ENTITIES: value}
return {CONF_ENTITIES: value}
return value
@ -135,14 +150,15 @@ class GroupIntegrationRegistry:
@bind_hass
def is_on(hass, entity_id):
def is_on(hass: HomeAssistant, entity_id: str) -> bool:
"""Test if the group state is in its ON-state."""
if REG_KEY not in hass.data:
# Integration not setup yet, it cannot be on
return False
if (state := hass.states.get(entity_id)) is not None:
return state.state in hass.data[REG_KEY].on_off_mapping
registry: GroupIntegrationRegistry = hass.data[REG_KEY]
return state.state in registry.on_off_mapping
return False
@ -408,10 +424,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
async def _process_group_platform(hass, domain, platform):
async def _process_group_platform(
hass: HomeAssistant, domain: str, platform: GroupProtocol
) -> None:
"""Process a group platform."""
current_domain.set(domain)
platform.async_describe_on_off_states(hass, hass.data[REG_KEY])
registry: GroupIntegrationRegistry = hass.data[REG_KEY]
platform.async_describe_on_off_states(hass, registry)
async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None:
@ -423,7 +442,7 @@ async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None
for object_id, conf in domain_config.items():
name: str = conf.get(CONF_NAME, object_id)
entity_ids: Iterable[str] = conf.get(CONF_ENTITIES) or []
entity_ids: Collection[str] = conf.get(CONF_ENTITIES) or []
icon: str | None = conf.get(CONF_ICON)
mode = bool(conf.get(CONF_ALL))
order: int = hass.data[GROUP_ORDER]
@ -456,15 +475,12 @@ async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None
class GroupEntity(Entity):
"""Representation of a Group of entities."""
@property
def should_poll(self) -> bool:
"""Disable polling for group."""
return False
_attr_should_poll = False
async def async_added_to_hass(self) -> None:
"""Register listeners."""
async def _update_at_start(_):
async def _update_at_start(_: HomeAssistant) -> None:
self.async_update_group_state()
self.async_write_ha_state()
@ -487,6 +503,10 @@ class GroupEntity(Entity):
class Group(Entity):
"""Track a group of entity ids."""
_attr_should_poll = False
tracking: tuple[str, ...]
trackable: tuple[str, ...]
def __init__(
self,
hass: HomeAssistant,
@ -494,7 +514,7 @@ class Group(Entity):
order: int | None = None,
icon: str | None = None,
user_defined: bool = True,
entity_ids: Iterable[str] | None = None,
entity_ids: Collection[str] | None = None,
mode: bool | None = None,
) -> None:
"""Initialize a group.
@ -503,25 +523,25 @@ class Group(Entity):
"""
self.hass = hass
self._name = name
self._state = None
self._state: str | None = None
self._icon = icon
self._set_tracked(entity_ids)
self._on_off = None
self._assumed = None
self._on_states = None
self._on_off: dict[str, bool] = {}
self._assumed: dict[str, bool] = {}
self._on_states: set[str] = set()
self.user_defined = user_defined
self.mode = any
if mode:
self.mode = all
self._order = order
self._assumed_state = False
self._async_unsub_state_changed = None
self._async_unsub_state_changed: CALLBACK_TYPE | None = None
@staticmethod
def create_group(
hass: HomeAssistant,
name: str,
entity_ids: Iterable[str] | None = None,
entity_ids: Collection[str] | None = None,
user_defined: bool = True,
icon: str | None = None,
object_id: str | None = None,
@ -541,7 +561,7 @@ class Group(Entity):
def async_create_group_entity(
hass: HomeAssistant,
name: str,
entity_ids: Iterable[str] | None = None,
entity_ids: Collection[str] | None = None,
user_defined: bool = True,
icon: str | None = None,
object_id: str | None = None,
@ -577,7 +597,7 @@ class Group(Entity):
async def async_create_group(
hass: HomeAssistant,
name: str,
entity_ids: Iterable[str] | None = None,
entity_ids: Collection[str] | None = None,
user_defined: bool = True,
icon: str | None = None,
object_id: str | None = None,
@ -597,37 +617,32 @@ class Group(Entity):
return group
@property
def should_poll(self):
"""No need to poll because groups will update themselves."""
return False
@property
def name(self):
def name(self) -> str:
"""Return the name of the group."""
return self._name
@name.setter
def name(self, value):
def name(self, value: str) -> None:
"""Set Group name."""
self._name = value
@property
def state(self):
def state(self) -> str | None:
"""Return the state of the group."""
return self._state
@property
def icon(self):
def icon(self) -> str | None:
"""Return the icon of the group."""
return self._icon
@icon.setter
def icon(self, value):
def icon(self, value: str | None) -> None:
"""Set Icon for group."""
self._icon = value
@property
def extra_state_attributes(self):
def extra_state_attributes(self) -> dict[str, Any]:
"""Return the state attributes for the group."""
data = {ATTR_ENTITY_ID: self.tracking, ATTR_ORDER: self._order}
if not self.user_defined:
@ -636,17 +651,19 @@ class Group(Entity):
return data
@property
def assumed_state(self):
def assumed_state(self) -> bool:
"""Test if any member has an assumed state."""
return self._assumed_state
def update_tracked_entity_ids(self, entity_ids):
def update_tracked_entity_ids(self, entity_ids: Collection[str] | None) -> None:
"""Update the member entity IDs."""
asyncio.run_coroutine_threadsafe(
self.async_update_tracked_entity_ids(entity_ids), self.hass.loop
).result()
async def async_update_tracked_entity_ids(self, entity_ids):
async def async_update_tracked_entity_ids(
self, entity_ids: Collection[str] | None
) -> None:
"""Update the member entity IDs.
This method must be run in the event loop.
@ -656,7 +673,7 @@ class Group(Entity):
self._reset_tracked_state()
self._async_start()
def _set_tracked(self, entity_ids):
def _set_tracked(self, entity_ids: Collection[str] | None) -> None:
"""Tuple of entities to be tracked."""
# tracking are the entities we want to track
# trackable are the entities we actually watch
@ -666,10 +683,11 @@ class Group(Entity):
self.trackable = ()
return
excluded_domains = self.hass.data[REG_KEY].exclude_domains
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
excluded_domains = registry.exclude_domains
tracking = []
trackable = []
tracking: list[str] = []
trackable: list[str] = []
for ent_id in entity_ids:
ent_id_lower = ent_id.lower()
domain = split_entity_id(ent_id_lower)[0]
@ -681,14 +699,14 @@ class Group(Entity):
self.tracking = tuple(tracking)
@callback
def _async_start(self, *_):
def _async_start(self, _: HomeAssistant | None = None) -> None:
"""Start tracking members and write state."""
self._reset_tracked_state()
self._async_start_tracking()
self.async_write_ha_state()
@callback
def _async_start_tracking(self):
def _async_start_tracking(self) -> None:
"""Start tracking members.
This method must be run in the event loop.
@ -701,7 +719,7 @@ class Group(Entity):
self._async_update_group_state()
@callback
def _async_stop(self):
def _async_stop(self) -> None:
"""Unregister the group from Home Assistant.
This method must be run in the event loop.
@ -711,20 +729,20 @@ class Group(Entity):
self._async_unsub_state_changed = None
@callback
def async_update_group_state(self):
def async_update_group_state(self) -> None:
"""Query all members and determine current group state."""
self._state = None
self._async_update_group_state()
async def async_added_to_hass(self):
async def async_added_to_hass(self) -> None:
"""Handle addition to Home Assistant."""
self.async_on_remove(start.async_at_start(self.hass, self._async_start))
async def async_will_remove_from_hass(self):
async def async_will_remove_from_hass(self) -> None:
"""Handle removal from Home Assistant."""
self._async_stop()
async def _async_state_changed_listener(self, event):
async def _async_state_changed_listener(self, event: Event) -> None:
"""Respond to a member state changing.
This method must be run in the event loop.
@ -742,7 +760,7 @@ class Group(Entity):
self._async_update_group_state(new_state)
self.async_write_ha_state()
def _reset_tracked_state(self):
def _reset_tracked_state(self) -> None:
"""Reset tracked state."""
self._on_off = {}
self._assumed = {}
@ -752,13 +770,13 @@ class Group(Entity):
if (state := self.hass.states.get(entity_id)) is not None:
self._see_state(state)
def _see_state(self, new_state):
def _see_state(self, new_state: State) -> None:
"""Keep track of the the state."""
entity_id = new_state.entity_id
domain = new_state.domain
state = new_state.state
registry = self.hass.data[REG_KEY]
self._assumed[entity_id] = new_state.attributes.get(ATTR_ASSUMED_STATE)
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
self._assumed[entity_id] = bool(new_state.attributes.get(ATTR_ASSUMED_STATE))
if domain not in registry.on_states_by_domain:
# Handle the group of a group case
@ -769,12 +787,12 @@ class Group(Entity):
self._on_off[entity_id] = state in registry.on_off_mapping
else:
entity_on_state = registry.on_states_by_domain[domain]
if domain in self.hass.data[REG_KEY].on_states_by_domain:
if domain in registry.on_states_by_domain:
self._on_states.update(entity_on_state)
self._on_off[entity_id] = state in entity_on_state
@callback
def _async_update_group_state(self, tr_state=None):
def _async_update_group_state(self, tr_state: State | None = None) -> None:
"""Update group state.
Optionally you can provide the only state changed since last update
@ -818,4 +836,5 @@ class Group(Entity):
if group_is_on:
self._state = on_state
else:
self._state = self.hass.data[REG_KEY].on_off_mapping[on_state]
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
self._state = registry.on_off_mapping[on_state]

View file

@ -103,6 +103,7 @@ class MediaPlayerGroup(MediaPlayerEntity):
"""Representation of a Media Group."""
_attr_available: bool = False
_attr_should_poll = False
def __init__(self, unique_id: str | None, name: str, entities: list[str]) -> None:
"""Initialize a Media Group entity."""
@ -216,11 +217,6 @@ class MediaPlayerGroup(MediaPlayerEntity):
"""Flag supported features."""
return self._supported_features
@property
def should_poll(self) -> bool:
"""No polling needed for a media group."""
return False
@property
def extra_state_attributes(self) -> dict:
"""Return the state attributes for the media group."""

View file

@ -1,7 +1,10 @@
"""Group platform for notify component."""
from __future__ import annotations
import asyncio
from collections.abc import Mapping
from collections.abc import Coroutine, Mapping
from copy import deepcopy
from typing import Any
import voluptuous as vol
@ -13,9 +16,9 @@ from homeassistant.components.notify import (
BaseNotificationService,
)
from homeassistant.const import ATTR_SERVICE
from homeassistant.core import HomeAssistant
import homeassistant.helpers.config_validation as cv
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
CONF_SERVICES = "services"
@ -29,46 +32,50 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
)
def update(input_dict, update_source):
def update(input_dict: dict[str, Any], update_source: dict[str, Any]) -> dict[str, Any]:
"""Deep update a dictionary.
Async friendly.
"""
for key, val in update_source.items():
if isinstance(val, Mapping):
recurse = update(input_dict.get(key, {}), val)
recurse = update(input_dict.get(key, {}), val) # type: ignore[arg-type]
input_dict[key] = recurse
else:
input_dict[key] = update_source[key]
return input_dict
async def async_get_service(hass, config, discovery_info=None):
async def async_get_service(
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> GroupNotifyPlatform:
"""Get the Group notification service."""
return GroupNotifyPlatform(hass, config.get(CONF_SERVICES))
return GroupNotifyPlatform(hass, config[CONF_SERVICES])
class GroupNotifyPlatform(BaseNotificationService):
"""Implement the notification service for the group notify platform."""
def __init__(self, hass, entities):
def __init__(self, hass: HomeAssistant, entities: list[dict[str, Any]]) -> None:
"""Initialize the service."""
self.hass = hass
self.entities = entities
async def async_send_message(self, message="", **kwargs):
async def async_send_message(self, message: str = "", **kwargs: Any) -> None:
"""Send message to all entities in the group."""
payload = {ATTR_MESSAGE: message}
payload: dict[str, Any] = {ATTR_MESSAGE: message}
payload.update({key: val for key, val in kwargs.items() if val})
tasks = []
tasks: list[Coroutine[Any, Any, bool | None]] = []
for entity in self.entities:
sending_payload = deepcopy(payload.copy())
if entity.get(ATTR_DATA) is not None:
update(sending_payload, entity.get(ATTR_DATA))
if (data := entity.get(ATTR_DATA)) is not None:
update(sending_payload, data)
tasks.append(
self.hass.services.async_call(
DOMAIN, entity.get(ATTR_SERVICE), sending_payload
DOMAIN, entity[ATTR_SERVICE], sending_payload
)
)