Improve MQTT type hints part 2 (#80529)
* Improve typing camera * Improve typing cover * b64 encoding can be either bytes or a string.
This commit is contained in:
parent
b4ad03784f
commit
bda7e416c4
2 changed files with 67 additions and 47 deletions
|
@ -27,6 +27,7 @@ from .mixins import (
|
||||||
async_setup_platform_helper,
|
async_setup_platform_helper,
|
||||||
warn_for_legacy_schema,
|
warn_for_legacy_schema,
|
||||||
)
|
)
|
||||||
|
from .models import ReceiveMessage
|
||||||
from .util import valid_subscribe_topic
|
from .util import valid_subscribe_topic
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
@ -114,8 +115,8 @@ async def _async_setup_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
async_add_entities: AddEntitiesCallback,
|
async_add_entities: AddEntitiesCallback,
|
||||||
config: ConfigType,
|
config: ConfigType,
|
||||||
config_entry: ConfigEntry | None = None,
|
config_entry: ConfigEntry,
|
||||||
discovery_data: dict | None = None,
|
discovery_data: DiscoveryInfoType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up the MQTT Camera."""
|
"""Set up the MQTT Camera."""
|
||||||
async_add_entities([MqttCamera(hass, config, config_entry, discovery_data)])
|
async_add_entities([MqttCamera(hass, config, config_entry, discovery_data)])
|
||||||
|
@ -124,31 +125,38 @@ async def _async_setup_entity(
|
||||||
class MqttCamera(MqttEntity, Camera):
|
class MqttCamera(MqttEntity, Camera):
|
||||||
"""representation of a MQTT camera."""
|
"""representation of a MQTT camera."""
|
||||||
|
|
||||||
_entity_id_format = camera.ENTITY_ID_FORMAT
|
_entity_id_format: str = camera.ENTITY_ID_FORMAT
|
||||||
_attributes_extra_blocked = MQTT_CAMERA_ATTRIBUTES_BLOCKED
|
_attributes_extra_blocked: frozenset[str] = MQTT_CAMERA_ATTRIBUTES_BLOCKED
|
||||||
|
|
||||||
def __init__(self, hass, config, config_entry, discovery_data):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
discovery_data: DiscoveryInfoType | None,
|
||||||
|
) -> None:
|
||||||
"""Initialize the MQTT Camera."""
|
"""Initialize the MQTT Camera."""
|
||||||
self._last_image = None
|
self._last_image: bytes | None = None
|
||||||
|
|
||||||
Camera.__init__(self)
|
Camera.__init__(self)
|
||||||
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def config_schema():
|
def config_schema() -> vol.Schema:
|
||||||
"""Return the config schema."""
|
"""Return the config schema."""
|
||||||
return DISCOVERY_SCHEMA
|
return DISCOVERY_SCHEMA
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self):
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
@log_messages(self.hass, self.entity_id)
|
||||||
def message_received(msg):
|
def message_received(msg: ReceiveMessage) -> None:
|
||||||
"""Handle new MQTT messages."""
|
"""Handle new MQTT messages."""
|
||||||
if CONF_IMAGE_ENCODING in self._config:
|
if CONF_IMAGE_ENCODING in self._config:
|
||||||
self._last_image = b64decode(msg.payload)
|
self._last_image = b64decode(msg.payload)
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(msg.payload, bytes)
|
||||||
self._last_image = msg.payload
|
self._last_image = msg.payload
|
||||||
|
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
|
@ -164,7 +172,7 @@ class MqttCamera(MqttEntity, Camera):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _subscribe_topics(self):
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ from homeassistant.core import HomeAssistant, callback
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.helpers.json import JSON_DECODE_EXCEPTIONS, json_loads
|
from homeassistant.helpers.json import JSON_DECODE_EXCEPTIONS, json_loads
|
||||||
|
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
from . import subscription
|
from . import subscription
|
||||||
|
@ -50,7 +51,7 @@ from .mixins import (
|
||||||
async_setup_platform_helper,
|
async_setup_platform_helper,
|
||||||
warn_for_legacy_schema,
|
warn_for_legacy_schema,
|
||||||
)
|
)
|
||||||
from .models import MqttCommandTemplate, MqttValueTemplate
|
from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage
|
||||||
from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic
|
from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
@ -113,44 +114,44 @@ MQTT_COVER_ATTRIBUTES_BLOCKED = frozenset(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def validate_options(value):
|
def validate_options(config: ConfigType) -> ConfigType:
|
||||||
"""Validate options.
|
"""Validate options.
|
||||||
|
|
||||||
If set position topic is set then get position topic is set as well.
|
If set position topic is set then get position topic is set as well.
|
||||||
"""
|
"""
|
||||||
if CONF_SET_POSITION_TOPIC in value and CONF_GET_POSITION_TOPIC not in value:
|
if CONF_SET_POSITION_TOPIC in config and CONF_GET_POSITION_TOPIC not in config:
|
||||||
raise vol.Invalid(
|
raise vol.Invalid(
|
||||||
f"'{CONF_SET_POSITION_TOPIC}' must be set together with '{CONF_GET_POSITION_TOPIC}'."
|
f"'{CONF_SET_POSITION_TOPIC}' must be set together with '{CONF_GET_POSITION_TOPIC}'."
|
||||||
)
|
)
|
||||||
|
|
||||||
# if templates are set make sure the topic for the template is also set
|
# if templates are set make sure the topic for the template is also set
|
||||||
|
|
||||||
if CONF_VALUE_TEMPLATE in value and CONF_STATE_TOPIC not in value:
|
if CONF_VALUE_TEMPLATE in config and CONF_STATE_TOPIC not in config:
|
||||||
raise vol.Invalid(
|
raise vol.Invalid(
|
||||||
f"'{CONF_VALUE_TEMPLATE}' must be set together with '{CONF_STATE_TOPIC}'."
|
f"'{CONF_VALUE_TEMPLATE}' must be set together with '{CONF_STATE_TOPIC}'."
|
||||||
)
|
)
|
||||||
|
|
||||||
if CONF_GET_POSITION_TEMPLATE in value and CONF_GET_POSITION_TOPIC not in value:
|
if CONF_GET_POSITION_TEMPLATE in config and CONF_GET_POSITION_TOPIC not in config:
|
||||||
raise vol.Invalid(
|
raise vol.Invalid(
|
||||||
f"'{CONF_GET_POSITION_TEMPLATE}' must be set together with '{CONF_GET_POSITION_TOPIC}'."
|
f"'{CONF_GET_POSITION_TEMPLATE}' must be set together with '{CONF_GET_POSITION_TOPIC}'."
|
||||||
)
|
)
|
||||||
|
|
||||||
if CONF_SET_POSITION_TEMPLATE in value and CONF_SET_POSITION_TOPIC not in value:
|
if CONF_SET_POSITION_TEMPLATE in config and CONF_SET_POSITION_TOPIC not in config:
|
||||||
raise vol.Invalid(
|
raise vol.Invalid(
|
||||||
f"'{CONF_SET_POSITION_TEMPLATE}' must be set together with '{CONF_SET_POSITION_TOPIC}'."
|
f"'{CONF_SET_POSITION_TEMPLATE}' must be set together with '{CONF_SET_POSITION_TOPIC}'."
|
||||||
)
|
)
|
||||||
|
|
||||||
if CONF_TILT_COMMAND_TEMPLATE in value and CONF_TILT_COMMAND_TOPIC not in value:
|
if CONF_TILT_COMMAND_TEMPLATE in config and CONF_TILT_COMMAND_TOPIC not in config:
|
||||||
raise vol.Invalid(
|
raise vol.Invalid(
|
||||||
f"'{CONF_TILT_COMMAND_TEMPLATE}' must be set together with '{CONF_TILT_COMMAND_TOPIC}'."
|
f"'{CONF_TILT_COMMAND_TEMPLATE}' must be set together with '{CONF_TILT_COMMAND_TOPIC}'."
|
||||||
)
|
)
|
||||||
|
|
||||||
if CONF_TILT_STATUS_TEMPLATE in value and CONF_TILT_STATUS_TOPIC not in value:
|
if CONF_TILT_STATUS_TEMPLATE in config and CONF_TILT_STATUS_TOPIC not in config:
|
||||||
raise vol.Invalid(
|
raise vol.Invalid(
|
||||||
f"'{CONF_TILT_STATUS_TEMPLATE}' must be set together with '{CONF_TILT_STATUS_TOPIC}'."
|
f"'{CONF_TILT_STATUS_TEMPLATE}' must be set together with '{CONF_TILT_STATUS_TOPIC}'."
|
||||||
)
|
)
|
||||||
|
|
||||||
return value
|
return config
|
||||||
|
|
||||||
|
|
||||||
_PLATFORM_SCHEMA_BASE = MQTT_BASE_SCHEMA.extend(
|
_PLATFORM_SCHEMA_BASE = MQTT_BASE_SCHEMA.extend(
|
||||||
|
@ -251,8 +252,8 @@ async def _async_setup_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
async_add_entities: AddEntitiesCallback,
|
async_add_entities: AddEntitiesCallback,
|
||||||
config: ConfigType,
|
config: ConfigType,
|
||||||
config_entry: ConfigEntry | None = None,
|
config_entry: ConfigEntry,
|
||||||
discovery_data: dict | None = None,
|
discovery_data: DiscoveryInfoType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up the MQTT Cover."""
|
"""Set up the MQTT Cover."""
|
||||||
async_add_entities([MqttCover(hass, config, config_entry, discovery_data)])
|
async_add_entities([MqttCover(hass, config, config_entry, discovery_data)])
|
||||||
|
@ -261,26 +262,32 @@ async def _async_setup_entity(
|
||||||
class MqttCover(MqttEntity, CoverEntity):
|
class MqttCover(MqttEntity, CoverEntity):
|
||||||
"""Representation of a cover that can be controlled using MQTT."""
|
"""Representation of a cover that can be controlled using MQTT."""
|
||||||
|
|
||||||
_entity_id_format = cover.ENTITY_ID_FORMAT
|
_entity_id_format: str = cover.ENTITY_ID_FORMAT
|
||||||
_attributes_extra_blocked = MQTT_COVER_ATTRIBUTES_BLOCKED
|
_attributes_extra_blocked: frozenset[str] = MQTT_COVER_ATTRIBUTES_BLOCKED
|
||||||
|
|
||||||
def __init__(self, hass, config, config_entry, discovery_data):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
discovery_data: DiscoveryInfoType | None,
|
||||||
|
) -> None:
|
||||||
"""Initialize the cover."""
|
"""Initialize the cover."""
|
||||||
self._position = None
|
self._position: int | None = None
|
||||||
self._state = None
|
self._state: str | None = None
|
||||||
|
|
||||||
self._optimistic = None
|
self._optimistic: bool | None = None
|
||||||
self._tilt_value = None
|
self._tilt_value: int | None = None
|
||||||
self._tilt_optimistic = None
|
self._tilt_optimistic: bool | None = None
|
||||||
|
|
||||||
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def config_schema():
|
def config_schema() -> vol.Schema:
|
||||||
"""Return the config schema."""
|
"""Return the config schema."""
|
||||||
return DISCOVERY_SCHEMA
|
return DISCOVERY_SCHEMA
|
||||||
|
|
||||||
def _setup_from_config(self, config):
|
def _setup_from_config(self, config: ConfigType) -> None:
|
||||||
no_position = (
|
no_position = (
|
||||||
config.get(CONF_SET_POSITION_TOPIC) is None
|
config.get(CONF_SET_POSITION_TOPIC) is None
|
||||||
and config.get(CONF_GET_POSITION_TOPIC) is None
|
and config.get(CONF_GET_POSITION_TOPIC) is None
|
||||||
|
@ -353,13 +360,13 @@ class MqttCover(MqttEntity, CoverEntity):
|
||||||
config_attributes=template_config_attributes,
|
config_attributes=template_config_attributes,
|
||||||
).async_render_with_possible_json_value
|
).async_render_with_possible_json_value
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self):
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
topics = {}
|
topics = {}
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
@log_messages(self.hass, self.entity_id)
|
||||||
def tilt_message_received(msg):
|
def tilt_message_received(msg: ReceiveMessage) -> None:
|
||||||
"""Handle tilt updates."""
|
"""Handle tilt updates."""
|
||||||
payload = self._tilt_status_template(msg.payload)
|
payload = self._tilt_status_template(msg.payload)
|
||||||
|
|
||||||
|
@ -371,7 +378,7 @@ class MqttCover(MqttEntity, CoverEntity):
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
@log_messages(self.hass, self.entity_id)
|
||||||
def state_message_received(msg):
|
def state_message_received(msg: ReceiveMessage) -> None:
|
||||||
"""Handle new MQTT state messages."""
|
"""Handle new MQTT state messages."""
|
||||||
payload = self._value_template(msg.payload)
|
payload = self._value_template(msg.payload)
|
||||||
|
|
||||||
|
@ -409,31 +416,32 @@ class MqttCover(MqttEntity, CoverEntity):
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
@log_messages(self.hass, self.entity_id)
|
||||||
def position_message_received(msg):
|
def position_message_received(msg: ReceiveMessage) -> None:
|
||||||
"""Handle new MQTT position messages."""
|
"""Handle new MQTT position messages."""
|
||||||
payload = self._get_position_template(msg.payload)
|
payload: ReceivePayloadType = self._get_position_template(msg.payload)
|
||||||
|
payload_dict: Any = None
|
||||||
|
|
||||||
if not payload:
|
if not payload:
|
||||||
_LOGGER.debug("Ignoring empty position message from '%s'", msg.topic)
|
_LOGGER.debug("Ignoring empty position message from '%s'", msg.topic)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = json_loads(payload)
|
payload_dict = json_loads(payload)
|
||||||
except JSON_DECODE_EXCEPTIONS:
|
except JSON_DECODE_EXCEPTIONS:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if isinstance(payload, dict):
|
if payload_dict and isinstance(payload_dict, dict):
|
||||||
if "position" not in payload:
|
if "position" not in payload_dict:
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"Template (position_template) returned JSON without position attribute"
|
"Template (position_template) returned JSON without position attribute"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if "tilt_position" in payload:
|
if "tilt_position" in payload_dict:
|
||||||
if not self._config.get(CONF_TILT_STATE_OPTIMISTIC):
|
if not self._config.get(CONF_TILT_STATE_OPTIMISTIC):
|
||||||
# reset forced set tilt optimistic
|
# reset forced set tilt optimistic
|
||||||
self._tilt_optimistic = False
|
self._tilt_optimistic = False
|
||||||
self.tilt_payload_received(payload["tilt_position"])
|
self.tilt_payload_received(payload_dict["tilt_position"])
|
||||||
payload = payload["position"]
|
payload = payload_dict["position"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
percentage_payload = self.find_percentage_in_range(
|
percentage_payload = self.find_percentage_in_range(
|
||||||
|
@ -481,7 +489,7 @@ class MqttCover(MqttEntity, CoverEntity):
|
||||||
self.hass, self._sub_state, topics
|
self.hass, self._sub_state, topics
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _subscribe_topics(self):
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||||
|
|
||||||
|
@ -719,13 +727,15 @@ class MqttCover(MqttEntity, CoverEntity):
|
||||||
else:
|
else:
|
||||||
await self.async_close_cover_tilt(**kwargs)
|
await self.async_close_cover_tilt(**kwargs)
|
||||||
|
|
||||||
def is_tilt_closed(self):
|
def is_tilt_closed(self) -> bool:
|
||||||
"""Return if the cover is tilted closed."""
|
"""Return if the cover is tilted closed."""
|
||||||
return self._tilt_value == self.find_percentage_in_range(
|
return self._tilt_value == self.find_percentage_in_range(
|
||||||
float(self._config[CONF_TILT_CLOSED_POSITION])
|
float(self._config[CONF_TILT_CLOSED_POSITION])
|
||||||
)
|
)
|
||||||
|
|
||||||
def find_percentage_in_range(self, position, range_type=TILT_PAYLOAD):
|
def find_percentage_in_range(
|
||||||
|
self, position: float, range_type: str = TILT_PAYLOAD
|
||||||
|
) -> int:
|
||||||
"""Find the 0-100% value within the specified range."""
|
"""Find the 0-100% value within the specified range."""
|
||||||
# the range of motion as defined by the min max values
|
# the range of motion as defined by the min max values
|
||||||
if range_type == COVER_PAYLOAD:
|
if range_type == COVER_PAYLOAD:
|
||||||
|
@ -745,7 +755,9 @@ class MqttCover(MqttEntity, CoverEntity):
|
||||||
|
|
||||||
return position_percentage
|
return position_percentage
|
||||||
|
|
||||||
def find_in_range_from_percent(self, percentage, range_type=TILT_PAYLOAD):
|
def find_in_range_from_percent(
|
||||||
|
self, percentage: float, range_type: str = TILT_PAYLOAD
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Find the adjusted value for 0-100% within the specified range.
|
Find the adjusted value for 0-100% within the specified range.
|
||||||
|
|
||||||
|
@ -768,7 +780,7 @@ class MqttCover(MqttEntity, CoverEntity):
|
||||||
return position
|
return position
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def tilt_payload_received(self, _payload):
|
def tilt_payload_received(self, _payload: Any) -> None:
|
||||||
"""Set the tilt value."""
|
"""Set the tilt value."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue