From 2f1138562720cd50343d2fedd4981913a9ef6bd9 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Mon, 24 Oct 2022 15:00:37 +0200 Subject: [PATCH] Add typing hints for MQTT mixins (#80702) * Add typing hints for MQTT mixins * Follow up comments * config_entry is always set * typing discovery_data - substate None assignment * Rename `config[CONF_DEVICE]` -> specifications --- homeassistant/components/mqtt/cover.py | 14 +- homeassistant/components/mqtt/discovery.py | 13 +- homeassistant/components/mqtt/mixins.py | 205 ++++++++++++--------- homeassistant/components/mqtt/models.py | 4 +- homeassistant/components/mqtt/update.py | 2 - 5 files changed, 132 insertions(+), 106 deletions(-) diff --git a/homeassistant/components/mqtt/cover.py b/homeassistant/components/mqtt/cover.py index 11901f15054..7d7d4f61c4a 100644 --- a/homeassistant/components/mqtt/cover.py +++ b/homeassistant/components/mqtt/cover.py @@ -552,7 +552,7 @@ class MqttCover(MqttEntity, CoverEntity): This method is a coroutine. """ await self.async_publish( - self._config.get(CONF_COMMAND_TOPIC), + self._config[CONF_COMMAND_TOPIC], self._config[CONF_PAYLOAD_OPEN], self._config[CONF_QOS], self._config[CONF_RETAIN], @@ -573,7 +573,7 @@ class MqttCover(MqttEntity, CoverEntity): This method is a coroutine. """ await self.async_publish( - self._config.get(CONF_COMMAND_TOPIC), + self._config[CONF_COMMAND_TOPIC], self._config[CONF_PAYLOAD_CLOSE], self._config[CONF_QOS], self._config[CONF_RETAIN], @@ -594,7 +594,7 @@ class MqttCover(MqttEntity, CoverEntity): This method is a coroutine. """ await self.async_publish( - self._config.get(CONF_COMMAND_TOPIC), + self._config[CONF_COMMAND_TOPIC], self._config[CONF_PAYLOAD_STOP], self._config[CONF_QOS], self._config[CONF_RETAIN], @@ -614,7 +614,7 @@ class MqttCover(MqttEntity, CoverEntity): } tilt_payload = self._set_tilt_template(tilt_open_position, variables=variables) await self.async_publish( - self._config.get(CONF_TILT_COMMAND_TOPIC), + self._config[CONF_TILT_COMMAND_TOPIC], tilt_payload, self._config[CONF_QOS], self._config[CONF_RETAIN], @@ -641,7 +641,7 @@ class MqttCover(MqttEntity, CoverEntity): tilt_closed_position, variables=variables ) await self.async_publish( - self._config.get(CONF_TILT_COMMAND_TOPIC), + self._config[CONF_TILT_COMMAND_TOPIC], tilt_payload, self._config[CONF_QOS], self._config[CONF_RETAIN], @@ -670,7 +670,7 @@ class MqttCover(MqttEntity, CoverEntity): tilt = self._set_tilt_template(tilt, variables=variables) await self.async_publish( - self._config.get(CONF_TILT_COMMAND_TOPIC), + self._config[CONF_TILT_COMMAND_TOPIC], tilt, self._config[CONF_QOS], self._config[CONF_RETAIN], @@ -697,7 +697,7 @@ class MqttCover(MqttEntity, CoverEntity): position = self._set_position_template(position, variables=variables) await self.async_publish( - self._config.get(CONF_SET_POSITION_TOPIC), + self._config[CONF_SET_POSITION_TOPIC], position, self._config[CONF_QOS], self._config[CONF_RETAIN], diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index 0aa288e700a..84f14d26146 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -7,6 +7,7 @@ import functools import logging import re import time +from typing import Any from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_DEVICE, CONF_PLATFORM @@ -19,7 +20,7 @@ from homeassistant.helpers.dispatcher import ( ) from homeassistant.helpers.json import json_loads from homeassistant.helpers.service_info.mqtt import MqttServiceInfo -from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType +from homeassistant.helpers.typing import DiscoveryInfoType from homeassistant.loader import async_get_mqtt from .. import mqtt @@ -73,8 +74,8 @@ MQTT_DISCOVERY_DONE = "mqtt_discovery_done_{}" TOPIC_BASE = "~" -class MQTTConfig(dict): - """Dummy class to allow adding attributes.""" +class MQTTDiscoveryPayload(dict[str, Any]): + """Class to hold and MQTT discovery payload and discovery data.""" discovery_data: DiscoveryInfoType @@ -96,7 +97,7 @@ async def async_start( # noqa: C901 mqtt_data = get_mqtt_data(hass) mqtt_integrations = {} - async def async_discovery_message_received(msg): + async def async_discovery_message_received(msg) -> None: """Process the received message.""" mqtt_data.last_discovery = time.time() payload = msg.payload @@ -126,7 +127,7 @@ async def async_start( # noqa: C901 _LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload) return - payload = MQTTConfig(payload) + payload = MQTTDiscoveryPayload(payload) for key in list(payload): abbreviated_key = key @@ -195,7 +196,7 @@ async def async_start( # noqa: C901 await async_process_discovery_payload(component, discovery_id, payload) async def async_process_discovery_payload( - component: str, discovery_id: str, payload: ConfigType + component: str, discovery_id: str, payload: MQTTDiscoveryPayload ) -> None: """Process the payload of a new discovery.""" diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index b5c870a196e..7866e3cf6d6 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -34,7 +34,10 @@ from homeassistant.helpers import ( device_registry as dr, entity_registry as er, ) -from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED +from homeassistant.helpers.device_registry import ( + EVENT_DEVICE_REGISTRY_UPDATED, + DeviceEntry, +) from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, async_dispatcher_send, @@ -50,6 +53,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.event import async_track_entity_registry_updated_event from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.json import json_loads +from homeassistant.helpers.service_info.mqtt import ReceivePayloadType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import debug_info, subscription @@ -74,11 +78,13 @@ from .discovery import ( MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_NEW, MQTT_DISCOVERY_UPDATED, + MQTTDiscoveryPayload, clear_discovery_hash, set_discovery_hash, ) from .models import MqttValueTemplate, PublishPayloadType, ReceiveMessage from .subscription import ( + EntitySubscription, async_prepare_subscribe_topics, async_subscribe_topics, async_unsubscribe_topics, @@ -222,7 +228,7 @@ MQTT_ENTITY_COMMON_SCHEMA = MQTT_AVAILABILITY_SCHEMA.extend( ) -def warn_for_legacy_schema(domain: str) -> Callable: +def warn_for_legacy_schema(domain: str) -> Callable[[ConfigType], ConfigType]: """Warn once when a legacy platform schema is used.""" warned = set() @@ -269,8 +275,8 @@ class SetupEntity(Protocol): hass: HomeAssistant, async_add_entities: AddEntitiesCallback, config: ConfigType, - config_entry: ConfigEntry | None = None, - discovery_data: dict[str, Any] | None = None, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None = None, ) -> None: """Define setup_entities type.""" @@ -294,13 +300,13 @@ async def async_get_platform_config_from_yaml( async def async_setup_entry_helper( hass: HomeAssistant, domain: str, - async_setup: partial[Coroutine[HomeAssistant, str, None]], + async_setup: partial[Coroutine[Any, Any, None]], discovery_schema: vol.Schema, ) -> None: """Set up entity, automation or tag creation dynamically through MQTT discovery.""" mqtt_data = get_mqtt_data(hass) - async def async_discover(discovery_payload): + async def async_discover(discovery_payload: MQTTDiscoveryPayload) -> None: """Discover and add an MQTT entity, automation or tag.""" if not mqtt_config_entry_enabled(hass): _LOGGER.warning( @@ -312,10 +318,10 @@ async def async_setup_entry_helper( return discovery_data = discovery_payload.discovery_data try: - config = discovery_schema(discovery_payload) + config: DiscoveryInfoType = discovery_schema(discovery_payload) await async_setup(config, discovery_data=discovery_data) except Exception: - discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] + discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH] clear_discovery_hash(hass, discovery_hash) async_dispatcher_send( hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None @@ -357,7 +363,7 @@ async def async_setup_entry_helper( async def async_setup_platform_helper( hass: HomeAssistant, platform_domain: str, - config: ConfigType | DiscoveryInfoType, + config: ConfigType, async_add_entities: AddEntitiesCallback, async_setup_entities: SetupEntity, ) -> None: @@ -381,7 +387,9 @@ async def async_setup_platform_helper( await async_setup_entities(hass, async_add_entities, config, config_entry) -def init_entity_id_from_config(hass, entity, config, entity_id_format): +def init_entity_id_from_config( + hass: HomeAssistant, entity: Entity, config: ConfigType, entity_id_format: str +) -> None: """Set entity_id from object_id if defined in config.""" if CONF_OBJECT_ID in config: entity.entity_id = async_generate_entity_id( @@ -394,10 +402,10 @@ class MqttAttributes(Entity): _attributes_extra_blocked: frozenset[str] = frozenset() - def __init__(self, config: dict) -> None: + def __init__(self, config: ConfigType) -> None: """Initialize the JSON attributes mixin.""" self._attributes: dict[str, Any] | None = None - self._attributes_sub_state = None + self._attributes_sub_state: dict[str, EntitySubscription] = {} self._attributes_config = config async def async_added_to_hass(self) -> None: @@ -406,16 +414,16 @@ class MqttAttributes(Entity): self._attributes_prepare_subscribe_topics() await self._attributes_subscribe_topics() - def attributes_prepare_discovery_update(self, config: dict): + def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None: """Handle updated discovery message.""" self._attributes_config = config self._attributes_prepare_subscribe_topics() - async def attributes_discovery_update(self, config: dict): + async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None: """Handle updated discovery message.""" await self._attributes_subscribe_topics() - def _attributes_prepare_subscribe_topics(self): + def _attributes_prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" attr_tpl = MqttValueTemplate( self._attributes_config.get(CONF_JSON_ATTRS_TEMPLATE), entity=self @@ -458,11 +466,11 @@ class MqttAttributes(Entity): }, ) - async def _attributes_subscribe_topics(self): + async def _attributes_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" await async_subscribe_topics(self.hass, self._attributes_sub_state) - async def async_will_remove_from_hass(self): + async def async_will_remove_from_hass(self) -> None: """Unsubscribe when removed.""" self._attributes_sub_state = async_unsubscribe_topics( self.hass, self._attributes_sub_state @@ -477,11 +485,11 @@ class MqttAttributes(Entity): class MqttAvailability(Entity): """Mixin used for platforms that report availability.""" - def __init__(self, config: dict) -> None: + def __init__(self, config: ConfigType) -> None: """Initialize the availability mixin.""" - self._availability_sub_state = None - self._available: dict = {} - self._available_latest = False + self._availability_sub_state: dict[str, EntitySubscription] = {} + self._available: dict[str, str | bool] = {} + self._available_latest: bool = False self._availability_setup_from_config(config) async def async_added_to_hass(self) -> None: @@ -498,18 +506,18 @@ class MqttAvailability(Entity): ) ) - def availability_prepare_discovery_update(self, config: dict): + def availability_prepare_discovery_update(self, config: DiscoveryInfoType) -> None: """Handle updated discovery message.""" self._availability_setup_from_config(config) self._availability_prepare_subscribe_topics() - async def availability_discovery_update(self, config: dict): + async def availability_discovery_update(self, config: DiscoveryInfoType) -> None: """Handle updated discovery message.""" await self._availability_subscribe_topics() - def _availability_setup_from_config(self, config): + def _availability_setup_from_config(self, config: ConfigType) -> None: """(Re)Setup.""" - self._avail_topics = {} + self._avail_topics: dict[str, dict[str, Any]] = {} if CONF_AVAILABILITY_TOPIC in config: self._avail_topics[config[CONF_AVAILABILITY_TOPIC]] = { CONF_PAYLOAD_AVAILABLE: config[CONF_PAYLOAD_AVAILABLE], @@ -518,6 +526,7 @@ class MqttAvailability(Entity): } if CONF_AVAILABILITY in config: + avail: dict[str, Any] for avail in config[CONF_AVAILABILITY]: self._avail_topics[avail[CONF_TOPIC]] = { CONF_PAYLOAD_AVAILABLE: avail[CONF_PAYLOAD_AVAILABLE], @@ -533,7 +542,7 @@ class MqttAvailability(Entity): self._avail_config = config - def _availability_prepare_subscribe_topics(self): + def _availability_prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" @callback @@ -541,6 +550,7 @@ class MqttAvailability(Entity): def availability_message_received(msg: ReceiveMessage) -> None: """Handle a new received MQTT availability message.""" topic = msg.topic + payload: ReceivePayloadType payload = self._avail_topics[topic][CONF_AVAILABILITY_TEMPLATE](msg.payload) if payload == self._avail_topics[topic][CONF_PAYLOAD_AVAILABLE]: self._available[topic] = True @@ -555,7 +565,7 @@ class MqttAvailability(Entity): topic: (self._available[topic] if topic in self._available else False) for topic in self._avail_topics } - topics = { + topics: dict[str, dict[str, Any]] = { f"availability_{topic}": { "topic": topic, "msg_callback": availability_message_received, @@ -571,17 +581,17 @@ class MqttAvailability(Entity): topics, ) - async def _availability_subscribe_topics(self): + async def _availability_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" await async_subscribe_topics(self.hass, self._availability_sub_state) @callback - def async_mqtt_connect(self): + def async_mqtt_connect(self) -> None: """Update state on connection/disconnection to MQTT broker.""" if not self.hass.is_stopping: self.async_write_ha_state() - async def async_will_remove_from_hass(self): + async def async_will_remove_from_hass(self) -> None: """Unsubscribe when removed.""" self._availability_sub_state = async_unsubscribe_topics( self.hass, self._availability_sub_state @@ -628,12 +638,12 @@ async def cleanup_device_registry( ) -def get_discovery_hash(discovery_data: dict) -> tuple[str, str]: +def get_discovery_hash(discovery_data: DiscoveryInfoType) -> tuple[str, str]: """Get the discovery hash from the discovery data.""" return discovery_data[ATTR_DISCOVERY_HASH] -def send_discovery_done(hass: HomeAssistant, discovery_data: dict) -> None: +def send_discovery_done(hass: HomeAssistant, discovery_data: DiscoveryInfoType) -> None: """Acknowledge a discovery message has been handled.""" discovery_hash = get_discovery_hash(discovery_data) async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) @@ -641,7 +651,7 @@ def send_discovery_done(hass: HomeAssistant, discovery_data: dict) -> None: def stop_discovery_updates( hass: HomeAssistant, - discovery_data: dict, + discovery_data: DiscoveryInfoType, remove_discovery_updated: Callable[[], None] | None = None, ) -> None: """Stop discovery updates of being sent.""" @@ -660,7 +670,7 @@ async def async_remove_discovery_payload(hass: HomeAssistant, discovery_data: di async def async_clear_discovery_topic_if_entity_removed( hass: HomeAssistant, - discovery_data: dict[str, Any], + discovery_data: DiscoveryInfoType, event: Event, ) -> None: """Clear the discovery topic if the entity is removed.""" @@ -675,7 +685,7 @@ class MqttDiscoveryDeviceUpdate: def __init__( self, hass: HomeAssistant, - discovery_data: dict, + discovery_data: DiscoveryInfoType, device_id: str | None, config_entry: ConfigEntry, log_name: str, @@ -718,7 +728,7 @@ class MqttDiscoveryDeviceUpdate: async def async_discovery_update( self, - discovery_payload: DiscoveryInfoType | None, + discovery_payload: MQTTDiscoveryPayload, ) -> None: """Handle discovery update.""" discovery_hash = get_discovery_hash(self._discovery_data) @@ -789,7 +799,7 @@ class MqttDiscoveryDeviceUpdate: self.hass, self._device_id, self._config_entry_id ) - async def async_update(self, discovery_data: dict) -> None: + async def async_update(self, discovery_data: MQTTDiscoveryPayload) -> None: """Handle the update of platform specific parts, extend to the platform.""" @abstractmethod @@ -803,8 +813,9 @@ class MqttDiscoveryUpdate(Entity): def __init__( self, hass: HomeAssistant, - discovery_data: dict | None, - discovery_update: Callable | None = None, + discovery_data: DiscoveryInfoType | None, + discovery_update: Callable[[MQTTDiscoveryPayload], Coroutine[Any, Any, None]] + | None = None, ) -> None: """Initialize the discovery update mixin.""" self._discovery_data = discovery_data @@ -823,11 +834,13 @@ class MqttDiscoveryUpdate(Entity): """Subscribe to discovery updates.""" await super().async_added_to_hass() self._removed_from_hass = False - discovery_hash = ( + discovery_hash: tuple[str, str] | None = ( self._discovery_data[ATTR_DISCOVERY_HASH] if self._discovery_data else None ) - async def _async_remove_state_and_registry_entry(self) -> None: + async def _async_remove_state_and_registry_entry( + self: MqttDiscoveryUpdate, + ) -> None: """Remove entity's state and entity registry entry. Remove entity from entity registry if it is registered, this also removes the state. @@ -842,13 +855,15 @@ class MqttDiscoveryUpdate(Entity): else: await self.async_remove(force_remove=True) - async def discovery_callback(payload): + async def discovery_callback(payload: MQTTDiscoveryPayload) -> None: """Handle discovery update.""" _LOGGER.info( "Got update for entity with hash: %s '%s'", discovery_hash, payload, ) + assert self._discovery_data + old_payload: DiscoveryInfoType old_payload = self._discovery_data[ATTR_DISCOVERY_PAYLOAD] debug_info.update_entity_discovery_data(self.hass, payload, self.entity_id) if not payload: @@ -923,39 +938,43 @@ class MqttDiscoveryUpdate(Entity): self._removed_from_hass = True -def device_info_from_config(config) -> DeviceInfo | None: +def device_info_from_specifications( + specifications: dict[str, Any] | None +) -> DeviceInfo | None: """Return a device description for device registry.""" - if not config: + if not specifications: return None info = DeviceInfo( - identifiers={(DOMAIN, id_) for id_ in config[CONF_IDENTIFIERS]}, - connections={(conn_[0], conn_[1]) for conn_ in config[CONF_CONNECTIONS]}, + identifiers={(DOMAIN, id_) for id_ in specifications[CONF_IDENTIFIERS]}, + connections={ + (conn_[0], conn_[1]) for conn_ in specifications[CONF_CONNECTIONS] + }, ) - if CONF_MANUFACTURER in config: - info[ATTR_MANUFACTURER] = config[CONF_MANUFACTURER] + if CONF_MANUFACTURER in specifications: + info[ATTR_MANUFACTURER] = specifications[CONF_MANUFACTURER] - if CONF_MODEL in config: - info[ATTR_MODEL] = config[CONF_MODEL] + if CONF_MODEL in specifications: + info[ATTR_MODEL] = specifications[CONF_MODEL] - if CONF_NAME in config: - info[ATTR_NAME] = config[CONF_NAME] + if CONF_NAME in specifications: + info[ATTR_NAME] = specifications[CONF_NAME] - if CONF_HW_VERSION in config: - info[ATTR_HW_VERSION] = config[CONF_HW_VERSION] + if CONF_HW_VERSION in specifications: + info[ATTR_HW_VERSION] = specifications[CONF_HW_VERSION] - if CONF_SW_VERSION in config: - info[ATTR_SW_VERSION] = config[CONF_SW_VERSION] + if CONF_SW_VERSION in specifications: + info[ATTR_SW_VERSION] = specifications[CONF_SW_VERSION] - if CONF_VIA_DEVICE in config: - info[ATTR_VIA_DEVICE] = (DOMAIN, config[CONF_VIA_DEVICE]) + if CONF_VIA_DEVICE in specifications: + info[ATTR_VIA_DEVICE] = (DOMAIN, specifications[CONF_VIA_DEVICE]) - if CONF_SUGGESTED_AREA in config: - info[ATTR_SUGGESTED_AREA] = config[CONF_SUGGESTED_AREA] + if CONF_SUGGESTED_AREA in specifications: + info[ATTR_SUGGESTED_AREA] = specifications[CONF_SUGGESTED_AREA] - if CONF_CONFIGURATION_URL in config: - info[ATTR_CONFIGURATION_URL] = config[CONF_CONFIGURATION_URL] + if CONF_CONFIGURATION_URL in specifications: + info[ATTR_CONFIGURATION_URL] = specifications[CONF_CONFIGURATION_URL] return info @@ -963,19 +982,21 @@ def device_info_from_config(config) -> DeviceInfo | None: class MqttEntityDeviceInfo(Entity): """Mixin used for mqtt platforms that support the device registry.""" - def __init__(self, device_config: ConfigType | None, config_entry=None) -> None: + def __init__( + self, specifications: dict[str, Any] | None, config_entry: ConfigEntry + ) -> None: """Initialize the device mixin.""" - self._device_config = device_config + self._device_specifications = specifications self._config_entry = config_entry - def device_info_discovery_update(self, config: dict): + def device_info_discovery_update(self, config: DiscoveryInfoType) -> None: """Handle updated discovery message.""" - self._device_config = config.get(CONF_DEVICE) + self._device_specifications = config.get(CONF_DEVICE) device_registry = dr.async_get(self.hass) config_entry_id = self._config_entry.entry_id device_info = self.device_info - if config_entry_id is not None and device_info is not None: + if device_info is not None: device_registry.async_get_or_create( config_entry_id=config_entry_id, **device_info ) @@ -983,7 +1004,7 @@ class MqttEntityDeviceInfo(Entity): @property def device_info(self) -> DeviceInfo | None: """Return a device description for device registry.""" - return device_info_from_config(self._device_config) + return device_info_from_specifications(self._device_specifications) class MqttEntity( @@ -997,12 +1018,18 @@ class MqttEntity( _attr_should_poll = False _entity_id_format: str - def __init__(self, hass, config, config_entry, discovery_data): + def __init__( + self, + hass: HomeAssistant, + config: ConfigType, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None, + ) -> None: """Init the MQTT Entity.""" self.hass = hass - self._config = config - self._unique_id = config.get(CONF_UNIQUE_ID) - self._sub_state = None + self._config: ConfigType = config + self._unique_id: str | None = config.get(CONF_UNIQUE_ID) + self._sub_state: dict[str, EntitySubscription] = {} # Load config self._setup_from_config(self._config) @@ -1016,14 +1043,14 @@ class MqttEntity( MqttDiscoveryUpdate.__init__(self, hass, discovery_data, self.discovery_update) MqttEntityDeviceInfo.__init__(self, config.get(CONF_DEVICE), config_entry) - def _init_entity_id(self): + def _init_entity_id(self) -> None: """Set entity_id from object_id if defined in config.""" init_entity_id_from_config( self.hass, self, self._config, self._entity_id_format ) @final - async def async_added_to_hass(self): + async def async_added_to_hass(self) -> None: """Subscribe to MQTT events.""" await super().async_added_to_hass() self._prepare_subscribe_topics() @@ -1032,15 +1059,15 @@ class MqttEntity( if self._discovery_data is not None: send_discovery_done(self.hass, self._discovery_data) - async def mqtt_async_added_to_hass(self): + async def mqtt_async_added_to_hass(self) -> None: """Call before the discovery message is acknowledged. To be extended by subclasses. """ - async def discovery_update(self, discovery_payload): + async def discovery_update(self, discovery_payload: MQTTDiscoveryPayload) -> None: """Handle updated discovery message.""" - config = self.config_schema()(discovery_payload) + config: DiscoveryInfoType = self.config_schema()(discovery_payload) self._config = config self._setup_from_config(self._config) @@ -1056,7 +1083,7 @@ class MqttEntity( await self._subscribe_topics() self.async_write_ha_state() - async def async_will_remove_from_hass(self): + async def async_will_remove_from_hass(self) -> None: """Unsubscribe when removed.""" self._sub_state = subscription.async_unsubscribe_topics( self.hass, self._sub_state @@ -1073,7 +1100,7 @@ class MqttEntity( qos: int = 0, retain: bool = False, encoding: str = DEFAULT_ENCODING, - ): + ) -> None: """Publish message to an MQTT topic.""" log_message(self.hass, self.entity_id, topic, payload, qos, retain) await async_publish( @@ -1087,18 +1114,18 @@ class MqttEntity( @staticmethod @abstractmethod - def config_schema(): + def config_schema() -> vol.Schema: """Return the config schema.""" - def _setup_from_config(self, config): + def _setup_from_config(self, config: ConfigType) -> None: """(Re)Setup the entity.""" @abstractmethod - def _prepare_subscribe_topics(self): + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" @abstractmethod - async def _subscribe_topics(self): + async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" @property @@ -1112,17 +1139,17 @@ class MqttEntity( return self._config.get(CONF_ENTITY_CATEGORY) @property - def icon(self): + def icon(self) -> str | None: """Return icon of the entity if any.""" return self._config.get(CONF_ICON) @property - def name(self): + def name(self) -> str | None: """Return the name of the device if any.""" return self._config.get(CONF_NAME) @property - def unique_id(self): + def unique_id(self) -> str | None: """Return a unique ID.""" return self._unique_id @@ -1136,10 +1163,10 @@ def update_device( if CONF_DEVICE not in config: return None - device = None + device: DeviceEntry | None = None device_registry = dr.async_get(hass) config_entry_id = config_entry.entry_id - device_info = device_info_from_config(config[CONF_DEVICE]) + device_info = device_info_from_specifications(config[CONF_DEVICE]) if config_entry_id is not None and device_info is not None: update_device_info = cast(dict, device_info) @@ -1154,7 +1181,7 @@ def async_removed_from_device( hass: HomeAssistant, event: Event, mqtt_device_id: str, config_entry_id: str ) -> bool: """Check if the passed event indicates MQTT was removed from a device.""" - device_id = event.data["device_id"] + device_id: str = event.data["device_id"] if event.data["action"] not in ("remove", "update"): return False diff --git a/homeassistant/components/mqtt/models.py b/homeassistant/components/mqtt/models.py index f2f30419b4c..363956cc732 100644 --- a/homeassistant/components/mqtt/models.py +++ b/homeassistant/components/mqtt/models.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from .client import MQTT, Subscription from .debug_info import TimestampedPublishMessage from .device_trigger import Trigger - from .discovery import MQTTConfig + from .discovery import MQTTDiscoveryPayload from .tag import MQTTTagScanner _SENTINEL = object() @@ -86,7 +86,7 @@ class TriggerDebugInfo(TypedDict): class PendingDiscovered(TypedDict): """Pending discovered items.""" - pending: deque[MQTTConfig] + pending: deque[MQTTDiscoveryPayload] unsub: CALLBACK_TYPE diff --git a/homeassistant/components/mqtt/update.py b/homeassistant/components/mqtt/update.py index ac8b5431a59..8fdc6393e0b 100644 --- a/homeassistant/components/mqtt/update.py +++ b/homeassistant/components/mqtt/update.py @@ -98,8 +98,6 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): ) -> None: """Initialize the MQTT update.""" self._config = config - self._sub_state = None - self._attr_device_class = self._config.get(CONF_DEVICE_CLASS) UpdateEntity.__init__(self)