Strict typing for shared MQTT modules (#80913)
Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
parent
052c673c9e
commit
2d9f39d406
6 changed files with 75 additions and 48 deletions
|
@ -182,7 +182,7 @@ MQTT_PUBLISH_SCHEMA = vol.All(
|
|||
|
||||
|
||||
async def _async_setup_discovery(
|
||||
hass: HomeAssistant, conf: ConfigType, config_entry
|
||||
hass: HomeAssistant, conf: ConfigType, config_entry: ConfigEntry
|
||||
) -> None:
|
||||
"""Try to start the discovery of MQTT devices.
|
||||
|
||||
|
@ -377,7 +377,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
retain: bool = call.data[ATTR_RETAIN]
|
||||
if msg_topic_template is not None:
|
||||
try:
|
||||
rendered_topic = template.Template(
|
||||
rendered_topic: Any = template.Template(
|
||||
msg_topic_template, hass
|
||||
).async_render(parse_result=False)
|
||||
msg_topic = valid_publish_topic(rendered_topic)
|
||||
|
@ -620,12 +620,12 @@ def async_subscribe_connection_status(
|
|||
"""Subscribe to MQTT connection changes."""
|
||||
connection_status_callback_job = HassJob(connection_status_callback)
|
||||
|
||||
async def connected():
|
||||
async def connected() -> None:
|
||||
task = hass.async_run_hass_job(connection_status_callback_job, True)
|
||||
if task:
|
||||
await task
|
||||
|
||||
async def disconnected():
|
||||
async def disconnected() -> None:
|
||||
task = hass.async_run_hass_job(connection_status_callback_job, False)
|
||||
if task:
|
||||
await task
|
||||
|
@ -636,7 +636,7 @@ def async_subscribe_connection_status(
|
|||
}
|
||||
|
||||
@callback
|
||||
def unsubscribe():
|
||||
def unsubscribe() -> None:
|
||||
subscriptions["connect"]()
|
||||
subscriptions["disconnect"]()
|
||||
|
||||
|
|
|
@ -245,7 +245,7 @@ def subscribe(
|
|||
async_subscribe(hass, topic, msg_callback, qos, encoding), hass.loop
|
||||
).result()
|
||||
|
||||
def remove():
|
||||
def remove() -> None:
|
||||
"""Remove listener convert."""
|
||||
run_callback_threadsafe(hass.loop, async_remove).result()
|
||||
|
||||
|
@ -341,7 +341,7 @@ class MQTT:
|
|||
self._ha_started = asyncio.Event()
|
||||
self._last_subscribe = time.time()
|
||||
self._mqttc: mqtt.Client = None
|
||||
self._cleanup_on_unload: list[Callable] = []
|
||||
self._cleanup_on_unload: list[Callable[[], None]] = []
|
||||
|
||||
self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client
|
||||
self._pending_operations: dict[int, asyncio.Event] = {}
|
||||
|
@ -352,14 +352,14 @@ class MQTT:
|
|||
else:
|
||||
|
||||
@callback
|
||||
def ha_started(_):
|
||||
def ha_started(_: Event) -> None:
|
||||
self._ha_started.set()
|
||||
|
||||
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started)
|
||||
|
||||
self.init_client()
|
||||
|
||||
async def async_stop_mqtt(_event: Event):
|
||||
async def async_stop_mqtt(_event: Event) -> None:
|
||||
"""Stop MQTT component."""
|
||||
await self.async_disconnect()
|
||||
|
||||
|
@ -506,9 +506,11 @@ class MQTT:
|
|||
|
||||
def _client_unsubscribe(topic: str) -> int:
|
||||
result: int | None = None
|
||||
mid: int | None = None
|
||||
result, mid = self._mqttc.unsubscribe(topic)
|
||||
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
|
||||
_raise_on_error(result)
|
||||
assert mid
|
||||
return mid
|
||||
|
||||
if any(other.topic == topic for other in self.subscriptions):
|
||||
|
@ -553,7 +555,13 @@ class MQTT:
|
|||
if errors:
|
||||
_raise_on_errors(errors)
|
||||
|
||||
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None:
|
||||
def _mqtt_on_connect(
|
||||
self,
|
||||
_mqttc: mqtt.Client,
|
||||
_userdata: None,
|
||||
_flags: dict[str, Any],
|
||||
result_code: int,
|
||||
) -> None:
|
||||
"""On connect callback.
|
||||
|
||||
Resubscribe to all topics we were subscribed to and publish birth
|
||||
|
@ -596,7 +604,7 @@ class MQTT:
|
|||
and ATTR_TOPIC in self.conf[CONF_BIRTH_MESSAGE]
|
||||
):
|
||||
|
||||
async def publish_birth_message(birth_message):
|
||||
async def publish_birth_message(birth_message: PublishMessage) -> None:
|
||||
await self._ha_started.wait() # Wait for Home Assistant to start
|
||||
await self._discovery_cooldown() # Wait for MQTT discovery to cool down
|
||||
await self.async_publish(
|
||||
|
@ -611,7 +619,9 @@ class MQTT:
|
|||
publish_birth_message(birth_message), self.hass.loop
|
||||
)
|
||||
|
||||
def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None:
|
||||
def _mqtt_on_message(
|
||||
self, _mqttc: mqtt.Client, _userdata: None, msg: MQTTMessage
|
||||
) -> None:
|
||||
"""Message received callback."""
|
||||
self.hass.add_job(self._mqtt_handle_message, msg)
|
||||
|
||||
|
@ -663,7 +673,13 @@ class MQTT:
|
|||
)
|
||||
self._mqtt_data.state_write_requests.process_write_state_requests()
|
||||
|
||||
def _mqtt_on_callback(self, _mqttc, _userdata, mid, _granted_qos=None) -> None:
|
||||
def _mqtt_on_callback(
|
||||
self,
|
||||
_mqttc: mqtt.Client,
|
||||
_userdata: None,
|
||||
mid: int,
|
||||
_granted_qos: tuple[Any, ...] | None = None,
|
||||
) -> None:
|
||||
"""Publish / Subscribe / Unsubscribe callback."""
|
||||
self.hass.add_job(self._mqtt_handle_mid, mid)
|
||||
|
||||
|
@ -679,7 +695,9 @@ class MQTT:
|
|||
if mid not in self._pending_operations:
|
||||
self._pending_operations[mid] = asyncio.Event()
|
||||
|
||||
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
|
||||
def _mqtt_on_disconnect(
|
||||
self, _mqttc: mqtt.Client, _userdata: None, result_code: int
|
||||
) -> None:
|
||||
"""Disconnected callback."""
|
||||
self.connected = False
|
||||
dispatcher_send(self.hass, MQTT_DISCONNECTED)
|
||||
|
@ -707,7 +725,7 @@ class MQTT:
|
|||
del self._pending_operations[mid]
|
||||
self._pending_operations_condition.notify_all()
|
||||
|
||||
async def _discovery_cooldown(self):
|
||||
async def _discovery_cooldown(self) -> None:
|
||||
now = time.time()
|
||||
# Reset discovery and subscribe cooldowns
|
||||
self._mqtt_data.last_discovery = now
|
||||
|
|
|
@ -287,7 +287,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
|||
options_config: dict[str, Any] = {}
|
||||
bad_input: bool = False
|
||||
|
||||
def _birth_will(birt_or_will: str) -> dict:
|
||||
def _birth_will(birt_or_will: str) -> dict[str, Any]:
|
||||
"""Return the user input for birth or will."""
|
||||
assert user_input
|
||||
return {
|
||||
|
@ -298,8 +298,11 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
|||
}
|
||||
|
||||
def _validate(
|
||||
field: str, values: dict[str, Any], error_code: str, schema: Callable
|
||||
):
|
||||
field: str,
|
||||
values: dict[str, Any],
|
||||
error_code: str,
|
||||
schema: Callable[[Any], Any],
|
||||
) -> None:
|
||||
"""Validate the user input."""
|
||||
nonlocal bad_input
|
||||
try:
|
||||
|
@ -679,7 +682,9 @@ def try_connection(
|
|||
|
||||
result: queue.Queue[bool] = queue.Queue(maxsize=1)
|
||||
|
||||
def on_connect(client_, userdata, flags, result_code):
|
||||
def on_connect(
|
||||
client_: mqtt.Client, userdata: None, flags: dict[str, Any], result_code: int
|
||||
) -> None:
|
||||
"""Handle connection result."""
|
||||
result.put(result_code == mqtt.CONNACK_ACCEPTED)
|
||||
|
||||
|
|
|
@ -287,6 +287,7 @@ async def async_get_platform_config_from_yaml(
|
|||
config_yaml: ConfigType | None = None,
|
||||
) -> list[ConfigType]:
|
||||
"""Return a list of validated configurations for the domain."""
|
||||
platform_configs: Any | None
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
if config_yaml is None:
|
||||
config_yaml = mqtt_data.config
|
||||
|
@ -294,6 +295,7 @@ async def async_get_platform_config_from_yaml(
|
|||
return []
|
||||
if not (platform_configs := config_yaml.get(platform_domain)):
|
||||
return []
|
||||
assert isinstance(platform_configs, list)
|
||||
return platform_configs
|
||||
|
||||
|
||||
|
@ -662,7 +664,9 @@ def stop_discovery_updates(
|
|||
clear_discovery_hash(hass, discovery_hash)
|
||||
|
||||
|
||||
async def async_remove_discovery_payload(hass: HomeAssistant, discovery_data: dict):
|
||||
async def async_remove_discovery_payload(
|
||||
hass: HomeAssistant, discovery_data: DiscoveryInfoType
|
||||
) -> None:
|
||||
"""Clear retained discovery topic in broker to avoid rediscovery after a restart of HA."""
|
||||
discovery_topic = discovery_data[ATTR_DISCOVERY_TOPIC]
|
||||
await async_publish(hass, discovery_topic, "", retain=True)
|
||||
|
@ -820,7 +824,7 @@ class MqttDiscoveryUpdate(Entity):
|
|||
"""Initialize the discovery update mixin."""
|
||||
self._discovery_data = discovery_data
|
||||
self._discovery_update = discovery_update
|
||||
self._remove_discovery_updated: Callable | None = None
|
||||
self._remove_discovery_updated: Callable[[], None] | None = None
|
||||
self._removed_from_hass = False
|
||||
if discovery_data is None:
|
||||
return
|
||||
|
@ -1169,7 +1173,7 @@ def update_device(
|
|||
device_info = device_info_from_specifications(config[CONF_DEVICE])
|
||||
|
||||
if config_entry_id is not None and device_info is not None:
|
||||
update_device_info = cast(dict, device_info)
|
||||
update_device_info = cast(dict[str, Any], device_info)
|
||||
update_device_info["config_entry_id"] = config_entry_id
|
||||
device = device_registry.async_get_or_create(**update_device_info)
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ class EntitySubscription:
|
|||
hass: HomeAssistant = attr.ib()
|
||||
topic: str = attr.ib()
|
||||
message_callback: MessageCallbackType = attr.ib()
|
||||
subscribe_task: Coroutine | None = attr.ib()
|
||||
subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None = attr.ib()
|
||||
unsubscribe_callback: Callable[[], None] | None = attr.ib()
|
||||
qos: int = attr.ib(default=0)
|
||||
encoding: str = attr.ib(default="utf-8")
|
||||
|
@ -53,7 +53,7 @@ class EntitySubscription:
|
|||
hass, self.topic, self.message_callback, self.qos, self.encoding
|
||||
)
|
||||
|
||||
async def subscribe(self):
|
||||
async def subscribe(self) -> None:
|
||||
"""Subscribe to a topic."""
|
||||
if not self.subscribe_task:
|
||||
return
|
||||
|
|
|
@ -40,58 +40,58 @@ def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:
|
|||
return not bool(hass.config_entries.async_entries(DOMAIN)[0].disabled_by)
|
||||
|
||||
|
||||
def valid_topic(value: Any) -> str:
|
||||
def valid_topic(topic: Any) -> str:
|
||||
"""Validate that this is a valid topic name/filter."""
|
||||
value = cv.string(value)
|
||||
validated_topic = cv.string(topic)
|
||||
try:
|
||||
raw_value = value.encode("utf-8")
|
||||
raw_validated_topic = validated_topic.encode("utf-8")
|
||||
except UnicodeError as err:
|
||||
raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.") from err
|
||||
if not raw_value:
|
||||
if not raw_validated_topic:
|
||||
raise vol.Invalid("MQTT topic name/filter must not be empty.")
|
||||
if len(raw_value) > 65535:
|
||||
if len(raw_validated_topic) > 65535:
|
||||
raise vol.Invalid(
|
||||
"MQTT topic name/filter must not be longer than 65535 encoded bytes."
|
||||
)
|
||||
if "\0" in value:
|
||||
if "\0" in validated_topic:
|
||||
raise vol.Invalid("MQTT topic name/filter must not contain null character.")
|
||||
if any(char <= "\u001F" for char in value):
|
||||
if any(char <= "\u001F" for char in validated_topic):
|
||||
raise vol.Invalid("MQTT topic name/filter must not contain control characters.")
|
||||
if any("\u007f" <= char <= "\u009F" for char in value):
|
||||
if any("\u007f" <= char <= "\u009F" for char in validated_topic):
|
||||
raise vol.Invalid("MQTT topic name/filter must not contain control characters.")
|
||||
if any("\ufdd0" <= char <= "\ufdef" for char in value):
|
||||
if any("\ufdd0" <= char <= "\ufdef" for char in validated_topic):
|
||||
raise vol.Invalid("MQTT topic name/filter must not contain non-characters.")
|
||||
if any((ord(char) & 0xFFFF) in (0xFFFE, 0xFFFF) for char in value):
|
||||
if any((ord(char) & 0xFFFF) in (0xFFFE, 0xFFFF) for char in validated_topic):
|
||||
raise vol.Invalid("MQTT topic name/filter must not contain noncharacters.")
|
||||
|
||||
return value
|
||||
return validated_topic
|
||||
|
||||
|
||||
def valid_subscribe_topic(value: Any) -> str:
|
||||
def valid_subscribe_topic(topic: Any) -> str:
|
||||
"""Validate that we can subscribe using this MQTT topic."""
|
||||
value = valid_topic(value)
|
||||
for i in (i for i, c in enumerate(value) if c == "+"):
|
||||
if (i > 0 and value[i - 1] != "/") or (
|
||||
i < len(value) - 1 and value[i + 1] != "/"
|
||||
validated_topic = valid_topic(topic)
|
||||
for i in (i for i, c in enumerate(validated_topic) if c == "+"):
|
||||
if (i > 0 and validated_topic[i - 1] != "/") or (
|
||||
i < len(validated_topic) - 1 and validated_topic[i + 1] != "/"
|
||||
):
|
||||
raise vol.Invalid(
|
||||
"Single-level wildcard must occupy an entire level of the filter"
|
||||
)
|
||||
|
||||
index = value.find("#")
|
||||
index = validated_topic.find("#")
|
||||
if index != -1:
|
||||
if index != len(value) - 1:
|
||||
if index != len(validated_topic) - 1:
|
||||
# If there are multiple wildcards, this will also trigger
|
||||
raise vol.Invalid(
|
||||
"Multi-level wildcard must be the last "
|
||||
"character in the topic filter."
|
||||
)
|
||||
if len(value) > 1 and value[index - 1] != "/":
|
||||
if len(validated_topic) > 1 and validated_topic[index - 1] != "/":
|
||||
raise vol.Invalid(
|
||||
"Multi-level wildcard must be after a topic level separator."
|
||||
)
|
||||
|
||||
return value
|
||||
return validated_topic
|
||||
|
||||
|
||||
def valid_subscribe_topic_template(value: Any) -> template.Template:
|
||||
|
@ -104,12 +104,12 @@ def valid_subscribe_topic_template(value: Any) -> template.Template:
|
|||
return tpl
|
||||
|
||||
|
||||
def valid_publish_topic(value: Any) -> str:
|
||||
def valid_publish_topic(topic: Any) -> str:
|
||||
"""Validate that we can publish using this MQTT topic."""
|
||||
value = valid_topic(value)
|
||||
if "+" in value or "#" in value:
|
||||
validated_topic = valid_topic(topic)
|
||||
if "+" in validated_topic or "#" in validated_topic:
|
||||
raise vol.Invalid("Wildcards can not be used in topic names")
|
||||
return value
|
||||
return validated_topic
|
||||
|
||||
|
||||
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue