Improve MQTT type hints part 2 (#80529)

* Improve typing camera

* Improve typing cover

* b64 encoding can be either bytes or a string.
This commit is contained in:
Jan Bouwhuis 2022-11-02 20:33:46 +01:00 committed by GitHub
parent b4ad03784f
commit bda7e416c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 67 additions and 47 deletions

View file

@ -27,6 +27,7 @@ from .mixins import (
async_setup_platform_helper, async_setup_platform_helper,
warn_for_legacy_schema, warn_for_legacy_schema,
) )
from .models import ReceiveMessage
from .util import valid_subscribe_topic from .util import valid_subscribe_topic
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -114,8 +115,8 @@ async def _async_setup_entity(
hass: HomeAssistant, hass: HomeAssistant,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
config: ConfigType, config: ConfigType,
config_entry: ConfigEntry | None = None, config_entry: ConfigEntry,
discovery_data: dict | None = None, discovery_data: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Set up the MQTT Camera.""" """Set up the MQTT Camera."""
async_add_entities([MqttCamera(hass, config, config_entry, discovery_data)]) async_add_entities([MqttCamera(hass, config, config_entry, discovery_data)])
@ -124,31 +125,38 @@ async def _async_setup_entity(
class MqttCamera(MqttEntity, Camera): class MqttCamera(MqttEntity, Camera):
"""representation of a MQTT camera.""" """representation of a MQTT camera."""
_entity_id_format = camera.ENTITY_ID_FORMAT _entity_id_format: str = camera.ENTITY_ID_FORMAT
_attributes_extra_blocked = MQTT_CAMERA_ATTRIBUTES_BLOCKED _attributes_extra_blocked: frozenset[str] = MQTT_CAMERA_ATTRIBUTES_BLOCKED
def __init__(self, hass, config, config_entry, discovery_data): def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the MQTT Camera.""" """Initialize the MQTT Camera."""
self._last_image = None self._last_image: bytes | None = None
Camera.__init__(self) Camera.__init__(self)
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod @staticmethod
def config_schema(): def config_schema() -> vol.Schema:
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
def _prepare_subscribe_topics(self): def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def message_received(msg): def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
if CONF_IMAGE_ENCODING in self._config: if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload) self._last_image = b64decode(msg.payload)
else: else:
assert isinstance(msg.payload, bytes)
self._last_image = msg.payload self._last_image = msg.payload
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
@ -164,7 +172,7 @@ class MqttCamera(MqttEntity, Camera):
}, },
) )
async def _subscribe_topics(self): async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) await subscription.async_subscribe_topics(self.hass, self._sub_state)

View file

@ -31,6 +31,7 @@ from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.json import JSON_DECODE_EXCEPTIONS, json_loads from homeassistant.helpers.json import JSON_DECODE_EXCEPTIONS, json_loads
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import subscription from . import subscription
@ -50,7 +51,7 @@ from .mixins import (
async_setup_platform_helper, async_setup_platform_helper,
warn_for_legacy_schema, warn_for_legacy_schema,
) )
from .models import MqttCommandTemplate, MqttValueTemplate from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage
from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -113,44 +114,44 @@ MQTT_COVER_ATTRIBUTES_BLOCKED = frozenset(
) )
def validate_options(value): def validate_options(config: ConfigType) -> ConfigType:
"""Validate options. """Validate options.
If set position topic is set then get position topic is set as well. If set position topic is set then get position topic is set as well.
""" """
if CONF_SET_POSITION_TOPIC in value and CONF_GET_POSITION_TOPIC not in value: if CONF_SET_POSITION_TOPIC in config and CONF_GET_POSITION_TOPIC not in config:
raise vol.Invalid( raise vol.Invalid(
f"'{CONF_SET_POSITION_TOPIC}' must be set together with '{CONF_GET_POSITION_TOPIC}'." f"'{CONF_SET_POSITION_TOPIC}' must be set together with '{CONF_GET_POSITION_TOPIC}'."
) )
# if templates are set make sure the topic for the template is also set # if templates are set make sure the topic for the template is also set
if CONF_VALUE_TEMPLATE in value and CONF_STATE_TOPIC not in value: if CONF_VALUE_TEMPLATE in config and CONF_STATE_TOPIC not in config:
raise vol.Invalid( raise vol.Invalid(
f"'{CONF_VALUE_TEMPLATE}' must be set together with '{CONF_STATE_TOPIC}'." f"'{CONF_VALUE_TEMPLATE}' must be set together with '{CONF_STATE_TOPIC}'."
) )
if CONF_GET_POSITION_TEMPLATE in value and CONF_GET_POSITION_TOPIC not in value: if CONF_GET_POSITION_TEMPLATE in config and CONF_GET_POSITION_TOPIC not in config:
raise vol.Invalid( raise vol.Invalid(
f"'{CONF_GET_POSITION_TEMPLATE}' must be set together with '{CONF_GET_POSITION_TOPIC}'." f"'{CONF_GET_POSITION_TEMPLATE}' must be set together with '{CONF_GET_POSITION_TOPIC}'."
) )
if CONF_SET_POSITION_TEMPLATE in value and CONF_SET_POSITION_TOPIC not in value: if CONF_SET_POSITION_TEMPLATE in config and CONF_SET_POSITION_TOPIC not in config:
raise vol.Invalid( raise vol.Invalid(
f"'{CONF_SET_POSITION_TEMPLATE}' must be set together with '{CONF_SET_POSITION_TOPIC}'." f"'{CONF_SET_POSITION_TEMPLATE}' must be set together with '{CONF_SET_POSITION_TOPIC}'."
) )
if CONF_TILT_COMMAND_TEMPLATE in value and CONF_TILT_COMMAND_TOPIC not in value: if CONF_TILT_COMMAND_TEMPLATE in config and CONF_TILT_COMMAND_TOPIC not in config:
raise vol.Invalid( raise vol.Invalid(
f"'{CONF_TILT_COMMAND_TEMPLATE}' must be set together with '{CONF_TILT_COMMAND_TOPIC}'." f"'{CONF_TILT_COMMAND_TEMPLATE}' must be set together with '{CONF_TILT_COMMAND_TOPIC}'."
) )
if CONF_TILT_STATUS_TEMPLATE in value and CONF_TILT_STATUS_TOPIC not in value: if CONF_TILT_STATUS_TEMPLATE in config and CONF_TILT_STATUS_TOPIC not in config:
raise vol.Invalid( raise vol.Invalid(
f"'{CONF_TILT_STATUS_TEMPLATE}' must be set together with '{CONF_TILT_STATUS_TOPIC}'." f"'{CONF_TILT_STATUS_TEMPLATE}' must be set together with '{CONF_TILT_STATUS_TOPIC}'."
) )
return value return config
_PLATFORM_SCHEMA_BASE = MQTT_BASE_SCHEMA.extend( _PLATFORM_SCHEMA_BASE = MQTT_BASE_SCHEMA.extend(
@ -251,8 +252,8 @@ async def _async_setup_entity(
hass: HomeAssistant, hass: HomeAssistant,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
config: ConfigType, config: ConfigType,
config_entry: ConfigEntry | None = None, config_entry: ConfigEntry,
discovery_data: dict | None = None, discovery_data: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Set up the MQTT Cover.""" """Set up the MQTT Cover."""
async_add_entities([MqttCover(hass, config, config_entry, discovery_data)]) async_add_entities([MqttCover(hass, config, config_entry, discovery_data)])
@ -261,26 +262,32 @@ async def _async_setup_entity(
class MqttCover(MqttEntity, CoverEntity): class MqttCover(MqttEntity, CoverEntity):
"""Representation of a cover that can be controlled using MQTT.""" """Representation of a cover that can be controlled using MQTT."""
_entity_id_format = cover.ENTITY_ID_FORMAT _entity_id_format: str = cover.ENTITY_ID_FORMAT
_attributes_extra_blocked = MQTT_COVER_ATTRIBUTES_BLOCKED _attributes_extra_blocked: frozenset[str] = MQTT_COVER_ATTRIBUTES_BLOCKED
def __init__(self, hass, config, config_entry, discovery_data): def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the cover.""" """Initialize the cover."""
self._position = None self._position: int | None = None
self._state = None self._state: str | None = None
self._optimistic = None self._optimistic: bool | None = None
self._tilt_value = None self._tilt_value: int | None = None
self._tilt_optimistic = None self._tilt_optimistic: bool | None = None
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod @staticmethod
def config_schema(): def config_schema() -> vol.Schema:
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
def _setup_from_config(self, config): def _setup_from_config(self, config: ConfigType) -> None:
no_position = ( no_position = (
config.get(CONF_SET_POSITION_TOPIC) is None config.get(CONF_SET_POSITION_TOPIC) is None
and config.get(CONF_GET_POSITION_TOPIC) is None and config.get(CONF_GET_POSITION_TOPIC) is None
@ -353,13 +360,13 @@ class MqttCover(MqttEntity, CoverEntity):
config_attributes=template_config_attributes, config_attributes=template_config_attributes,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self): def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics = {} topics = {}
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def tilt_message_received(msg): def tilt_message_received(msg: ReceiveMessage) -> None:
"""Handle tilt updates.""" """Handle tilt updates."""
payload = self._tilt_status_template(msg.payload) payload = self._tilt_status_template(msg.payload)
@ -371,7 +378,7 @@ class MqttCover(MqttEntity, CoverEntity):
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def state_message_received(msg): def state_message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages.""" """Handle new MQTT state messages."""
payload = self._value_template(msg.payload) payload = self._value_template(msg.payload)
@ -409,31 +416,32 @@ class MqttCover(MqttEntity, CoverEntity):
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def position_message_received(msg): def position_message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT position messages.""" """Handle new MQTT position messages."""
payload = self._get_position_template(msg.payload) payload: ReceivePayloadType = self._get_position_template(msg.payload)
payload_dict: Any = None
if not payload: if not payload:
_LOGGER.debug("Ignoring empty position message from '%s'", msg.topic) _LOGGER.debug("Ignoring empty position message from '%s'", msg.topic)
return return
try: try:
payload = json_loads(payload) payload_dict = json_loads(payload)
except JSON_DECODE_EXCEPTIONS: except JSON_DECODE_EXCEPTIONS:
pass pass
if isinstance(payload, dict): if payload_dict and isinstance(payload_dict, dict):
if "position" not in payload: if "position" not in payload_dict:
_LOGGER.warning( _LOGGER.warning(
"Template (position_template) returned JSON without position attribute" "Template (position_template) returned JSON without position attribute"
) )
return return
if "tilt_position" in payload: if "tilt_position" in payload_dict:
if not self._config.get(CONF_TILT_STATE_OPTIMISTIC): if not self._config.get(CONF_TILT_STATE_OPTIMISTIC):
# reset forced set tilt optimistic # reset forced set tilt optimistic
self._tilt_optimistic = False self._tilt_optimistic = False
self.tilt_payload_received(payload["tilt_position"]) self.tilt_payload_received(payload_dict["tilt_position"])
payload = payload["position"] payload = payload_dict["position"]
try: try:
percentage_payload = self.find_percentage_in_range( percentage_payload = self.find_percentage_in_range(
@ -481,7 +489,7 @@ class MqttCover(MqttEntity, CoverEntity):
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
) )
async def _subscribe_topics(self): async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) await subscription.async_subscribe_topics(self.hass, self._sub_state)
@ -719,13 +727,15 @@ class MqttCover(MqttEntity, CoverEntity):
else: else:
await self.async_close_cover_tilt(**kwargs) await self.async_close_cover_tilt(**kwargs)
def is_tilt_closed(self): def is_tilt_closed(self) -> bool:
"""Return if the cover is tilted closed.""" """Return if the cover is tilted closed."""
return self._tilt_value == self.find_percentage_in_range( return self._tilt_value == self.find_percentage_in_range(
float(self._config[CONF_TILT_CLOSED_POSITION]) float(self._config[CONF_TILT_CLOSED_POSITION])
) )
def find_percentage_in_range(self, position, range_type=TILT_PAYLOAD): def find_percentage_in_range(
self, position: float, range_type: str = TILT_PAYLOAD
) -> int:
"""Find the 0-100% value within the specified range.""" """Find the 0-100% value within the specified range."""
# the range of motion as defined by the min max values # the range of motion as defined by the min max values
if range_type == COVER_PAYLOAD: if range_type == COVER_PAYLOAD:
@ -745,7 +755,9 @@ class MqttCover(MqttEntity, CoverEntity):
return position_percentage return position_percentage
def find_in_range_from_percent(self, percentage, range_type=TILT_PAYLOAD): def find_in_range_from_percent(
self, percentage: float, range_type: str = TILT_PAYLOAD
) -> int:
""" """
Find the adjusted value for 0-100% within the specified range. Find the adjusted value for 0-100% within the specified range.
@ -768,7 +780,7 @@ class MqttCover(MqttEntity, CoverEntity):
return position return position
@callback @callback
def tilt_payload_received(self, _payload): def tilt_payload_received(self, _payload: Any) -> None:
"""Set the tilt value.""" """Set the tilt value."""
try: try: