Improve MQTT type hints part 5 (#80979)

* Improve typing scene

* Improve typing select

* Improve typing sensor

* move expire_after - and class level attrs

* Follow up comment

* Solve type confict

* Remove stale sentinel const

* Update homeassistant/components/mqtt/sensor.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Make _expire_after a class attribute

* Code styling

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
Jan Bouwhuis 2022-11-08 12:55:41 +01:00 committed by GitHub
parent d6c10cd887
commit 47dba6f6bc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 103 additions and 72 deletions

View file

@ -1,9 +1,11 @@
"""Support for MQTT sensors."""
from __future__ import annotations
from datetime import timedelta
from collections.abc import Callable
from datetime import datetime, timedelta
import functools
import logging
from typing import Any
import voluptuous as vol
@ -15,6 +17,7 @@ from homeassistant.components.sensor import (
STATE_CLASSES_SCHEMA,
RestoreSensor,
SensorDeviceClass,
SensorExtraStoredData,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
@ -26,7 +29,7 @@ from homeassistant.const import (
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, State, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_point_in_utc_time
@ -45,7 +48,12 @@ from .mixins import (
async_setup_platform_helper,
warn_for_legacy_schema,
)
from .models import MqttValueTemplate, PayloadSentinel, ReceiveMessage
from .models import (
MqttValueTemplate,
PayloadSentinel,
ReceiveMessage,
ReceivePayloadType,
)
from .util import get_mqtt_data, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__)
@ -65,7 +73,7 @@ DEFAULT_NAME = "MQTT Sensor"
DEFAULT_FORCE_UPDATE = False
def validate_options(conf):
def validate_options(conf: ConfigType) -> ConfigType:
"""Validate options.
If last reset topic is present it must be same as the state topic.
@ -155,8 +163,8 @@ async def _async_setup_entity(
hass: HomeAssistant,
async_add_entities: AddEntitiesCallback,
config: ConfigType,
config_entry: ConfigEntry | None = None,
discovery_data: dict | None = None,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None = None,
) -> None:
"""Set up MQTT sensor."""
async_add_entities([MqttSensor(hass, config, config_entry, discovery_data)])
@ -168,24 +176,29 @@ class MqttSensor(MqttEntity, RestoreSensor):
_entity_id_format = ENTITY_ID_FORMAT
_attr_last_reset = None
_attributes_extra_blocked = MQTT_SENSOR_ATTRIBUTES_BLOCKED
_expire_after: int | None
_expired: bool | None
_template: Callable[[ReceivePayloadType, PayloadSentinel], ReceivePayloadType]
_last_reset_template: Callable[[ReceivePayloadType], ReceivePayloadType]
def __init__(self, hass, config, config_entry, discovery_data):
def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the sensor."""
self._expiration_trigger = None
expire_after = config.get(CONF_EXPIRE_AFTER)
if expire_after is not None and expire_after > 0:
self._expired = True
else:
self._expired = None
self._expiration_trigger: CALLBACK_TYPE | None = None
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
async def mqtt_async_added_to_hass(self) -> None:
"""Restore state for entities with expire_after set."""
last_state: State | None
last_sensor_data: SensorExtraStoredData | None
if (
(expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None
and expire_after > 0
(_expire_after := self._expire_after) is not None
and _expire_after > 0
and (last_state := await self.async_get_last_state()) is not None
and last_state.state not in [STATE_UNKNOWN, STATE_UNAVAILABLE]
and (last_sensor_data := await self.async_get_last_sensor_data())
@ -194,7 +207,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
# MqttEntity.async_added_to_hass(), then we should not restore state
and not self._expiration_trigger
):
expiration_at = last_state.last_changed + timedelta(seconds=expire_after)
expiration_at = last_state.last_changed + timedelta(seconds=_expire_after)
if expiration_at < (time_now := dt_util.utcnow()):
# Skip reactivating the sensor
_LOGGER.debug("Skip state recovery after reload for %s", self.entity_id)
@ -222,7 +235,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
await MqttEntity.async_will_remove_from_hass(self)
@staticmethod
def config_schema():
def config_schema() -> vol.Schema:
"""Return the config schema."""
return DISCOVERY_SCHEMA
@ -233,6 +246,12 @@ class MqttSensor(MqttEntity, RestoreSensor):
self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT)
self._attr_state_class = config.get(CONF_STATE_CLASS)
self._expire_after = config.get(CONF_EXPIRE_AFTER)
if self._expire_after is not None and self._expire_after > 0:
self._expired = True
else:
self._expired = None
self._template = MqttValueTemplate(
self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value
@ -240,15 +259,14 @@ class MqttSensor(MqttEntity, RestoreSensor):
self._config.get(CONF_LAST_RESET_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value
def _prepare_subscribe_topics(self):
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics = {}
topics: dict[str, dict[str, Any]] = {}
def _update_state(msg: ReceiveMessage) -> None:
# auto-expire enabled?
expire_after = self._config.get(CONF_EXPIRE_AFTER)
if expire_after is not None and expire_after > 0:
# When expire_after is set, and we receive a message, assume device is not expired since it has to be to receive the message
if self._expire_after is not None and self._expire_after > 0:
# When self._expire_after is set, and we receive a message, assume device is not expired since it has to be to receive the message
self._expired = False
# Reset old trigger
@ -256,13 +274,13 @@ class MqttSensor(MqttEntity, RestoreSensor):
self._expiration_trigger()
# Set new trigger
expiration_at = dt_util.utcnow() + timedelta(seconds=expire_after)
expiration_at = dt_util.utcnow() + timedelta(seconds=self._expire_after)
self._expiration_trigger = async_track_point_in_utc_time(
self.hass, self._value_is_expired, expiration_at
)
payload = self._template(msg.payload, default=PayloadSentinel.DEFAULT)
payload = self._template(msg.payload, PayloadSentinel.DEFAULT)
if payload is PayloadSentinel.DEFAULT:
return
if self.device_class not in {
@ -282,14 +300,14 @@ class MqttSensor(MqttEntity, RestoreSensor):
return
self._attr_native_value = payload_datetime
def _update_last_reset(msg):
def _update_last_reset(msg: ReceiveMessage) -> None:
payload = self._last_reset_template(msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty last_reset message from '%s'", msg.topic)
return
try:
last_reset = dt_util.parse_datetime(payload)
last_reset = dt_util.parse_datetime(str(payload))
if last_reset is None:
raise ValueError
self._attr_last_reset = last_reset
@ -300,7 +318,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
@callback
@log_messages(self.hass, self.entity_id)
def message_received(msg):
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
_update_state(msg)
if CONF_LAST_RESET_VALUE_TEMPLATE in self._config and (
@ -319,7 +337,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
@callback
@log_messages(self.hass, self.entity_id)
def last_reset_message_received(msg):
def last_reset_message_received(msg: ReceiveMessage) -> None:
"""Handle new last_reset messages."""
_update_last_reset(msg)
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
@ -339,12 +357,12 @@ class MqttSensor(MqttEntity, RestoreSensor):
self.hass, self._sub_state, topics
)
async def _subscribe_topics(self):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@callback
def _value_is_expired(self, *_):
def _value_is_expired(self, *_: datetime) -> None:
"""Triggered when value is expired."""
self._expiration_trigger = None
self._expired = True
@ -353,8 +371,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
@property
def available(self) -> bool:
"""Return true if the device is available and value has not expired."""
expire_after = self._config.get(CONF_EXPIRE_AFTER)
# mypy doesn't know about fget: https://github.com/python/mypy/issues/6185
return MqttAvailability.available.fget(self) and ( # type: ignore[attr-defined]
expire_after is None or not self._expired
self._expire_after is None or not self._expired
)