Improve MQTT type hints part 5 (#80979)

* Improve typing scene

* Improve typing select

* Improve typing sensor

* move expire_after - and class level attrs

* Follow up comment

* Solve type confict

* Remove stale sentinel const

* Update homeassistant/components/mqtt/sensor.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Make _expire_after a class attribute

* Code styling

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
Jan Bouwhuis 2022-11-08 12:55:41 +01:00 committed by GitHub
parent d6c10cd887
commit 47dba6f6bc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 103 additions and 72 deletions

View file

@ -88,8 +88,8 @@ async def _async_setup_entity(
hass: HomeAssistant, hass: HomeAssistant,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
config: ConfigType, config: ConfigType,
config_entry: ConfigEntry | None = None, config_entry: ConfigEntry,
discovery_data: dict | None = None, discovery_data: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Set up the MQTT scene.""" """Set up the MQTT scene."""
async_add_entities([MqttScene(hass, config, config_entry, discovery_data)]) async_add_entities([MqttScene(hass, config, config_entry, discovery_data)])
@ -103,23 +103,29 @@ class MqttScene(
_entity_id_format = scene.DOMAIN + ".{}" _entity_id_format = scene.DOMAIN + ".{}"
def __init__(self, hass, config, config_entry, discovery_data): def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the MQTT scene.""" """Initialize the MQTT scene."""
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod @staticmethod
def config_schema(): def config_schema() -> vol.Schema:
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
def _setup_from_config(self, config): def _setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup the entity.""" """(Re)Setup the entity."""
self._config = config self._config = config
def _prepare_subscribe_topics(self): def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
async def _subscribe_topics(self): async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
async def async_activate(self, **kwargs: Any) -> None: async def async_activate(self, **kwargs: Any) -> None:

View file

@ -1,6 +1,7 @@
"""Configure select in a device through MQTT topic.""" """Configure select in a device through MQTT topic."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import functools import functools
import logging import logging
@ -34,7 +35,13 @@ from .mixins import (
async_setup_platform_helper, async_setup_platform_helper,
warn_for_legacy_schema, warn_for_legacy_schema,
) )
from .models import MqttCommandTemplate, MqttValueTemplate from .models import (
MqttCommandTemplate,
MqttValueTemplate,
PublishPayloadType,
ReceiveMessage,
ReceivePayloadType,
)
from .util import get_mqtt_data from .util import get_mqtt_data
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -103,8 +110,8 @@ async def _async_setup_entity(
hass: HomeAssistant, hass: HomeAssistant,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
config: ConfigType, config: ConfigType,
config_entry: ConfigEntry | None = None, config_entry: ConfigEntry,
discovery_data: dict | None = None, discovery_data: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Set up the MQTT select.""" """Set up the MQTT select."""
async_add_entities([MqttSelect(hass, config, config_entry, discovery_data)]) async_add_entities([MqttSelect(hass, config, config_entry, discovery_data)])
@ -114,53 +121,55 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
"""representation of an MQTT select.""" """representation of an MQTT select."""
_entity_id_format = select.ENTITY_ID_FORMAT _entity_id_format = select.ENTITY_ID_FORMAT
_attributes_extra_blocked = MQTT_SELECT_ATTRIBUTES_BLOCKED _attributes_extra_blocked = MQTT_SELECT_ATTRIBUTES_BLOCKED
_command_template: Callable[[PublishPayloadType], PublishPayloadType]
_value_template: Callable[[ReceivePayloadType], ReceivePayloadType]
_optimistic: bool = False
def __init__(self, hass, config, config_entry, discovery_data): def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the MQTT select.""" """Initialize the MQTT select."""
self._config = config
self._optimistic = False
self._sub_state = None
self._attr_current_option = None
SelectEntity.__init__(self) SelectEntity.__init__(self)
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod @staticmethod
def config_schema(): def config_schema() -> vol.Schema:
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
def _setup_from_config(self, config): def _setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup the entity.""" """(Re)Setup the entity."""
self._attr_current_option = None
self._optimistic = config[CONF_OPTIMISTIC] self._optimistic = config[CONF_OPTIMISTIC]
self._attr_options = config[CONF_OPTIONS] self._attr_options = config[CONF_OPTIONS]
self._templates = { self._command_template = MqttCommandTemplate(
CONF_COMMAND_TEMPLATE: MqttCommandTemplate( config.get(CONF_COMMAND_TEMPLATE),
config.get(CONF_COMMAND_TEMPLATE), entity=self
).async_render,
CONF_VALUE_TEMPLATE: MqttValueTemplate(
config.get(CONF_VALUE_TEMPLATE),
entity=self, entity=self,
).async_render_with_possible_json_value, ).async_render
} self._value_template = MqttValueTemplate(
config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value
def _prepare_subscribe_topics(self): def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def message_received(msg): def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload) payload = str(self._value_template(msg.payload))
if payload.lower() == "none": if payload.lower() == "none":
payload = None self._attr_current_option = None
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
return
if payload is not None and payload not in self.options: if payload not in self.options:
_LOGGER.error( _LOGGER.error(
"Invalid option for %s: '%s' (valid options: %s)", "Invalid option for %s: '%s' (valid options: %s)",
self.entity_id, self.entity_id,
@ -168,7 +177,6 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
self.options, self.options,
) )
return return
self._attr_current_option = payload self._attr_current_option = payload
get_mqtt_data(self.hass).state_write_requests.write_state_request(self) get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
@ -189,7 +197,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
}, },
) )
async def _subscribe_topics(self): async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) await subscription.async_subscribe_topics(self.hass, self._sub_state)
@ -198,7 +206,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
async def async_select_option(self, option: str) -> None: async def async_select_option(self, option: str) -> None:
"""Update the current value.""" """Update the current value."""
payload = self._templates[CONF_COMMAND_TEMPLATE](option) payload = self._command_template(option)
if self._optimistic: if self._optimistic:
self._attr_current_option = option self._attr_current_option = option
self.async_write_ha_state() self.async_write_ha_state()

View file

@ -1,9 +1,11 @@
"""Support for MQTT sensors.""" """Support for MQTT sensors."""
from __future__ import annotations from __future__ import annotations
from datetime import timedelta from collections.abc import Callable
from datetime import datetime, timedelta
import functools import functools
import logging import logging
from typing import Any
import voluptuous as vol import voluptuous as vol
@ -15,6 +17,7 @@ from homeassistant.components.sensor import (
STATE_CLASSES_SCHEMA, STATE_CLASSES_SCHEMA,
RestoreSensor, RestoreSensor,
SensorDeviceClass, SensorDeviceClass,
SensorExtraStoredData,
) )
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
@ -26,7 +29,7 @@ from homeassistant.const import (
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
STATE_UNKNOWN, STATE_UNKNOWN,
) )
from homeassistant.core import HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, State, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_point_in_utc_time from homeassistant.helpers.event import async_track_point_in_utc_time
@ -45,7 +48,12 @@ from .mixins import (
async_setup_platform_helper, async_setup_platform_helper,
warn_for_legacy_schema, warn_for_legacy_schema,
) )
from .models import MqttValueTemplate, PayloadSentinel, ReceiveMessage from .models import (
MqttValueTemplate,
PayloadSentinel,
ReceiveMessage,
ReceivePayloadType,
)
from .util import get_mqtt_data, valid_subscribe_topic from .util import get_mqtt_data, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -65,7 +73,7 @@ DEFAULT_NAME = "MQTT Sensor"
DEFAULT_FORCE_UPDATE = False DEFAULT_FORCE_UPDATE = False
def validate_options(conf): def validate_options(conf: ConfigType) -> ConfigType:
"""Validate options. """Validate options.
If last reset topic is present it must be same as the state topic. If last reset topic is present it must be same as the state topic.
@ -155,8 +163,8 @@ async def _async_setup_entity(
hass: HomeAssistant, hass: HomeAssistant,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
config: ConfigType, config: ConfigType,
config_entry: ConfigEntry | None = None, config_entry: ConfigEntry,
discovery_data: dict | None = None, discovery_data: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Set up MQTT sensor.""" """Set up MQTT sensor."""
async_add_entities([MqttSensor(hass, config, config_entry, discovery_data)]) async_add_entities([MqttSensor(hass, config, config_entry, discovery_data)])
@ -168,24 +176,29 @@ class MqttSensor(MqttEntity, RestoreSensor):
_entity_id_format = ENTITY_ID_FORMAT _entity_id_format = ENTITY_ID_FORMAT
_attr_last_reset = None _attr_last_reset = None
_attributes_extra_blocked = MQTT_SENSOR_ATTRIBUTES_BLOCKED _attributes_extra_blocked = MQTT_SENSOR_ATTRIBUTES_BLOCKED
_expire_after: int | None
_expired: bool | None
_template: Callable[[ReceivePayloadType, PayloadSentinel], ReceivePayloadType]
_last_reset_template: Callable[[ReceivePayloadType], ReceivePayloadType]
def __init__(self, hass, config, config_entry, discovery_data): def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the sensor.""" """Initialize the sensor."""
self._expiration_trigger = None self._expiration_trigger: CALLBACK_TYPE | None = None
expire_after = config.get(CONF_EXPIRE_AFTER)
if expire_after is not None and expire_after > 0:
self._expired = True
else:
self._expired = None
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
async def mqtt_async_added_to_hass(self) -> None: async def mqtt_async_added_to_hass(self) -> None:
"""Restore state for entities with expire_after set.""" """Restore state for entities with expire_after set."""
last_state: State | None
last_sensor_data: SensorExtraStoredData | None
if ( if (
(expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None (_expire_after := self._expire_after) is not None
and expire_after > 0 and _expire_after > 0
and (last_state := await self.async_get_last_state()) is not None and (last_state := await self.async_get_last_state()) is not None
and last_state.state not in [STATE_UNKNOWN, STATE_UNAVAILABLE] and last_state.state not in [STATE_UNKNOWN, STATE_UNAVAILABLE]
and (last_sensor_data := await self.async_get_last_sensor_data()) and (last_sensor_data := await self.async_get_last_sensor_data())
@ -194,7 +207,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
# MqttEntity.async_added_to_hass(), then we should not restore state # MqttEntity.async_added_to_hass(), then we should not restore state
and not self._expiration_trigger and not self._expiration_trigger
): ):
expiration_at = last_state.last_changed + timedelta(seconds=expire_after) expiration_at = last_state.last_changed + timedelta(seconds=_expire_after)
if expiration_at < (time_now := dt_util.utcnow()): if expiration_at < (time_now := dt_util.utcnow()):
# Skip reactivating the sensor # Skip reactivating the sensor
_LOGGER.debug("Skip state recovery after reload for %s", self.entity_id) _LOGGER.debug("Skip state recovery after reload for %s", self.entity_id)
@ -222,7 +235,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
await MqttEntity.async_will_remove_from_hass(self) await MqttEntity.async_will_remove_from_hass(self)
@staticmethod @staticmethod
def config_schema(): def config_schema() -> vol.Schema:
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
@ -233,6 +246,12 @@ class MqttSensor(MqttEntity, RestoreSensor):
self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT) self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT)
self._attr_state_class = config.get(CONF_STATE_CLASS) self._attr_state_class = config.get(CONF_STATE_CLASS)
self._expire_after = config.get(CONF_EXPIRE_AFTER)
if self._expire_after is not None and self._expire_after > 0:
self._expired = True
else:
self._expired = None
self._template = MqttValueTemplate( self._template = MqttValueTemplate(
self._config.get(CONF_VALUE_TEMPLATE), entity=self self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
@ -240,15 +259,14 @@ class MqttSensor(MqttEntity, RestoreSensor):
self._config.get(CONF_LAST_RESET_VALUE_TEMPLATE), entity=self self._config.get(CONF_LAST_RESET_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self): def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics = {} topics: dict[str, dict[str, Any]] = {}
def _update_state(msg: ReceiveMessage) -> None: def _update_state(msg: ReceiveMessage) -> None:
# auto-expire enabled? # auto-expire enabled?
expire_after = self._config.get(CONF_EXPIRE_AFTER) if self._expire_after is not None and self._expire_after > 0:
if expire_after is not None and expire_after > 0: # When self._expire_after is set, and we receive a message, assume device is not expired since it has to be to receive the message
# When expire_after is set, and we receive a message, assume device is not expired since it has to be to receive the message
self._expired = False self._expired = False
# Reset old trigger # Reset old trigger
@ -256,13 +274,13 @@ class MqttSensor(MqttEntity, RestoreSensor):
self._expiration_trigger() self._expiration_trigger()
# Set new trigger # Set new trigger
expiration_at = dt_util.utcnow() + timedelta(seconds=expire_after) expiration_at = dt_util.utcnow() + timedelta(seconds=self._expire_after)
self._expiration_trigger = async_track_point_in_utc_time( self._expiration_trigger = async_track_point_in_utc_time(
self.hass, self._value_is_expired, expiration_at self.hass, self._value_is_expired, expiration_at
) )
payload = self._template(msg.payload, default=PayloadSentinel.DEFAULT) payload = self._template(msg.payload, PayloadSentinel.DEFAULT)
if payload is PayloadSentinel.DEFAULT: if payload is PayloadSentinel.DEFAULT:
return return
if self.device_class not in { if self.device_class not in {
@ -282,14 +300,14 @@ class MqttSensor(MqttEntity, RestoreSensor):
return return
self._attr_native_value = payload_datetime self._attr_native_value = payload_datetime
def _update_last_reset(msg): def _update_last_reset(msg: ReceiveMessage) -> None:
payload = self._last_reset_template(msg.payload) payload = self._last_reset_template(msg.payload)
if not payload: if not payload:
_LOGGER.debug("Ignoring empty last_reset message from '%s'", msg.topic) _LOGGER.debug("Ignoring empty last_reset message from '%s'", msg.topic)
return return
try: try:
last_reset = dt_util.parse_datetime(payload) last_reset = dt_util.parse_datetime(str(payload))
if last_reset is None: if last_reset is None:
raise ValueError raise ValueError
self._attr_last_reset = last_reset self._attr_last_reset = last_reset
@ -300,7 +318,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def message_received(msg): def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
_update_state(msg) _update_state(msg)
if CONF_LAST_RESET_VALUE_TEMPLATE in self._config and ( if CONF_LAST_RESET_VALUE_TEMPLATE in self._config and (
@ -319,7 +337,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def last_reset_message_received(msg): def last_reset_message_received(msg: ReceiveMessage) -> None:
"""Handle new last_reset messages.""" """Handle new last_reset messages."""
_update_last_reset(msg) _update_last_reset(msg)
get_mqtt_data(self.hass).state_write_requests.write_state_request(self) get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
@ -339,12 +357,12 @@ class MqttSensor(MqttEntity, RestoreSensor):
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
) )
async def _subscribe_topics(self): async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) await subscription.async_subscribe_topics(self.hass, self._sub_state)
@callback @callback
def _value_is_expired(self, *_): def _value_is_expired(self, *_: datetime) -> None:
"""Triggered when value is expired.""" """Triggered when value is expired."""
self._expiration_trigger = None self._expiration_trigger = None
self._expired = True self._expired = True
@ -353,8 +371,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
@property @property
def available(self) -> bool: def available(self) -> bool:
"""Return true if the device is available and value has not expired.""" """Return true if the device is available and value has not expired."""
expire_after = self._config.get(CONF_EXPIRE_AFTER)
# mypy doesn't know about fget: https://github.com/python/mypy/issues/6185 # mypy doesn't know about fget: https://github.com/python/mypy/issues/6185
return MqttAvailability.available.fget(self) and ( # type: ignore[attr-defined] return MqttAvailability.available.fget(self) and ( # type: ignore[attr-defined]
expire_after is None or not self._expired self._expire_after is None or not self._expired
) )