diff --git a/homeassistant/components/group/__init__.py b/homeassistant/components/group/__init__.py index 04bc109d15b..c1759432ade 100644 --- a/homeassistant/components/group/__init__.py +++ b/homeassistant/components/group/__init__.py @@ -3,10 +3,10 @@ from __future__ import annotations from abc import abstractmethod import asyncio -from collections.abc import Iterable +from collections.abc import Collection, Iterable from contextvars import ContextVar import logging -from typing import Any, Union, cast +from typing import Any, Protocol, Union, cast import voluptuous as vol @@ -27,7 +27,15 @@ from homeassistant.const import ( STATE_ON, Platform, ) -from homeassistant.core import HomeAssistant, ServiceCall, callback, split_entity_id +from homeassistant.core import ( + CALLBACK_TYPE, + Event, + HomeAssistant, + ServiceCall, + State, + callback, + split_entity_id, +) from homeassistant.helpers import config_validation as cv, entity_registry as er, start from homeassistant.helpers.entity import Entity, async_generate_entity_id from homeassistant.helpers.entity_component import EntityComponent @@ -42,8 +50,6 @@ from homeassistant.loader import bind_hass from .const import CONF_HIDE_MEMBERS -# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs - DOMAIN = "group" GROUP_ORDER = "group_order" @@ -79,10 +85,19 @@ _LOGGER = logging.getLogger(__name__) current_domain: ContextVar[str] = ContextVar("current_domain") -def _conf_preprocess(value): +class GroupProtocol(Protocol): + """Define the format of group platforms.""" + + def async_describe_on_off_states( + self, hass: HomeAssistant, registry: GroupIntegrationRegistry + ) -> None: + """Describe group on off states.""" + + +def _conf_preprocess(value: Any) -> dict[str, Any]: """Preprocess alternative configuration formats.""" if not isinstance(value, dict): - value = {CONF_ENTITIES: value} + return {CONF_ENTITIES: value} return value @@ -135,14 +150,15 @@ class GroupIntegrationRegistry: @bind_hass -def is_on(hass, entity_id): +def is_on(hass: HomeAssistant, entity_id: str) -> bool: """Test if the group state is in its ON-state.""" if REG_KEY not in hass.data: # Integration not setup yet, it cannot be on return False if (state := hass.states.get(entity_id)) is not None: - return state.state in hass.data[REG_KEY].on_off_mapping + registry: GroupIntegrationRegistry = hass.data[REG_KEY] + return state.state in registry.on_off_mapping return False @@ -408,10 +424,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -async def _process_group_platform(hass, domain, platform): +async def _process_group_platform( + hass: HomeAssistant, domain: str, platform: GroupProtocol +) -> None: """Process a group platform.""" current_domain.set(domain) - platform.async_describe_on_off_states(hass, hass.data[REG_KEY]) + registry: GroupIntegrationRegistry = hass.data[REG_KEY] + platform.async_describe_on_off_states(hass, registry) async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None: @@ -423,7 +442,7 @@ async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None for object_id, conf in domain_config.items(): name: str = conf.get(CONF_NAME, object_id) - entity_ids: Iterable[str] = conf.get(CONF_ENTITIES) or [] + entity_ids: Collection[str] = conf.get(CONF_ENTITIES) or [] icon: str | None = conf.get(CONF_ICON) mode = bool(conf.get(CONF_ALL)) order: int = hass.data[GROUP_ORDER] @@ -456,15 +475,12 @@ async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None class GroupEntity(Entity): """Representation of a Group of entities.""" - @property - def should_poll(self) -> bool: - """Disable polling for group.""" - return False + _attr_should_poll = False async def async_added_to_hass(self) -> None: """Register listeners.""" - async def _update_at_start(_): + async def _update_at_start(_: HomeAssistant) -> None: self.async_update_group_state() self.async_write_ha_state() @@ -487,6 +503,10 @@ class GroupEntity(Entity): class Group(Entity): """Track a group of entity ids.""" + _attr_should_poll = False + tracking: tuple[str, ...] + trackable: tuple[str, ...] + def __init__( self, hass: HomeAssistant, @@ -494,7 +514,7 @@ class Group(Entity): order: int | None = None, icon: str | None = None, user_defined: bool = True, - entity_ids: Iterable[str] | None = None, + entity_ids: Collection[str] | None = None, mode: bool | None = None, ) -> None: """Initialize a group. @@ -503,25 +523,25 @@ class Group(Entity): """ self.hass = hass self._name = name - self._state = None + self._state: str | None = None self._icon = icon self._set_tracked(entity_ids) - self._on_off = None - self._assumed = None - self._on_states = None + self._on_off: dict[str, bool] = {} + self._assumed: dict[str, bool] = {} + self._on_states: set[str] = set() self.user_defined = user_defined self.mode = any if mode: self.mode = all self._order = order self._assumed_state = False - self._async_unsub_state_changed = None + self._async_unsub_state_changed: CALLBACK_TYPE | None = None @staticmethod def create_group( hass: HomeAssistant, name: str, - entity_ids: Iterable[str] | None = None, + entity_ids: Collection[str] | None = None, user_defined: bool = True, icon: str | None = None, object_id: str | None = None, @@ -541,7 +561,7 @@ class Group(Entity): def async_create_group_entity( hass: HomeAssistant, name: str, - entity_ids: Iterable[str] | None = None, + entity_ids: Collection[str] | None = None, user_defined: bool = True, icon: str | None = None, object_id: str | None = None, @@ -577,7 +597,7 @@ class Group(Entity): async def async_create_group( hass: HomeAssistant, name: str, - entity_ids: Iterable[str] | None = None, + entity_ids: Collection[str] | None = None, user_defined: bool = True, icon: str | None = None, object_id: str | None = None, @@ -597,37 +617,32 @@ class Group(Entity): return group @property - def should_poll(self): - """No need to poll because groups will update themselves.""" - return False - - @property - def name(self): + def name(self) -> str: """Return the name of the group.""" return self._name @name.setter - def name(self, value): + def name(self, value: str) -> None: """Set Group name.""" self._name = value @property - def state(self): + def state(self) -> str | None: """Return the state of the group.""" return self._state @property - def icon(self): + def icon(self) -> str | None: """Return the icon of the group.""" return self._icon @icon.setter - def icon(self, value): + def icon(self, value: str | None) -> None: """Set Icon for group.""" self._icon = value @property - def extra_state_attributes(self): + def extra_state_attributes(self) -> dict[str, Any]: """Return the state attributes for the group.""" data = {ATTR_ENTITY_ID: self.tracking, ATTR_ORDER: self._order} if not self.user_defined: @@ -636,17 +651,19 @@ class Group(Entity): return data @property - def assumed_state(self): + def assumed_state(self) -> bool: """Test if any member has an assumed state.""" return self._assumed_state - def update_tracked_entity_ids(self, entity_ids): + def update_tracked_entity_ids(self, entity_ids: Collection[str] | None) -> None: """Update the member entity IDs.""" asyncio.run_coroutine_threadsafe( self.async_update_tracked_entity_ids(entity_ids), self.hass.loop ).result() - async def async_update_tracked_entity_ids(self, entity_ids): + async def async_update_tracked_entity_ids( + self, entity_ids: Collection[str] | None + ) -> None: """Update the member entity IDs. This method must be run in the event loop. @@ -656,7 +673,7 @@ class Group(Entity): self._reset_tracked_state() self._async_start() - def _set_tracked(self, entity_ids): + def _set_tracked(self, entity_ids: Collection[str] | None) -> None: """Tuple of entities to be tracked.""" # tracking are the entities we want to track # trackable are the entities we actually watch @@ -666,10 +683,11 @@ class Group(Entity): self.trackable = () return - excluded_domains = self.hass.data[REG_KEY].exclude_domains + registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] + excluded_domains = registry.exclude_domains - tracking = [] - trackable = [] + tracking: list[str] = [] + trackable: list[str] = [] for ent_id in entity_ids: ent_id_lower = ent_id.lower() domain = split_entity_id(ent_id_lower)[0] @@ -681,14 +699,14 @@ class Group(Entity): self.tracking = tuple(tracking) @callback - def _async_start(self, *_): + def _async_start(self, _: HomeAssistant | None = None) -> None: """Start tracking members and write state.""" self._reset_tracked_state() self._async_start_tracking() self.async_write_ha_state() @callback - def _async_start_tracking(self): + def _async_start_tracking(self) -> None: """Start tracking members. This method must be run in the event loop. @@ -701,7 +719,7 @@ class Group(Entity): self._async_update_group_state() @callback - def _async_stop(self): + def _async_stop(self) -> None: """Unregister the group from Home Assistant. This method must be run in the event loop. @@ -711,20 +729,20 @@ class Group(Entity): self._async_unsub_state_changed = None @callback - def async_update_group_state(self): + def async_update_group_state(self) -> None: """Query all members and determine current group state.""" self._state = None self._async_update_group_state() - async def async_added_to_hass(self): + async def async_added_to_hass(self) -> None: """Handle addition to Home Assistant.""" self.async_on_remove(start.async_at_start(self.hass, self._async_start)) - async def async_will_remove_from_hass(self): + async def async_will_remove_from_hass(self) -> None: """Handle removal from Home Assistant.""" self._async_stop() - async def _async_state_changed_listener(self, event): + async def _async_state_changed_listener(self, event: Event) -> None: """Respond to a member state changing. This method must be run in the event loop. @@ -742,7 +760,7 @@ class Group(Entity): self._async_update_group_state(new_state) self.async_write_ha_state() - def _reset_tracked_state(self): + def _reset_tracked_state(self) -> None: """Reset tracked state.""" self._on_off = {} self._assumed = {} @@ -752,13 +770,13 @@ class Group(Entity): if (state := self.hass.states.get(entity_id)) is not None: self._see_state(state) - def _see_state(self, new_state): + def _see_state(self, new_state: State) -> None: """Keep track of the the state.""" entity_id = new_state.entity_id domain = new_state.domain state = new_state.state - registry = self.hass.data[REG_KEY] - self._assumed[entity_id] = new_state.attributes.get(ATTR_ASSUMED_STATE) + registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] + self._assumed[entity_id] = bool(new_state.attributes.get(ATTR_ASSUMED_STATE)) if domain not in registry.on_states_by_domain: # Handle the group of a group case @@ -769,12 +787,12 @@ class Group(Entity): self._on_off[entity_id] = state in registry.on_off_mapping else: entity_on_state = registry.on_states_by_domain[domain] - if domain in self.hass.data[REG_KEY].on_states_by_domain: + if domain in registry.on_states_by_domain: self._on_states.update(entity_on_state) self._on_off[entity_id] = state in entity_on_state @callback - def _async_update_group_state(self, tr_state=None): + def _async_update_group_state(self, tr_state: State | None = None) -> None: """Update group state. Optionally you can provide the only state changed since last update @@ -818,4 +836,5 @@ class Group(Entity): if group_is_on: self._state = on_state else: - self._state = self.hass.data[REG_KEY].on_off_mapping[on_state] + registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] + self._state = registry.on_off_mapping[on_state] diff --git a/homeassistant/components/group/media_player.py b/homeassistant/components/group/media_player.py index cbce44a359a..ddb44072080 100644 --- a/homeassistant/components/group/media_player.py +++ b/homeassistant/components/group/media_player.py @@ -103,6 +103,7 @@ class MediaPlayerGroup(MediaPlayerEntity): """Representation of a Media Group.""" _attr_available: bool = False + _attr_should_poll = False def __init__(self, unique_id: str | None, name: str, entities: list[str]) -> None: """Initialize a Media Group entity.""" @@ -216,11 +217,6 @@ class MediaPlayerGroup(MediaPlayerEntity): """Flag supported features.""" return self._supported_features - @property - def should_poll(self) -> bool: - """No polling needed for a media group.""" - return False - @property def extra_state_attributes(self) -> dict: """Return the state attributes for the media group.""" diff --git a/homeassistant/components/group/notify.py b/homeassistant/components/group/notify.py index 97c019163e8..7c4bc0c65c4 100644 --- a/homeassistant/components/group/notify.py +++ b/homeassistant/components/group/notify.py @@ -1,7 +1,10 @@ """Group platform for notify component.""" +from __future__ import annotations + import asyncio -from collections.abc import Mapping +from collections.abc import Coroutine, Mapping from copy import deepcopy +from typing import Any import voluptuous as vol @@ -13,9 +16,9 @@ from homeassistant.components.notify import ( BaseNotificationService, ) from homeassistant.const import ATTR_SERVICE +from homeassistant.core import HomeAssistant import homeassistant.helpers.config_validation as cv - -# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType CONF_SERVICES = "services" @@ -29,46 +32,50 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( ) -def update(input_dict, update_source): +def update(input_dict: dict[str, Any], update_source: dict[str, Any]) -> dict[str, Any]: """Deep update a dictionary. Async friendly. """ for key, val in update_source.items(): if isinstance(val, Mapping): - recurse = update(input_dict.get(key, {}), val) + recurse = update(input_dict.get(key, {}), val) # type: ignore[arg-type] input_dict[key] = recurse else: input_dict[key] = update_source[key] return input_dict -async def async_get_service(hass, config, discovery_info=None): +async def async_get_service( + hass: HomeAssistant, + config: ConfigType, + discovery_info: DiscoveryInfoType | None = None, +) -> GroupNotifyPlatform: """Get the Group notification service.""" - return GroupNotifyPlatform(hass, config.get(CONF_SERVICES)) + return GroupNotifyPlatform(hass, config[CONF_SERVICES]) class GroupNotifyPlatform(BaseNotificationService): """Implement the notification service for the group notify platform.""" - def __init__(self, hass, entities): + def __init__(self, hass: HomeAssistant, entities: list[dict[str, Any]]) -> None: """Initialize the service.""" self.hass = hass self.entities = entities - async def async_send_message(self, message="", **kwargs): + async def async_send_message(self, message: str = "", **kwargs: Any) -> None: """Send message to all entities in the group.""" - payload = {ATTR_MESSAGE: message} + payload: dict[str, Any] = {ATTR_MESSAGE: message} payload.update({key: val for key, val in kwargs.items() if val}) - tasks = [] + tasks: list[Coroutine[Any, Any, bool | None]] = [] for entity in self.entities: sending_payload = deepcopy(payload.copy()) - if entity.get(ATTR_DATA) is not None: - update(sending_payload, entity.get(ATTR_DATA)) + if (data := entity.get(ATTR_DATA)) is not None: + update(sending_payload, data) tasks.append( self.hass.services.async_call( - DOMAIN, entity.get(ATTR_SERVICE), sending_payload + DOMAIN, entity[ATTR_SERVICE], sending_payload ) )