Improve type hints for MQTT climate (#81396)

* Improve typing climate

* Move climate type hints to class level

* Apply suggestions from code review

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* remove stale command after applying suggestions

* cleanup

* Update homeassistant/components/mqtt/climate.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
Jan Bouwhuis 2022-11-08 09:17:03 +01:00 committed by GitHub
parent 23bed25e52
commit 88faf33cb8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,6 +1,7 @@
"""Support for MQTT climate devices."""
from __future__ import annotations
from collections.abc import Callable
import functools
import logging
from typing import Any
@ -41,6 +42,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import subscription
@ -54,7 +56,13 @@ from .mixins import (
async_setup_platform_helper,
warn_for_legacy_schema,
)
from .models import MqttCommandTemplate, MqttValueTemplate
from .models import (
MqttCommandTemplate,
MqttValueTemplate,
PublishPayloadType,
ReceiveMessage,
ReceivePayloadType,
)
from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__)
@ -197,9 +205,9 @@ TOPIC_KEYS = (
)
def valid_preset_mode_configuration(config):
def valid_preset_mode_configuration(config: ConfigType) -> ConfigType:
"""Validate that the preset mode reset payload is not one of the preset modes."""
if PRESET_NONE in config.get(CONF_PRESET_MODES_LIST):
if PRESET_NONE in config[CONF_PRESET_MODES_LIST]:
raise ValueError("preset_modes must not include preset mode 'none'")
return config
@ -359,8 +367,8 @@ async def _async_setup_entity(
hass: HomeAssistant,
async_add_entities: AddEntitiesCallback,
config: ConfigType,
config_entry: ConfigEntry | None = None,
discovery_data: dict | None = None,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None = None,
) -> None:
"""Set up the MQTT climate devices."""
async_add_entities([MqttClimate(hass, config, config_entry, discovery_data)])
@ -372,22 +380,28 @@ class MqttClimate(MqttEntity, ClimateEntity):
_entity_id_format = climate.ENTITY_ID_FORMAT
_attributes_extra_blocked = MQTT_CLIMATE_ATTRIBUTES_BLOCKED
def __init__(self, hass, config, config_entry, discovery_data):
"""Initialize the climate device."""
self._topic = None
self._value_templates = None
self._command_templates = None
self._feature_preset_mode = False
self._optimistic_preset_mode = None
_command_templates: dict[str, Callable[[PublishPayloadType], PublishPayloadType]]
_value_templates: dict[str, Callable[[ReceivePayloadType], ReceivePayloadType]]
_feature_preset_mode: bool
_optimistic_preset_mode: bool
_topic: dict[str, Any]
def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the climate device."""
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod
def config_schema():
def config_schema() -> vol.Schema:
"""Return the config schema."""
return DISCOVERY_SCHEMA
def _setup_from_config(self, config):
def _setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup the entity."""
self._attr_hvac_modes = config[CONF_MODE_LIST]
self._attr_min_temp = config[CONF_TEMP_MIN]
@ -438,7 +452,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
self._attr_is_aux_heat = False
value_templates = {}
value_templates: dict[str, Template | None] = {}
for key in VALUE_TEMPLATE_KEYS:
value_templates[key] = None
if CONF_VALUE_TEMPLATE in config:
@ -455,14 +469,12 @@ class MqttClimate(MqttEntity, ClimateEntity):
for key, template in value_templates.items()
}
command_templates = {}
self._command_templates = {}
for key in COMMAND_TEMPLATE_KEYS:
command_templates[key] = MqttCommandTemplate(
self._command_templates[key] = MqttCommandTemplate(
config.get(key), entity=self
).async_render
self._command_templates = command_templates
support: int = 0
if (self._topic[CONF_TEMP_STATE_TOPIC] is not None) or (
self._topic[CONF_TEMP_COMMAND_TOPIC] is not None
@ -498,12 +510,16 @@ class MqttClimate(MqttEntity, ClimateEntity):
support |= ClimateEntityFeature.AUX_HEAT
self._attr_supported_features = support
def _prepare_subscribe_topics(self): # noqa: C901
def _prepare_subscribe_topics(self) -> None: # noqa: C901
"""(Re)Subscribe to topics."""
topics = {}
qos = self._config[CONF_QOS]
topics: dict[str, dict[str, Any]] = {}
qos: int = self._config[CONF_QOS]
def add_subscription(topics, topic, msg_callback):
def add_subscription(
topics: dict[str, dict[str, Any]],
topic: str,
msg_callback: Callable[[ReceiveMessage], None],
) -> None:
if self._topic[topic] is not None:
topics[topic] = {
"topic": self._topic[topic],
@ -512,13 +528,15 @@ class MqttClimate(MqttEntity, ClimateEntity):
"encoding": self._config[CONF_ENCODING] or None,
}
def render_template(msg, template_name):
def render_template(
msg: ReceiveMessage, template_name: str
) -> ReceivePayloadType:
template = self._value_templates[template_name]
return template(msg.payload)
@callback
@log_messages(self.hass, self.entity_id)
def handle_action_received(msg):
def handle_action_received(msg: ReceiveMessage) -> None:
"""Handle receiving action via MQTT."""
payload = render_template(msg, CONF_ACTION_TEMPLATE)
if not payload or payload == PAYLOAD_NONE:
@ -529,7 +547,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
)
return
try:
self._attr_hvac_action = HVACAction(payload)
self._attr_hvac_action = HVACAction(str(payload))
except ValueError:
_LOGGER.warning(
"Invalid %s action: %s",
@ -542,7 +560,9 @@ class MqttClimate(MqttEntity, ClimateEntity):
add_subscription(topics, CONF_ACTION_TOPIC, handle_action_received)
@callback
def handle_temperature_received(msg, template_name, attr):
def handle_temperature_received(
msg: ReceiveMessage, template_name: str, attr: str
) -> None:
"""Handle temperature coming via MQTT."""
payload = render_template(msg, template_name)
@ -554,7 +574,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
@callback
@log_messages(self.hass, self.entity_id)
def handle_current_temperature_received(msg):
def handle_current_temperature_received(msg: ReceiveMessage) -> None:
"""Handle current temperature coming via MQTT."""
handle_temperature_received(
msg, CONF_CURRENT_TEMP_TEMPLATE, "_attr_current_temperature"
@ -566,7 +586,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
@callback
@log_messages(self.hass, self.entity_id)
def handle_target_temperature_received(msg):
def handle_target_temperature_received(msg: ReceiveMessage) -> None:
"""Handle target temperature coming via MQTT."""
handle_temperature_received(
msg, CONF_TEMP_STATE_TEMPLATE, "_attr_target_temperature"
@ -578,7 +598,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
@callback
@log_messages(self.hass, self.entity_id)
def handle_temperature_low_received(msg):
def handle_temperature_low_received(msg: ReceiveMessage) -> None:
"""Handle target temperature low coming via MQTT."""
handle_temperature_received(
msg, CONF_TEMP_LOW_STATE_TEMPLATE, "_attr_target_temperature_low"
@ -590,7 +610,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
@callback
@log_messages(self.hass, self.entity_id)
def handle_temperature_high_received(msg):
def handle_temperature_high_received(msg: ReceiveMessage) -> None:
"""Handle target temperature high coming via MQTT."""
handle_temperature_received(
msg, CONF_TEMP_HIGH_STATE_TEMPLATE, "_attr_target_temperature_high"
@ -601,7 +621,9 @@ class MqttClimate(MqttEntity, ClimateEntity):
)
@callback
def handle_mode_received(msg, template_name, attr, mode_list):
def handle_mode_received(
msg: ReceiveMessage, template_name: str, attr: str, mode_list: str
) -> None:
"""Handle receiving listed mode via MQTT."""
payload = render_template(msg, template_name)
@ -613,7 +635,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
@callback
@log_messages(self.hass, self.entity_id)
def handle_current_mode_received(msg):
def handle_current_mode_received(msg: ReceiveMessage) -> None:
"""Handle receiving mode via MQTT."""
handle_mode_received(
msg, CONF_MODE_STATE_TEMPLATE, "_attr_hvac_mode", CONF_MODE_LIST
@ -623,7 +645,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
@callback
@log_messages(self.hass, self.entity_id)
def handle_fan_mode_received(msg):
def handle_fan_mode_received(msg: ReceiveMessage) -> None:
"""Handle receiving fan mode via MQTT."""
handle_mode_received(
msg,
@ -636,7 +658,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
@callback
@log_messages(self.hass, self.entity_id)
def handle_swing_mode_received(msg):
def handle_swing_mode_received(msg: ReceiveMessage) -> None:
"""Handle receiving swing mode via MQTT."""
handle_mode_received(
msg,
@ -650,11 +672,13 @@ class MqttClimate(MqttEntity, ClimateEntity):
)
@callback
def handle_onoff_mode_received(msg, template_name, attr):
def handle_onoff_mode_received(
msg: ReceiveMessage, template_name: str, attr: str
) -> None:
"""Handle receiving on/off mode via MQTT."""
payload = render_template(msg, template_name)
payload_on = self._config[CONF_PAYLOAD_ON]
payload_off = self._config[CONF_PAYLOAD_OFF]
payload_on: str = self._config[CONF_PAYLOAD_ON]
payload_off: str = self._config[CONF_PAYLOAD_OFF]
if payload == "True":
payload = payload_on
@ -672,7 +696,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
@callback
@log_messages(self.hass, self.entity_id)
def handle_aux_mode_received(msg):
def handle_aux_mode_received(msg: ReceiveMessage) -> None:
"""Handle receiving aux mode via MQTT."""
handle_onoff_mode_received(
msg, CONF_AUX_STATE_TEMPLATE, "_attr_is_aux_heat"
@ -682,7 +706,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
@callback
@log_messages(self.hass, self.entity_id)
def handle_preset_mode_received(msg):
def handle_preset_mode_received(msg: ReceiveMessage) -> None:
"""Handle receiving preset mode via MQTT."""
preset_mode = render_template(msg, CONF_PRESET_MODE_VALUE_TEMPLATE)
if preset_mode in [PRESET_NONE, PAYLOAD_NONE]:
@ -692,7 +716,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
if not preset_mode:
_LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic)
return
if preset_mode not in self.preset_modes:
if not self.preset_modes or preset_mode not in self.preset_modes:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid preset mode",
msg.payload,
@ -700,7 +724,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
preset_mode,
)
else:
self._attr_preset_mode = preset_mode
self._attr_preset_mode = str(preset_mode)
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
@ -712,11 +736,11 @@ class MqttClimate(MqttEntity, ClimateEntity):
self.hass, self._sub_state, topics
)
async def _subscribe_topics(self):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
async def _publish(self, topic, payload):
async def _publish(self, topic: str, payload: PublishPayloadType) -> None:
if self._topic[topic] is not None:
await self.async_publish(
self._topic[topic],
@ -727,8 +751,13 @@ class MqttClimate(MqttEntity, ClimateEntity):
)
async def _set_temperature(
self, temp, cmnd_topic, cmnd_template, state_topic, attr
):
self,
temp: float | None,
cmnd_topic: str,
cmnd_template: str,
state_topic: str,
attr: str,
) -> None:
if temp is not None:
if self._topic[state_topic] is None:
# optimistic mode
@ -822,7 +851,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
return
async def _set_aux_heat(self, state):
async def _set_aux_heat(self, state: bool) -> None:
await self._publish(
CONF_AUX_COMMAND_TOPIC,
self._config[CONF_PAYLOAD_ON] if state else self._config[CONF_PAYLOAD_OFF],