From 47dba6f6bc7ab1a1877b53818120f1698952c1fe Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Tue, 8 Nov 2022 12:55:41 +0100 Subject: [PATCH] Improve MQTT type hints part 5 (#80979) * Improve typing scene * Improve typing select * Improve typing sensor * move expire_after - and class level attrs * Follow up comment * Solve type confict * Remove stale sentinel const * Update homeassistant/components/mqtt/sensor.py Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> * Make _expire_after a class attribute * Code styling Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> --- homeassistant/components/mqtt/scene.py | 20 ++++-- homeassistant/components/mqtt/select.py | 70 +++++++++++--------- homeassistant/components/mqtt/sensor.py | 85 +++++++++++++++---------- 3 files changed, 103 insertions(+), 72 deletions(-) diff --git a/homeassistant/components/mqtt/scene.py b/homeassistant/components/mqtt/scene.py index e237d70e903..9eafd0cdd99 100644 --- a/homeassistant/components/mqtt/scene.py +++ b/homeassistant/components/mqtt/scene.py @@ -88,8 +88,8 @@ async def _async_setup_entity( hass: HomeAssistant, async_add_entities: AddEntitiesCallback, config: ConfigType, - config_entry: ConfigEntry | None = None, - discovery_data: dict | None = None, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None = None, ) -> None: """Set up the MQTT scene.""" async_add_entities([MqttScene(hass, config, config_entry, discovery_data)]) @@ -103,23 +103,29 @@ class MqttScene( _entity_id_format = scene.DOMAIN + ".{}" - def __init__(self, hass, config, config_entry, discovery_data): + def __init__( + self, + hass: HomeAssistant, + config: ConfigType, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None, + ) -> None: """Initialize the MQTT scene.""" MqttEntity.__init__(self, hass, config, config_entry, discovery_data) @staticmethod - def config_schema(): + def config_schema() -> vol.Schema: """Return the config schema.""" return DISCOVERY_SCHEMA - def _setup_from_config(self, config): + def _setup_from_config(self, config: ConfigType) -> None: """(Re)Setup the entity.""" self._config = config - def _prepare_subscribe_topics(self): + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - async def _subscribe_topics(self): + async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" async def async_activate(self, **kwargs: Any) -> None: diff --git a/homeassistant/components/mqtt/select.py b/homeassistant/components/mqtt/select.py index 12593550e2f..6dfe5081e74 100644 --- a/homeassistant/components/mqtt/select.py +++ b/homeassistant/components/mqtt/select.py @@ -1,6 +1,7 @@ """Configure select in a device through MQTT topic.""" from __future__ import annotations +from collections.abc import Callable import functools import logging @@ -34,7 +35,13 @@ from .mixins import ( async_setup_platform_helper, warn_for_legacy_schema, ) -from .models import MqttCommandTemplate, MqttValueTemplate +from .models import ( + MqttCommandTemplate, + MqttValueTemplate, + PublishPayloadType, + ReceiveMessage, + ReceivePayloadType, +) from .util import get_mqtt_data _LOGGER = logging.getLogger(__name__) @@ -103,8 +110,8 @@ async def _async_setup_entity( hass: HomeAssistant, async_add_entities: AddEntitiesCallback, config: ConfigType, - config_entry: ConfigEntry | None = None, - discovery_data: dict | None = None, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None = None, ) -> None: """Set up the MQTT select.""" async_add_entities([MqttSelect(hass, config, config_entry, discovery_data)]) @@ -114,53 +121,55 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity): """representation of an MQTT select.""" _entity_id_format = select.ENTITY_ID_FORMAT - _attributes_extra_blocked = MQTT_SELECT_ATTRIBUTES_BLOCKED + _command_template: Callable[[PublishPayloadType], PublishPayloadType] + _value_template: Callable[[ReceivePayloadType], ReceivePayloadType] + _optimistic: bool = False - def __init__(self, hass, config, config_entry, discovery_data): + def __init__( + self, + hass: HomeAssistant, + config: ConfigType, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None, + ) -> None: """Initialize the MQTT select.""" - self._config = config - self._optimistic = False - self._sub_state = None - - self._attr_current_option = None - SelectEntity.__init__(self) MqttEntity.__init__(self, hass, config, config_entry, discovery_data) @staticmethod - def config_schema(): + def config_schema() -> vol.Schema: """Return the config schema.""" return DISCOVERY_SCHEMA - def _setup_from_config(self, config): + def _setup_from_config(self, config: ConfigType) -> None: """(Re)Setup the entity.""" + self._attr_current_option = None self._optimistic = config[CONF_OPTIMISTIC] self._attr_options = config[CONF_OPTIONS] - self._templates = { - CONF_COMMAND_TEMPLATE: MqttCommandTemplate( - config.get(CONF_COMMAND_TEMPLATE), entity=self - ).async_render, - CONF_VALUE_TEMPLATE: MqttValueTemplate( - config.get(CONF_VALUE_TEMPLATE), - entity=self, - ).async_render_with_possible_json_value, - } + self._command_template = MqttCommandTemplate( + config.get(CONF_COMMAND_TEMPLATE), + entity=self, + ).async_render + self._value_template = MqttValueTemplate( + config.get(CONF_VALUE_TEMPLATE), entity=self + ).async_render_with_possible_json_value - def _prepare_subscribe_topics(self): + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" @callback @log_messages(self.hass, self.entity_id) - def message_received(msg): + def message_received(msg: ReceiveMessage) -> None: """Handle new MQTT messages.""" - payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload) - + payload = str(self._value_template(msg.payload)) if payload.lower() == "none": - payload = None + self._attr_current_option = None + get_mqtt_data(self.hass).state_write_requests.write_state_request(self) + return - if payload is not None and payload not in self.options: + if payload not in self.options: _LOGGER.error( "Invalid option for %s: '%s' (valid options: %s)", self.entity_id, @@ -168,7 +177,6 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity): self.options, ) return - self._attr_current_option = payload get_mqtt_data(self.hass).state_write_requests.write_state_request(self) @@ -189,7 +197,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity): }, ) - async def _subscribe_topics(self): + async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" await subscription.async_subscribe_topics(self.hass, self._sub_state) @@ -198,7 +206,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity): async def async_select_option(self, option: str) -> None: """Update the current value.""" - payload = self._templates[CONF_COMMAND_TEMPLATE](option) + payload = self._command_template(option) if self._optimistic: self._attr_current_option = option self.async_write_ha_state() diff --git a/homeassistant/components/mqtt/sensor.py b/homeassistant/components/mqtt/sensor.py index 4c6b5409962..ed65b5a42fe 100644 --- a/homeassistant/components/mqtt/sensor.py +++ b/homeassistant/components/mqtt/sensor.py @@ -1,9 +1,11 @@ """Support for MQTT sensors.""" from __future__ import annotations -from datetime import timedelta +from collections.abc import Callable +from datetime import datetime, timedelta import functools import logging +from typing import Any import voluptuous as vol @@ -15,6 +17,7 @@ from homeassistant.components.sensor import ( STATE_CLASSES_SCHEMA, RestoreSensor, SensorDeviceClass, + SensorExtraStoredData, ) from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( @@ -26,7 +29,7 @@ from homeassistant.const import ( STATE_UNAVAILABLE, STATE_UNKNOWN, ) -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import CALLBACK_TYPE, HomeAssistant, State, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.event import async_track_point_in_utc_time @@ -45,7 +48,12 @@ from .mixins import ( async_setup_platform_helper, warn_for_legacy_schema, ) -from .models import MqttValueTemplate, PayloadSentinel, ReceiveMessage +from .models import ( + MqttValueTemplate, + PayloadSentinel, + ReceiveMessage, + ReceivePayloadType, +) from .util import get_mqtt_data, valid_subscribe_topic _LOGGER = logging.getLogger(__name__) @@ -65,7 +73,7 @@ DEFAULT_NAME = "MQTT Sensor" DEFAULT_FORCE_UPDATE = False -def validate_options(conf): +def validate_options(conf: ConfigType) -> ConfigType: """Validate options. If last reset topic is present it must be same as the state topic. @@ -155,8 +163,8 @@ async def _async_setup_entity( hass: HomeAssistant, async_add_entities: AddEntitiesCallback, config: ConfigType, - config_entry: ConfigEntry | None = None, - discovery_data: dict | None = None, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None = None, ) -> None: """Set up MQTT sensor.""" async_add_entities([MqttSensor(hass, config, config_entry, discovery_data)]) @@ -168,24 +176,29 @@ class MqttSensor(MqttEntity, RestoreSensor): _entity_id_format = ENTITY_ID_FORMAT _attr_last_reset = None _attributes_extra_blocked = MQTT_SENSOR_ATTRIBUTES_BLOCKED + _expire_after: int | None + _expired: bool | None + _template: Callable[[ReceivePayloadType, PayloadSentinel], ReceivePayloadType] + _last_reset_template: Callable[[ReceivePayloadType], ReceivePayloadType] - def __init__(self, hass, config, config_entry, discovery_data): + def __init__( + self, + hass: HomeAssistant, + config: ConfigType, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None, + ) -> None: """Initialize the sensor.""" - self._expiration_trigger = None - - expire_after = config.get(CONF_EXPIRE_AFTER) - if expire_after is not None and expire_after > 0: - self._expired = True - else: - self._expired = None - + self._expiration_trigger: CALLBACK_TYPE | None = None MqttEntity.__init__(self, hass, config, config_entry, discovery_data) async def mqtt_async_added_to_hass(self) -> None: """Restore state for entities with expire_after set.""" + last_state: State | None + last_sensor_data: SensorExtraStoredData | None if ( - (expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None - and expire_after > 0 + (_expire_after := self._expire_after) is not None + and _expire_after > 0 and (last_state := await self.async_get_last_state()) is not None and last_state.state not in [STATE_UNKNOWN, STATE_UNAVAILABLE] and (last_sensor_data := await self.async_get_last_sensor_data()) @@ -194,7 +207,7 @@ class MqttSensor(MqttEntity, RestoreSensor): # MqttEntity.async_added_to_hass(), then we should not restore state and not self._expiration_trigger ): - expiration_at = last_state.last_changed + timedelta(seconds=expire_after) + expiration_at = last_state.last_changed + timedelta(seconds=_expire_after) if expiration_at < (time_now := dt_util.utcnow()): # Skip reactivating the sensor _LOGGER.debug("Skip state recovery after reload for %s", self.entity_id) @@ -222,7 +235,7 @@ class MqttSensor(MqttEntity, RestoreSensor): await MqttEntity.async_will_remove_from_hass(self) @staticmethod - def config_schema(): + def config_schema() -> vol.Schema: """Return the config schema.""" return DISCOVERY_SCHEMA @@ -233,6 +246,12 @@ class MqttSensor(MqttEntity, RestoreSensor): self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT) self._attr_state_class = config.get(CONF_STATE_CLASS) + self._expire_after = config.get(CONF_EXPIRE_AFTER) + if self._expire_after is not None and self._expire_after > 0: + self._expired = True + else: + self._expired = None + self._template = MqttValueTemplate( self._config.get(CONF_VALUE_TEMPLATE), entity=self ).async_render_with_possible_json_value @@ -240,15 +259,14 @@ class MqttSensor(MqttEntity, RestoreSensor): self._config.get(CONF_LAST_RESET_VALUE_TEMPLATE), entity=self ).async_render_with_possible_json_value - def _prepare_subscribe_topics(self): + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - topics = {} + topics: dict[str, dict[str, Any]] = {} def _update_state(msg: ReceiveMessage) -> None: # auto-expire enabled? - expire_after = self._config.get(CONF_EXPIRE_AFTER) - if expire_after is not None and expire_after > 0: - # When expire_after is set, and we receive a message, assume device is not expired since it has to be to receive the message + if self._expire_after is not None and self._expire_after > 0: + # When self._expire_after is set, and we receive a message, assume device is not expired since it has to be to receive the message self._expired = False # Reset old trigger @@ -256,13 +274,13 @@ class MqttSensor(MqttEntity, RestoreSensor): self._expiration_trigger() # Set new trigger - expiration_at = dt_util.utcnow() + timedelta(seconds=expire_after) + expiration_at = dt_util.utcnow() + timedelta(seconds=self._expire_after) self._expiration_trigger = async_track_point_in_utc_time( self.hass, self._value_is_expired, expiration_at ) - payload = self._template(msg.payload, default=PayloadSentinel.DEFAULT) + payload = self._template(msg.payload, PayloadSentinel.DEFAULT) if payload is PayloadSentinel.DEFAULT: return if self.device_class not in { @@ -282,14 +300,14 @@ class MqttSensor(MqttEntity, RestoreSensor): return self._attr_native_value = payload_datetime - def _update_last_reset(msg): + def _update_last_reset(msg: ReceiveMessage) -> None: payload = self._last_reset_template(msg.payload) if not payload: _LOGGER.debug("Ignoring empty last_reset message from '%s'", msg.topic) return try: - last_reset = dt_util.parse_datetime(payload) + last_reset = dt_util.parse_datetime(str(payload)) if last_reset is None: raise ValueError self._attr_last_reset = last_reset @@ -300,7 +318,7 @@ class MqttSensor(MqttEntity, RestoreSensor): @callback @log_messages(self.hass, self.entity_id) - def message_received(msg): + def message_received(msg: ReceiveMessage) -> None: """Handle new MQTT messages.""" _update_state(msg) if CONF_LAST_RESET_VALUE_TEMPLATE in self._config and ( @@ -319,7 +337,7 @@ class MqttSensor(MqttEntity, RestoreSensor): @callback @log_messages(self.hass, self.entity_id) - def last_reset_message_received(msg): + def last_reset_message_received(msg: ReceiveMessage) -> None: """Handle new last_reset messages.""" _update_last_reset(msg) get_mqtt_data(self.hass).state_write_requests.write_state_request(self) @@ -339,12 +357,12 @@ class MqttSensor(MqttEntity, RestoreSensor): self.hass, self._sub_state, topics ) - async def _subscribe_topics(self): + async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" await subscription.async_subscribe_topics(self.hass, self._sub_state) @callback - def _value_is_expired(self, *_): + def _value_is_expired(self, *_: datetime) -> None: """Triggered when value is expired.""" self._expiration_trigger = None self._expired = True @@ -353,8 +371,7 @@ class MqttSensor(MqttEntity, RestoreSensor): @property def available(self) -> bool: """Return true if the device is available and value has not expired.""" - expire_after = self._config.get(CONF_EXPIRE_AFTER) # mypy doesn't know about fget: https://github.com/python/mypy/issues/6185 return MqttAvailability.available.fget(self) and ( # type: ignore[attr-defined] - expire_after is None or not self._expired + self._expire_after is None or not self._expired )