Convert MQTT to use asyncio (#115910)
This commit is contained in:
parent
5a24690d79
commit
423544401e
6 changed files with 464 additions and 90 deletions
|
@ -3,12 +3,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from collections.abc import AsyncGenerator, Callable, Coroutine, Iterable
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from functools import lru_cache, partial
|
||||
from itertools import chain, groupby
|
||||
import logging
|
||||
from operator import attrgetter
|
||||
import socket
|
||||
import ssl
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
@ -35,7 +37,7 @@ from homeassistant.core import (
|
|||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.dispatcher import dispatcher_send
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
@ -92,6 +94,9 @@ INITIAL_SUBSCRIBE_COOLDOWN = 1.0
|
|||
SUBSCRIBE_COOLDOWN = 0.1
|
||||
UNSUBSCRIBE_COOLDOWN = 0.1
|
||||
TIMEOUT_ACK = 10
|
||||
RECONNECT_INTERVAL_SECONDS = 10
|
||||
|
||||
SocketType = socket.socket | ssl.SSLSocket | Any
|
||||
|
||||
SubscribePayloadType = str | bytes # Only bytes if encoding is None
|
||||
|
||||
|
@ -258,7 +263,9 @@ class MqttClientSetup:
|
|||
# However, that feature is not mandatory so we generate our own.
|
||||
client_id = mqtt.base62(uuid.uuid4().int, padding=22)
|
||||
transport = config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT)
|
||||
self._client = mqtt.Client(client_id, protocol=proto, transport=transport)
|
||||
self._client = mqtt.Client(
|
||||
client_id, protocol=proto, transport=transport, reconnect_on_failure=False
|
||||
)
|
||||
|
||||
# Enable logging
|
||||
self._client.enable_logger()
|
||||
|
@ -404,12 +411,17 @@ class MQTT:
|
|||
self._ha_started = asyncio.Event()
|
||||
self._cleanup_on_unload: list[Callable[[], None]] = []
|
||||
|
||||
self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client
|
||||
self._connection_lock = asyncio.Lock()
|
||||
self._pending_operations: dict[int, asyncio.Event] = {}
|
||||
self._pending_operations_condition = asyncio.Condition()
|
||||
self._subscribe_debouncer = EnsureJobAfterCooldown(
|
||||
INITIAL_SUBSCRIBE_COOLDOWN, self._async_perform_subscriptions
|
||||
)
|
||||
self._misc_task: asyncio.Task | None = None
|
||||
self._reconnect_task: asyncio.Task | None = None
|
||||
self._should_reconnect: bool = True
|
||||
self._available_future: asyncio.Future[bool] | None = None
|
||||
|
||||
self._max_qos: dict[str, int] = {} # topic, max qos
|
||||
self._pending_subscriptions: dict[str, int] = {} # topic, qos
|
||||
self._unsubscribe_debouncer = EnsureJobAfterCooldown(
|
||||
|
@ -456,25 +468,140 @@ class MQTT:
|
|||
while self._cleanup_on_unload:
|
||||
self._cleanup_on_unload.pop()()
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _async_connect_in_executor(self) -> AsyncGenerator[None, None]:
|
||||
# While we are connecting in the executor we need to
|
||||
# handle on_socket_open and on_socket_register_write
|
||||
# in the executor as well.
|
||||
mqttc = self._mqttc
|
||||
try:
|
||||
mqttc.on_socket_open = self._on_socket_open
|
||||
mqttc.on_socket_register_write = self._on_socket_register_write
|
||||
yield
|
||||
finally:
|
||||
# Once the executor job is done, we can switch back to
|
||||
# handling these in the event loop.
|
||||
mqttc.on_socket_open = self._async_on_socket_open
|
||||
mqttc.on_socket_register_write = self._async_on_socket_register_write
|
||||
|
||||
def init_client(self) -> None:
|
||||
"""Initialize paho client."""
|
||||
self._mqttc = MqttClientSetup(self.conf).client
|
||||
self._mqttc.on_connect = self._mqtt_on_connect
|
||||
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
||||
self._mqttc.on_message = self._mqtt_on_message
|
||||
self._mqttc.on_publish = self._mqtt_on_callback
|
||||
self._mqttc.on_subscribe = self._mqtt_on_callback
|
||||
self._mqttc.on_unsubscribe = self._mqtt_on_callback
|
||||
mqttc = MqttClientSetup(self.conf).client
|
||||
# on_socket_unregister_write and _async_on_socket_close
|
||||
# are only ever called in the event loop
|
||||
mqttc.on_socket_close = self._async_on_socket_close
|
||||
mqttc.on_socket_unregister_write = self._async_on_socket_unregister_write
|
||||
|
||||
# These will be called in the event loop
|
||||
mqttc.on_connect = self._async_mqtt_on_connect
|
||||
mqttc.on_disconnect = self._async_mqtt_on_disconnect
|
||||
mqttc.on_message = self._async_mqtt_on_message
|
||||
mqttc.on_publish = self._async_mqtt_on_callback
|
||||
mqttc.on_subscribe = self._async_mqtt_on_callback
|
||||
mqttc.on_unsubscribe = self._async_mqtt_on_callback
|
||||
|
||||
if will := self.conf.get(CONF_WILL_MESSAGE, DEFAULT_WILL):
|
||||
will_message = PublishMessage(**will)
|
||||
self._mqttc.will_set(
|
||||
mqttc.will_set(
|
||||
topic=will_message.topic,
|
||||
payload=will_message.payload,
|
||||
qos=will_message.qos,
|
||||
retain=will_message.retain,
|
||||
)
|
||||
|
||||
self._mqttc = mqttc
|
||||
|
||||
async def _misc_loop(self) -> None:
|
||||
"""Start the MQTT client misc loop."""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
while self._mqttc.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@callback
|
||||
def _async_reader_callback(self, client: mqtt.Client) -> None:
|
||||
"""Handle reading data from the socket."""
|
||||
if (status := client.loop_read()) != 0:
|
||||
self._async_on_disconnect(status)
|
||||
|
||||
@callback
|
||||
def _async_start_misc_loop(self) -> None:
|
||||
"""Start the misc loop."""
|
||||
if self._misc_task is None or self._misc_task.done():
|
||||
_LOGGER.debug("%s: Starting client misc loop", self.config_entry.title)
|
||||
self._misc_task = self.config_entry.async_create_background_task(
|
||||
self.hass, self._misc_loop(), name="mqtt misc loop"
|
||||
)
|
||||
|
||||
def _on_socket_open(
|
||||
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
||||
) -> None:
|
||||
"""Handle socket open."""
|
||||
self.loop.call_soon_threadsafe(
|
||||
self._async_on_socket_open, client, userdata, sock
|
||||
)
|
||||
|
||||
@callback
|
||||
def _async_on_socket_open(
|
||||
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
||||
) -> None:
|
||||
"""Handle socket open."""
|
||||
fileno = sock.fileno()
|
||||
_LOGGER.debug("%s: connection opened %s", self.config_entry.title, fileno)
|
||||
if fileno > -1:
|
||||
self.loop.add_reader(sock, partial(self._async_reader_callback, client))
|
||||
self._async_start_misc_loop()
|
||||
|
||||
@callback
|
||||
def _async_on_socket_close(
|
||||
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
||||
) -> None:
|
||||
"""Handle socket close."""
|
||||
fileno = sock.fileno()
|
||||
_LOGGER.debug("%s: connection closed %s", self.config_entry.title, fileno)
|
||||
# If socket close is called before the connect
|
||||
# result is set make sure the first connection result is set
|
||||
self._async_connection_result(False)
|
||||
if fileno > -1:
|
||||
self.loop.remove_reader(sock)
|
||||
if self._misc_task is not None and not self._misc_task.done():
|
||||
self._misc_task.cancel()
|
||||
|
||||
@callback
|
||||
def _async_writer_callback(self, client: mqtt.Client) -> None:
|
||||
"""Handle writing data to the socket."""
|
||||
if (status := client.loop_write()) != 0:
|
||||
self._async_on_disconnect(status)
|
||||
|
||||
def _on_socket_register_write(
|
||||
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
||||
) -> None:
|
||||
"""Register the socket for writing."""
|
||||
self.loop.call_soon_threadsafe(
|
||||
self._async_on_socket_register_write, client, None, sock
|
||||
)
|
||||
|
||||
@callback
|
||||
def _async_on_socket_register_write(
|
||||
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
||||
) -> None:
|
||||
"""Register the socket for writing."""
|
||||
fileno = sock.fileno()
|
||||
_LOGGER.debug("%s: register write %s", self.config_entry.title, fileno)
|
||||
if fileno > -1:
|
||||
self.loop.add_writer(sock, partial(self._async_writer_callback, client))
|
||||
|
||||
@callback
|
||||
def _async_on_socket_unregister_write(
|
||||
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
||||
) -> None:
|
||||
"""Unregister the socket for writing."""
|
||||
fileno = sock.fileno()
|
||||
_LOGGER.debug("%s: unregister write %s", self.config_entry.title, fileno)
|
||||
if fileno > -1:
|
||||
self.loop.remove_writer(sock)
|
||||
|
||||
def _is_active_subscription(self, topic: str) -> bool:
|
||||
"""Check if a topic has an active subscription."""
|
||||
return topic in self._simple_subscriptions or any(
|
||||
|
@ -485,10 +612,7 @@ class MQTT:
|
|||
self, topic: str, payload: PublishPayloadType, qos: int, retain: bool
|
||||
) -> None:
|
||||
"""Publish a MQTT message."""
|
||||
async with self._paho_lock:
|
||||
msg_info = await self.hass.async_add_executor_job(
|
||||
self._mqttc.publish, topic, payload, qos, retain
|
||||
)
|
||||
msg_info = self._mqttc.publish(topic, payload, qos, retain)
|
||||
_LOGGER.debug(
|
||||
"Transmitting%s message on %s: '%s', mid: %s, qos: %s",
|
||||
" retained" if retain else "",
|
||||
|
@ -500,37 +624,71 @@ class MQTT:
|
|||
_raise_on_error(msg_info.rc)
|
||||
await self._wait_for_mid(msg_info.mid)
|
||||
|
||||
async def async_connect(self) -> None:
|
||||
async def async_connect(self, client_available: asyncio.Future[bool]) -> None:
|
||||
"""Connect to the host. Does not process messages yet."""
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
result: int | None = None
|
||||
self._available_future = client_available
|
||||
self._should_reconnect = True
|
||||
try:
|
||||
result = await self.hass.async_add_executor_job(
|
||||
self._mqttc.connect,
|
||||
self.conf[CONF_BROKER],
|
||||
self.conf.get(CONF_PORT, DEFAULT_PORT),
|
||||
self.conf.get(CONF_KEEPALIVE, DEFAULT_KEEPALIVE),
|
||||
)
|
||||
async with self._connection_lock, self._async_connect_in_executor():
|
||||
result = await self.hass.async_add_executor_job(
|
||||
self._mqttc.connect,
|
||||
self.conf[CONF_BROKER],
|
||||
self.conf.get(CONF_PORT, DEFAULT_PORT),
|
||||
self.conf.get(CONF_KEEPALIVE, DEFAULT_KEEPALIVE),
|
||||
)
|
||||
except OSError as err:
|
||||
_LOGGER.error("Failed to connect to MQTT server due to exception: %s", err)
|
||||
self._async_connection_result(False)
|
||||
finally:
|
||||
if result is not None and result != 0:
|
||||
if result is not None:
|
||||
_LOGGER.error(
|
||||
"Failed to connect to MQTT server: %s",
|
||||
mqtt.error_string(result),
|
||||
)
|
||||
self._async_connection_result(False)
|
||||
|
||||
if result is not None and result != 0:
|
||||
_LOGGER.error(
|
||||
"Failed to connect to MQTT server: %s", mqtt.error_string(result)
|
||||
@callback
|
||||
def _async_connection_result(self, connected: bool) -> None:
|
||||
"""Handle a connection result."""
|
||||
if self._available_future and not self._available_future.done():
|
||||
self._available_future.set_result(connected)
|
||||
|
||||
if connected:
|
||||
self._async_cancel_reconnect()
|
||||
elif self._should_reconnect and not self._reconnect_task:
|
||||
self._reconnect_task = self.config_entry.async_create_background_task(
|
||||
self.hass, self._reconnect_loop(), "mqtt reconnect loop"
|
||||
)
|
||||
|
||||
self._mqttc.loop_start()
|
||||
@callback
|
||||
def _async_cancel_reconnect(self) -> None:
|
||||
"""Cancel the reconnect task."""
|
||||
if self._reconnect_task:
|
||||
self._reconnect_task.cancel()
|
||||
self._reconnect_task = None
|
||||
|
||||
async def _reconnect_loop(self) -> None:
|
||||
"""Reconnect to the MQTT server."""
|
||||
while True:
|
||||
if not self.connected:
|
||||
try:
|
||||
async with self._connection_lock, self._async_connect_in_executor():
|
||||
await self.hass.async_add_executor_job(self._mqttc.reconnect)
|
||||
except OSError as err:
|
||||
_LOGGER.debug(
|
||||
"Error re-connecting to MQTT server due to exception: %s", err
|
||||
)
|
||||
|
||||
await asyncio.sleep(RECONNECT_INTERVAL_SECONDS)
|
||||
|
||||
async def async_disconnect(self) -> None:
|
||||
"""Stop the MQTT client."""
|
||||
|
||||
def stop() -> None:
|
||||
"""Stop the MQTT client."""
|
||||
# Do not disconnect, we want the broker to always publish will
|
||||
self._mqttc.loop_stop()
|
||||
|
||||
def no_more_acks() -> bool:
|
||||
"""Return False if there are unprocessed ACKs."""
|
||||
return not any(not op.is_set() for op in self._pending_operations.values())
|
||||
|
@ -549,8 +707,10 @@ class MQTT:
|
|||
await self._pending_operations_condition.wait_for(no_more_acks)
|
||||
|
||||
# stop the MQTT loop
|
||||
async with self._paho_lock:
|
||||
await self.hass.async_add_executor_job(stop)
|
||||
async with self._connection_lock:
|
||||
self._should_reconnect = False
|
||||
self._async_cancel_reconnect()
|
||||
self._mqttc.disconnect()
|
||||
|
||||
@callback
|
||||
def async_restore_tracked_subscriptions(
|
||||
|
@ -689,11 +849,8 @@ class MQTT:
|
|||
subscriptions: dict[str, int] = self._pending_subscriptions
|
||||
self._pending_subscriptions = {}
|
||||
|
||||
async with self._paho_lock:
|
||||
subscription_list = list(subscriptions.items())
|
||||
result, mid = await self.hass.async_add_executor_job(
|
||||
self._mqttc.subscribe, subscription_list
|
||||
)
|
||||
subscription_list = list(subscriptions.items())
|
||||
result, mid = self._mqttc.subscribe(subscription_list)
|
||||
|
||||
for topic, qos in subscriptions.items():
|
||||
_LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos)
|
||||
|
@ -712,17 +869,15 @@ class MQTT:
|
|||
topics = list(self._pending_unsubscribes)
|
||||
self._pending_unsubscribes = set()
|
||||
|
||||
async with self._paho_lock:
|
||||
result, mid = await self.hass.async_add_executor_job(
|
||||
self._mqttc.unsubscribe, topics
|
||||
)
|
||||
result, mid = self._mqttc.unsubscribe(topics)
|
||||
_raise_on_error(result)
|
||||
for topic in topics:
|
||||
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
|
||||
|
||||
await self._wait_for_mid(mid)
|
||||
|
||||
def _mqtt_on_connect(
|
||||
@callback
|
||||
def _async_mqtt_on_connect(
|
||||
self,
|
||||
_mqttc: mqtt.Client,
|
||||
_userdata: None,
|
||||
|
@ -746,7 +901,7 @@ class MQTT:
|
|||
return
|
||||
|
||||
self.connected = True
|
||||
dispatcher_send(self.hass, MQTT_CONNECTED)
|
||||
async_dispatcher_send(self.hass, MQTT_CONNECTED)
|
||||
_LOGGER.info(
|
||||
"Connected to MQTT server %s:%s (%s)",
|
||||
self.conf[CONF_BROKER],
|
||||
|
@ -754,7 +909,7 @@ class MQTT:
|
|||
result_code,
|
||||
)
|
||||
|
||||
self.hass.create_task(self._async_resubscribe())
|
||||
self.hass.async_create_task(self._async_resubscribe())
|
||||
|
||||
if birth := self.conf.get(CONF_BIRTH_MESSAGE, DEFAULT_BIRTH):
|
||||
|
||||
|
@ -771,13 +926,17 @@ class MQTT:
|
|||
)
|
||||
|
||||
birth_message = PublishMessage(**birth)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
publish_birth_message(birth_message), self.hass.loop
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
publish_birth_message(birth_message),
|
||||
name="mqtt birth message",
|
||||
)
|
||||
else:
|
||||
# Update subscribe cooldown period to a shorter time
|
||||
self._subscribe_debouncer.set_timeout(SUBSCRIBE_COOLDOWN)
|
||||
|
||||
self._async_connection_result(True)
|
||||
|
||||
async def _async_resubscribe(self) -> None:
|
||||
"""Resubscribe on reconnect."""
|
||||
self._max_qos.clear()
|
||||
|
@ -796,16 +955,6 @@ class MQTT:
|
|||
)
|
||||
await self._async_perform_subscriptions()
|
||||
|
||||
def _mqtt_on_message(
|
||||
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
|
||||
) -> None:
|
||||
"""Message received callback."""
|
||||
# MQTT messages tend to be high volume,
|
||||
# and since they come in via a thread and need to be processed in the event loop,
|
||||
# we want to avoid hass.add_job since most of the time is spent calling
|
||||
# inspect to figure out how to run the callback.
|
||||
self.loop.call_soon_threadsafe(self._mqtt_handle_message, msg)
|
||||
|
||||
@lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def _matching_subscriptions(self, topic: str) -> list[Subscription]:
|
||||
subscriptions: list[Subscription] = []
|
||||
|
@ -819,7 +968,9 @@ class MQTT:
|
|||
return subscriptions
|
||||
|
||||
@callback
|
||||
def _mqtt_handle_message(self, msg: mqtt.MQTTMessage) -> None:
|
||||
def _async_mqtt_on_message(
|
||||
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
|
||||
) -> None:
|
||||
topic = msg.topic
|
||||
# msg.topic is a property that decodes the topic to a string
|
||||
# every time it is accessed. Save the result to avoid
|
||||
|
@ -878,7 +1029,8 @@ class MQTT:
|
|||
self.hass.async_run_hass_job(subscription.job, receive_msg)
|
||||
self._mqtt_data.state_write_requests.process_write_state_requests(msg)
|
||||
|
||||
def _mqtt_on_callback(
|
||||
@callback
|
||||
def _async_mqtt_on_callback(
|
||||
self,
|
||||
_mqttc: mqtt.Client,
|
||||
_userdata: None,
|
||||
|
@ -890,7 +1042,7 @@ class MQTT:
|
|||
# The callback signature for on_unsubscribe is different from on_subscribe
|
||||
# see https://github.com/eclipse/paho.mqtt.python/issues/687
|
||||
# properties and reasoncodes are not used in Home Assistant
|
||||
self.hass.create_task(self._mqtt_handle_mid(mid))
|
||||
self.hass.async_create_task(self._mqtt_handle_mid(mid))
|
||||
|
||||
async def _mqtt_handle_mid(self, mid: int) -> None:
|
||||
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
|
||||
|
@ -906,7 +1058,8 @@ class MQTT:
|
|||
if mid not in self._pending_operations:
|
||||
self._pending_operations[mid] = asyncio.Event()
|
||||
|
||||
def _mqtt_on_disconnect(
|
||||
@callback
|
||||
def _async_mqtt_on_disconnect(
|
||||
self,
|
||||
_mqttc: mqtt.Client,
|
||||
_userdata: None,
|
||||
|
@ -914,8 +1067,19 @@ class MQTT:
|
|||
properties: mqtt.Properties | None = None,
|
||||
) -> None:
|
||||
"""Disconnected callback."""
|
||||
self._async_on_disconnect(result_code)
|
||||
|
||||
@callback
|
||||
def _async_on_disconnect(self, result_code: int) -> None:
|
||||
if not self.connected:
|
||||
# This function is re-entrant and may be called multiple times
|
||||
# when there is a broken pipe error.
|
||||
return
|
||||
# If disconnect is called before the connect
|
||||
# result is set make sure the first connection result is set
|
||||
self._async_connection_result(False)
|
||||
self.connected = False
|
||||
dispatcher_send(self.hass, MQTT_DISCONNECTED)
|
||||
async_dispatcher_send(self.hass, MQTT_DISCONNECTED)
|
||||
_LOGGER.warning(
|
||||
"Disconnected from MQTT server %s:%s (%s)",
|
||||
self.conf[CONF_BROKER],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue