Convert MQTT to use asyncio (#115910)

This commit is contained in:
J. Nick Koston 2024-04-21 22:33:58 +02:00 committed by GitHub
parent 5a24690d79
commit 423544401e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 464 additions and 90 deletions

View file

@ -265,7 +265,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
conf: dict[str, Any] conf: dict[str, Any]
mqtt_data: MqttData mqtt_data: MqttData
async def _setup_client() -> tuple[MqttData, dict[str, Any]]: async def _setup_client(
client_available: asyncio.Future[bool],
) -> tuple[MqttData, dict[str, Any]]:
"""Set up the MQTT client.""" """Set up the MQTT client."""
# Fetch configuration # Fetch configuration
conf = dict(entry.data) conf = dict(entry.data)
@ -294,7 +296,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
entry.add_update_listener(_async_config_entry_updated) entry.add_update_listener(_async_config_entry_updated)
) )
await mqtt_data.client.async_connect() await mqtt_data.client.async_connect(client_available)
return (mqtt_data, conf) return (mqtt_data, conf)
client_available: asyncio.Future[bool] client_available: asyncio.Future[bool]
@ -303,13 +305,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
else: else:
client_available = hass.data[DATA_MQTT_AVAILABLE] client_available = hass.data[DATA_MQTT_AVAILABLE]
setup_ok: bool = False mqtt_data, conf = await _setup_client(client_available)
try:
mqtt_data, conf = await _setup_client()
setup_ok = True
finally:
if not client_available.done():
client_available.set_result(setup_ok)
async def async_publish_service(call: ServiceCall) -> None: async def async_publish_service(call: ServiceCall) -> None:
"""Handle MQTT publish service calls.""" """Handle MQTT publish service calls."""

View file

@ -3,12 +3,14 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Coroutine, Iterable from collections.abc import AsyncGenerator, Callable, Coroutine, Iterable
import contextlib
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache, partial
from itertools import chain, groupby from itertools import chain, groupby
import logging import logging
from operator import attrgetter from operator import attrgetter
import socket
import ssl import ssl
import time import time
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@ -35,7 +37,7 @@ from homeassistant.core import (
callback, callback,
) )
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.dispatcher import dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -92,6 +94,9 @@ INITIAL_SUBSCRIBE_COOLDOWN = 1.0
SUBSCRIBE_COOLDOWN = 0.1 SUBSCRIBE_COOLDOWN = 0.1
UNSUBSCRIBE_COOLDOWN = 0.1 UNSUBSCRIBE_COOLDOWN = 0.1
TIMEOUT_ACK = 10 TIMEOUT_ACK = 10
RECONNECT_INTERVAL_SECONDS = 10
SocketType = socket.socket | ssl.SSLSocket | Any
SubscribePayloadType = str | bytes # Only bytes if encoding is None SubscribePayloadType = str | bytes # Only bytes if encoding is None
@ -258,7 +263,9 @@ class MqttClientSetup:
# However, that feature is not mandatory so we generate our own. # However, that feature is not mandatory so we generate our own.
client_id = mqtt.base62(uuid.uuid4().int, padding=22) client_id = mqtt.base62(uuid.uuid4().int, padding=22)
transport = config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT) transport = config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT)
self._client = mqtt.Client(client_id, protocol=proto, transport=transport) self._client = mqtt.Client(
client_id, protocol=proto, transport=transport, reconnect_on_failure=False
)
# Enable logging # Enable logging
self._client.enable_logger() self._client.enable_logger()
@ -404,12 +411,17 @@ class MQTT:
self._ha_started = asyncio.Event() self._ha_started = asyncio.Event()
self._cleanup_on_unload: list[Callable[[], None]] = [] self._cleanup_on_unload: list[Callable[[], None]] = []
self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client self._connection_lock = asyncio.Lock()
self._pending_operations: dict[int, asyncio.Event] = {} self._pending_operations: dict[int, asyncio.Event] = {}
self._pending_operations_condition = asyncio.Condition() self._pending_operations_condition = asyncio.Condition()
self._subscribe_debouncer = EnsureJobAfterCooldown( self._subscribe_debouncer = EnsureJobAfterCooldown(
INITIAL_SUBSCRIBE_COOLDOWN, self._async_perform_subscriptions INITIAL_SUBSCRIBE_COOLDOWN, self._async_perform_subscriptions
) )
self._misc_task: asyncio.Task | None = None
self._reconnect_task: asyncio.Task | None = None
self._should_reconnect: bool = True
self._available_future: asyncio.Future[bool] | None = None
self._max_qos: dict[str, int] = {} # topic, max qos self._max_qos: dict[str, int] = {} # topic, max qos
self._pending_subscriptions: dict[str, int] = {} # topic, qos self._pending_subscriptions: dict[str, int] = {} # topic, qos
self._unsubscribe_debouncer = EnsureJobAfterCooldown( self._unsubscribe_debouncer = EnsureJobAfterCooldown(
@ -456,25 +468,140 @@ class MQTT:
while self._cleanup_on_unload: while self._cleanup_on_unload:
self._cleanup_on_unload.pop()() self._cleanup_on_unload.pop()()
@contextlib.asynccontextmanager
async def _async_connect_in_executor(self) -> AsyncGenerator[None, None]:
# While we are connecting in the executor we need to
# handle on_socket_open and on_socket_register_write
# in the executor as well.
mqttc = self._mqttc
try:
mqttc.on_socket_open = self._on_socket_open
mqttc.on_socket_register_write = self._on_socket_register_write
yield
finally:
# Once the executor job is done, we can switch back to
# handling these in the event loop.
mqttc.on_socket_open = self._async_on_socket_open
mqttc.on_socket_register_write = self._async_on_socket_register_write
def init_client(self) -> None: def init_client(self) -> None:
"""Initialize paho client.""" """Initialize paho client."""
self._mqttc = MqttClientSetup(self.conf).client mqttc = MqttClientSetup(self.conf).client
self._mqttc.on_connect = self._mqtt_on_connect # on_socket_unregister_write and _async_on_socket_close
self._mqttc.on_disconnect = self._mqtt_on_disconnect # are only ever called in the event loop
self._mqttc.on_message = self._mqtt_on_message mqttc.on_socket_close = self._async_on_socket_close
self._mqttc.on_publish = self._mqtt_on_callback mqttc.on_socket_unregister_write = self._async_on_socket_unregister_write
self._mqttc.on_subscribe = self._mqtt_on_callback
self._mqttc.on_unsubscribe = self._mqtt_on_callback # These will be called in the event loop
mqttc.on_connect = self._async_mqtt_on_connect
mqttc.on_disconnect = self._async_mqtt_on_disconnect
mqttc.on_message = self._async_mqtt_on_message
mqttc.on_publish = self._async_mqtt_on_callback
mqttc.on_subscribe = self._async_mqtt_on_callback
mqttc.on_unsubscribe = self._async_mqtt_on_callback
if will := self.conf.get(CONF_WILL_MESSAGE, DEFAULT_WILL): if will := self.conf.get(CONF_WILL_MESSAGE, DEFAULT_WILL):
will_message = PublishMessage(**will) will_message = PublishMessage(**will)
self._mqttc.will_set( mqttc.will_set(
topic=will_message.topic, topic=will_message.topic,
payload=will_message.payload, payload=will_message.payload,
qos=will_message.qos, qos=will_message.qos,
retain=will_message.retain, retain=will_message.retain,
) )
self._mqttc = mqttc
async def _misc_loop(self) -> None:
"""Start the MQTT client misc loop."""
# pylint: disable=import-outside-toplevel
import paho.mqtt.client as mqtt
while self._mqttc.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
await asyncio.sleep(1)
@callback
def _async_reader_callback(self, client: mqtt.Client) -> None:
"""Handle reading data from the socket."""
if (status := client.loop_read()) != 0:
self._async_on_disconnect(status)
@callback
def _async_start_misc_loop(self) -> None:
"""Start the misc loop."""
if self._misc_task is None or self._misc_task.done():
_LOGGER.debug("%s: Starting client misc loop", self.config_entry.title)
self._misc_task = self.config_entry.async_create_background_task(
self.hass, self._misc_loop(), name="mqtt misc loop"
)
def _on_socket_open(
self, client: mqtt.Client, userdata: Any, sock: SocketType
) -> None:
"""Handle socket open."""
self.loop.call_soon_threadsafe(
self._async_on_socket_open, client, userdata, sock
)
@callback
def _async_on_socket_open(
self, client: mqtt.Client, userdata: Any, sock: SocketType
) -> None:
"""Handle socket open."""
fileno = sock.fileno()
_LOGGER.debug("%s: connection opened %s", self.config_entry.title, fileno)
if fileno > -1:
self.loop.add_reader(sock, partial(self._async_reader_callback, client))
self._async_start_misc_loop()
@callback
def _async_on_socket_close(
self, client: mqtt.Client, userdata: Any, sock: SocketType
) -> None:
"""Handle socket close."""
fileno = sock.fileno()
_LOGGER.debug("%s: connection closed %s", self.config_entry.title, fileno)
# If socket close is called before the connect
# result is set make sure the first connection result is set
self._async_connection_result(False)
if fileno > -1:
self.loop.remove_reader(sock)
if self._misc_task is not None and not self._misc_task.done():
self._misc_task.cancel()
@callback
def _async_writer_callback(self, client: mqtt.Client) -> None:
"""Handle writing data to the socket."""
if (status := client.loop_write()) != 0:
self._async_on_disconnect(status)
def _on_socket_register_write(
self, client: mqtt.Client, userdata: Any, sock: SocketType
) -> None:
"""Register the socket for writing."""
self.loop.call_soon_threadsafe(
self._async_on_socket_register_write, client, None, sock
)
@callback
def _async_on_socket_register_write(
self, client: mqtt.Client, userdata: Any, sock: SocketType
) -> None:
"""Register the socket for writing."""
fileno = sock.fileno()
_LOGGER.debug("%s: register write %s", self.config_entry.title, fileno)
if fileno > -1:
self.loop.add_writer(sock, partial(self._async_writer_callback, client))
@callback
def _async_on_socket_unregister_write(
self, client: mqtt.Client, userdata: Any, sock: SocketType
) -> None:
"""Unregister the socket for writing."""
fileno = sock.fileno()
_LOGGER.debug("%s: unregister write %s", self.config_entry.title, fileno)
if fileno > -1:
self.loop.remove_writer(sock)
def _is_active_subscription(self, topic: str) -> bool: def _is_active_subscription(self, topic: str) -> bool:
"""Check if a topic has an active subscription.""" """Check if a topic has an active subscription."""
return topic in self._simple_subscriptions or any( return topic in self._simple_subscriptions or any(
@ -485,10 +612,7 @@ class MQTT:
self, topic: str, payload: PublishPayloadType, qos: int, retain: bool self, topic: str, payload: PublishPayloadType, qos: int, retain: bool
) -> None: ) -> None:
"""Publish a MQTT message.""" """Publish a MQTT message."""
async with self._paho_lock: msg_info = self._mqttc.publish(topic, payload, qos, retain)
msg_info = await self.hass.async_add_executor_job(
self._mqttc.publish, topic, payload, qos, retain
)
_LOGGER.debug( _LOGGER.debug(
"Transmitting%s message on %s: '%s', mid: %s, qos: %s", "Transmitting%s message on %s: '%s', mid: %s, qos: %s",
" retained" if retain else "", " retained" if retain else "",
@ -500,37 +624,71 @@ class MQTT:
_raise_on_error(msg_info.rc) _raise_on_error(msg_info.rc)
await self._wait_for_mid(msg_info.mid) await self._wait_for_mid(msg_info.mid)
async def async_connect(self) -> None: async def async_connect(self, client_available: asyncio.Future[bool]) -> None:
"""Connect to the host. Does not process messages yet.""" """Connect to the host. Does not process messages yet."""
# pylint: disable-next=import-outside-toplevel # pylint: disable-next=import-outside-toplevel
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
result: int | None = None result: int | None = None
self._available_future = client_available
self._should_reconnect = True
try: try:
result = await self.hass.async_add_executor_job( async with self._connection_lock, self._async_connect_in_executor():
self._mqttc.connect, result = await self.hass.async_add_executor_job(
self.conf[CONF_BROKER], self._mqttc.connect,
self.conf.get(CONF_PORT, DEFAULT_PORT), self.conf[CONF_BROKER],
self.conf.get(CONF_KEEPALIVE, DEFAULT_KEEPALIVE), self.conf.get(CONF_PORT, DEFAULT_PORT),
) self.conf.get(CONF_KEEPALIVE, DEFAULT_KEEPALIVE),
)
except OSError as err: except OSError as err:
_LOGGER.error("Failed to connect to MQTT server due to exception: %s", err) _LOGGER.error("Failed to connect to MQTT server due to exception: %s", err)
self._async_connection_result(False)
finally:
if result is not None and result != 0:
if result is not None:
_LOGGER.error(
"Failed to connect to MQTT server: %s",
mqtt.error_string(result),
)
self._async_connection_result(False)
if result is not None and result != 0: @callback
_LOGGER.error( def _async_connection_result(self, connected: bool) -> None:
"Failed to connect to MQTT server: %s", mqtt.error_string(result) """Handle a connection result."""
if self._available_future and not self._available_future.done():
self._available_future.set_result(connected)
if connected:
self._async_cancel_reconnect()
elif self._should_reconnect and not self._reconnect_task:
self._reconnect_task = self.config_entry.async_create_background_task(
self.hass, self._reconnect_loop(), "mqtt reconnect loop"
) )
self._mqttc.loop_start() @callback
def _async_cancel_reconnect(self) -> None:
"""Cancel the reconnect task."""
if self._reconnect_task:
self._reconnect_task.cancel()
self._reconnect_task = None
async def _reconnect_loop(self) -> None:
"""Reconnect to the MQTT server."""
while True:
if not self.connected:
try:
async with self._connection_lock, self._async_connect_in_executor():
await self.hass.async_add_executor_job(self._mqttc.reconnect)
except OSError as err:
_LOGGER.debug(
"Error re-connecting to MQTT server due to exception: %s", err
)
await asyncio.sleep(RECONNECT_INTERVAL_SECONDS)
async def async_disconnect(self) -> None: async def async_disconnect(self) -> None:
"""Stop the MQTT client.""" """Stop the MQTT client."""
def stop() -> None:
"""Stop the MQTT client."""
# Do not disconnect, we want the broker to always publish will
self._mqttc.loop_stop()
def no_more_acks() -> bool: def no_more_acks() -> bool:
"""Return False if there are unprocessed ACKs.""" """Return False if there are unprocessed ACKs."""
return not any(not op.is_set() for op in self._pending_operations.values()) return not any(not op.is_set() for op in self._pending_operations.values())
@ -549,8 +707,10 @@ class MQTT:
await self._pending_operations_condition.wait_for(no_more_acks) await self._pending_operations_condition.wait_for(no_more_acks)
# stop the MQTT loop # stop the MQTT loop
async with self._paho_lock: async with self._connection_lock:
await self.hass.async_add_executor_job(stop) self._should_reconnect = False
self._async_cancel_reconnect()
self._mqttc.disconnect()
@callback @callback
def async_restore_tracked_subscriptions( def async_restore_tracked_subscriptions(
@ -689,11 +849,8 @@ class MQTT:
subscriptions: dict[str, int] = self._pending_subscriptions subscriptions: dict[str, int] = self._pending_subscriptions
self._pending_subscriptions = {} self._pending_subscriptions = {}
async with self._paho_lock: subscription_list = list(subscriptions.items())
subscription_list = list(subscriptions.items()) result, mid = self._mqttc.subscribe(subscription_list)
result, mid = await self.hass.async_add_executor_job(
self._mqttc.subscribe, subscription_list
)
for topic, qos in subscriptions.items(): for topic, qos in subscriptions.items():
_LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos) _LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos)
@ -712,17 +869,15 @@ class MQTT:
topics = list(self._pending_unsubscribes) topics = list(self._pending_unsubscribes)
self._pending_unsubscribes = set() self._pending_unsubscribes = set()
async with self._paho_lock: result, mid = self._mqttc.unsubscribe(topics)
result, mid = await self.hass.async_add_executor_job(
self._mqttc.unsubscribe, topics
)
_raise_on_error(result) _raise_on_error(result)
for topic in topics: for topic in topics:
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid) _LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
await self._wait_for_mid(mid) await self._wait_for_mid(mid)
def _mqtt_on_connect( @callback
def _async_mqtt_on_connect(
self, self,
_mqttc: mqtt.Client, _mqttc: mqtt.Client,
_userdata: None, _userdata: None,
@ -746,7 +901,7 @@ class MQTT:
return return
self.connected = True self.connected = True
dispatcher_send(self.hass, MQTT_CONNECTED) async_dispatcher_send(self.hass, MQTT_CONNECTED)
_LOGGER.info( _LOGGER.info(
"Connected to MQTT server %s:%s (%s)", "Connected to MQTT server %s:%s (%s)",
self.conf[CONF_BROKER], self.conf[CONF_BROKER],
@ -754,7 +909,7 @@ class MQTT:
result_code, result_code,
) )
self.hass.create_task(self._async_resubscribe()) self.hass.async_create_task(self._async_resubscribe())
if birth := self.conf.get(CONF_BIRTH_MESSAGE, DEFAULT_BIRTH): if birth := self.conf.get(CONF_BIRTH_MESSAGE, DEFAULT_BIRTH):
@ -771,13 +926,17 @@ class MQTT:
) )
birth_message = PublishMessage(**birth) birth_message = PublishMessage(**birth)
asyncio.run_coroutine_threadsafe( self.config_entry.async_create_background_task(
publish_birth_message(birth_message), self.hass.loop self.hass,
publish_birth_message(birth_message),
name="mqtt birth message",
) )
else: else:
# Update subscribe cooldown period to a shorter time # Update subscribe cooldown period to a shorter time
self._subscribe_debouncer.set_timeout(SUBSCRIBE_COOLDOWN) self._subscribe_debouncer.set_timeout(SUBSCRIBE_COOLDOWN)
self._async_connection_result(True)
async def _async_resubscribe(self) -> None: async def _async_resubscribe(self) -> None:
"""Resubscribe on reconnect.""" """Resubscribe on reconnect."""
self._max_qos.clear() self._max_qos.clear()
@ -796,16 +955,6 @@ class MQTT:
) )
await self._async_perform_subscriptions() await self._async_perform_subscriptions()
def _mqtt_on_message(
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
) -> None:
"""Message received callback."""
# MQTT messages tend to be high volume,
# and since they come in via a thread and need to be processed in the event loop,
# we want to avoid hass.add_job since most of the time is spent calling
# inspect to figure out how to run the callback.
self.loop.call_soon_threadsafe(self._mqtt_handle_message, msg)
@lru_cache(None) # pylint: disable=method-cache-max-size-none @lru_cache(None) # pylint: disable=method-cache-max-size-none
def _matching_subscriptions(self, topic: str) -> list[Subscription]: def _matching_subscriptions(self, topic: str) -> list[Subscription]:
subscriptions: list[Subscription] = [] subscriptions: list[Subscription] = []
@ -819,7 +968,9 @@ class MQTT:
return subscriptions return subscriptions
@callback @callback
def _mqtt_handle_message(self, msg: mqtt.MQTTMessage) -> None: def _async_mqtt_on_message(
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
) -> None:
topic = msg.topic topic = msg.topic
# msg.topic is a property that decodes the topic to a string # msg.topic is a property that decodes the topic to a string
# every time it is accessed. Save the result to avoid # every time it is accessed. Save the result to avoid
@ -878,7 +1029,8 @@ class MQTT:
self.hass.async_run_hass_job(subscription.job, receive_msg) self.hass.async_run_hass_job(subscription.job, receive_msg)
self._mqtt_data.state_write_requests.process_write_state_requests(msg) self._mqtt_data.state_write_requests.process_write_state_requests(msg)
def _mqtt_on_callback( @callback
def _async_mqtt_on_callback(
self, self,
_mqttc: mqtt.Client, _mqttc: mqtt.Client,
_userdata: None, _userdata: None,
@ -890,7 +1042,7 @@ class MQTT:
# The callback signature for on_unsubscribe is different from on_subscribe # The callback signature for on_unsubscribe is different from on_subscribe
# see https://github.com/eclipse/paho.mqtt.python/issues/687 # see https://github.com/eclipse/paho.mqtt.python/issues/687
# properties and reasoncodes are not used in Home Assistant # properties and reasoncodes are not used in Home Assistant
self.hass.create_task(self._mqtt_handle_mid(mid)) self.hass.async_create_task(self._mqtt_handle_mid(mid))
async def _mqtt_handle_mid(self, mid: int) -> None: async def _mqtt_handle_mid(self, mid: int) -> None:
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid # Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
@ -906,7 +1058,8 @@ class MQTT:
if mid not in self._pending_operations: if mid not in self._pending_operations:
self._pending_operations[mid] = asyncio.Event() self._pending_operations[mid] = asyncio.Event()
def _mqtt_on_disconnect( @callback
def _async_mqtt_on_disconnect(
self, self,
_mqttc: mqtt.Client, _mqttc: mqtt.Client,
_userdata: None, _userdata: None,
@ -914,8 +1067,19 @@ class MQTT:
properties: mqtt.Properties | None = None, properties: mqtt.Properties | None = None,
) -> None: ) -> None:
"""Disconnected callback.""" """Disconnected callback."""
self._async_on_disconnect(result_code)
@callback
def _async_on_disconnect(self, result_code: int) -> None:
if not self.connected:
# This function is re-entrant and may be called multiple times
# when there is a broken pipe error.
return
# If disconnect is called before the connect
# result is set make sure the first connection result is set
self._async_connection_result(False)
self.connected = False self.connected = False
dispatcher_send(self.hass, MQTT_DISCONNECTED) async_dispatcher_send(self.hass, MQTT_DISCONNECTED)
_LOGGER.warning( _LOGGER.warning(
"Disconnected from MQTT server %s:%s (%s)", "Disconnected from MQTT server %s:%s (%s)",
self.conf[CONF_BROKER], self.conf[CONF_BROKER],

View file

@ -452,7 +452,7 @@ def async_fire_mqtt_message(
mqtt_data: MqttData = hass.data["mqtt"] mqtt_data: MqttData = hass.data["mqtt"]
assert mqtt_data.client assert mqtt_data.client
mqtt_data.client._mqtt_handle_message(msg) mqtt_data.client._async_mqtt_on_message(Mock(), None, msg)
fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message) fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message)

View file

@ -4,17 +4,22 @@ import asyncio
from copy import deepcopy from copy import deepcopy
from datetime import datetime, timedelta from datetime import datetime, timedelta
import json import json
import socket
import ssl import ssl
from typing import Any, TypedDict from typing import Any, TypedDict
from unittest.mock import ANY, MagicMock, call, mock_open, patch from unittest.mock import ANY, MagicMock, call, mock_open, patch
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
import paho.mqtt.client as paho_mqtt
import pytest import pytest
import voluptuous as vol import voluptuous as vol
from homeassistant.components import mqtt from homeassistant.components import mqtt
from homeassistant.components.mqtt import debug_info from homeassistant.components.mqtt import debug_info
from homeassistant.components.mqtt.client import EnsureJobAfterCooldown from homeassistant.components.mqtt.client import (
RECONNECT_INTERVAL_SECONDS,
EnsureJobAfterCooldown,
)
from homeassistant.components.mqtt.mixins import MQTT_ENTITY_DEVICE_INFO_SCHEMA from homeassistant.components.mqtt.mixins import MQTT_ENTITY_DEVICE_INFO_SCHEMA
from homeassistant.components.mqtt.models import ( from homeassistant.components.mqtt.models import (
MessageCallbackType, MessageCallbackType,
@ -146,7 +151,7 @@ async def test_mqtt_disconnects_on_home_assistant_stop(
hass.bus.fire(EVENT_HOMEASSISTANT_STOP) hass.bus.fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
assert mqtt_client_mock.loop_stop.call_count == 1 assert mqtt_client_mock.disconnect.call_count == 1
async def test_mqtt_await_ack_at_disconnect( async def test_mqtt_await_ack_at_disconnect(
@ -161,8 +166,14 @@ async def test_mqtt_await_ack_at_disconnect(
rc = 0 rc = 0
with patch("paho.mqtt.client.Client") as mock_client: with patch("paho.mqtt.client.Client") as mock_client:
mock_client().connect = MagicMock(return_value=0) mqtt_client = mock_client.return_value
mock_client().publish = MagicMock(return_value=FakeInfo()) mqtt_client.connect = MagicMock(
return_value=0,
side_effect=lambda *args, **kwargs: hass.loop.call_soon_threadsafe(
mqtt_client.on_connect, mqtt_client, None, 0, 0, 0
),
)
mqtt_client.publish = MagicMock(return_value=FakeInfo())
entry = MockConfigEntry( entry = MockConfigEntry(
domain=mqtt.DOMAIN, domain=mqtt.DOMAIN,
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"}, data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"},
@ -1669,6 +1680,7 @@ async def test_not_calling_subscribe_when_unsubscribed_within_cooldown(
the subscribe cool down period has ended. the subscribe cool down period has ended.
""" """
mqtt_mock = await mqtt_mock_entry() mqtt_mock = await mqtt_mock_entry()
mqtt_client_mock.subscribe.reset_mock()
# Fake that the client is connected # Fake that the client is connected
mqtt_mock().connected = True mqtt_mock().connected = True
@ -1925,6 +1937,7 @@ async def test_canceling_debouncer_on_shutdown(
"""Test canceling the debouncer when HA shuts down.""" """Test canceling the debouncer when HA shuts down."""
mqtt_mock = await mqtt_mock_entry() mqtt_mock = await mqtt_mock_entry()
mqtt_client_mock.subscribe.reset_mock()
# Fake that the client is connected # Fake that the client is connected
mqtt_mock().connected = True mqtt_mock().connected = True
@ -2008,7 +2021,7 @@ async def test_initial_setup_logs_error(
"""Test for setup failure if initial client connection fails.""" """Test for setup failure if initial client connection fails."""
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"}) entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
entry.add_to_hass(hass) entry.add_to_hass(hass)
mqtt_client_mock.connect.return_value = 1 mqtt_client_mock.connect.side_effect = MagicMock(return_value=1)
try: try:
assert await hass.config_entries.async_setup(entry.entry_id) assert await hass.config_entries.async_setup(entry.entry_id)
except HomeAssistantError: except HomeAssistantError:
@ -2230,7 +2243,12 @@ async def test_handle_mqtt_timeout_on_callback(
mock_client = mock_client.return_value mock_client = mock_client.return_value
mock_client.publish.return_value = FakeInfo() mock_client.publish.return_value = FakeInfo()
mock_client.subscribe.side_effect = _mock_ack mock_client.subscribe.side_effect = _mock_ack
mock_client.connect.return_value = 0 mock_client.connect = MagicMock(
return_value=0,
side_effect=lambda *args, **kwargs: hass.loop.call_soon_threadsafe(
mock_client.on_connect, mock_client, None, 0, 0, 0
),
)
entry = MockConfigEntry( entry = MockConfigEntry(
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"} domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"}
@ -4144,3 +4162,179 @@ async def test_multi_platform_discovery(
) )
is not None is not None
) )
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
async def test_auto_reconnect(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry: MqttMockHAClientGenerator,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test reconnection is automatically done."""
mqtt_mock = await mqtt_mock_entry()
await hass.async_block_till_done()
assert mqtt_mock.connected is True
mqtt_client_mock.reconnect.reset_mock()
mqtt_client_mock.disconnect()
mqtt_client_mock.on_disconnect(None, None, 0)
await hass.async_block_till_done()
mqtt_client_mock.reconnect.side_effect = OSError("foo")
async_fire_time_changed(
hass, utcnow() + timedelta(seconds=RECONNECT_INTERVAL_SECONDS)
)
await hass.async_block_till_done()
assert len(mqtt_client_mock.reconnect.mock_calls) == 1
assert "Error re-connecting to MQTT server due to exception: foo" in caplog.text
mqtt_client_mock.reconnect.side_effect = None
async_fire_time_changed(
hass, utcnow() + timedelta(seconds=RECONNECT_INTERVAL_SECONDS)
)
await hass.async_block_till_done()
assert len(mqtt_client_mock.reconnect.mock_calls) == 2
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
mqtt_client_mock.disconnect()
mqtt_client_mock.on_disconnect(None, None, 0)
await hass.async_block_till_done()
async_fire_time_changed(
hass, utcnow() + timedelta(seconds=RECONNECT_INTERVAL_SECONDS)
)
await hass.async_block_till_done()
# Should not reconnect after stop
assert len(mqtt_client_mock.reconnect.mock_calls) == 2
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
async def test_server_sock_connect_and_disconnect(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry: MqttMockHAClientGenerator,
calls: list[ReceiveMessage],
record_calls: MessageCallbackType,
) -> None:
"""Test handling the socket connected and disconnected."""
mqtt_mock = await mqtt_mock_entry()
await hass.async_block_till_done()
assert mqtt_mock.connected is True
mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_SUCCESS
client, server = socket.socketpair(
family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0
)
client.setblocking(False)
server.setblocking(False)
mqtt_client_mock.on_socket_open(mqtt_client_mock, None, client)
mqtt_client_mock.on_socket_register_write(mqtt_client_mock, None, client)
await hass.async_block_till_done()
server.close() # mock the server closing the connection on us
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_CONN_LOST
mqtt_client_mock.on_socket_unregister_write(mqtt_client_mock, None, client)
mqtt_client_mock.on_socket_close(mqtt_client_mock, None, client)
mqtt_client_mock.on_disconnect(mqtt_client_mock, None, client)
await hass.async_block_till_done()
unsub()
# Should have failed
assert len(calls) == 0
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
async def test_client_sock_failure_after_connect(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry: MqttMockHAClientGenerator,
calls: list[ReceiveMessage],
record_calls: MessageCallbackType,
) -> None:
"""Test handling the socket connected and disconnected."""
mqtt_mock = await mqtt_mock_entry()
# Fake that the client is connected
mqtt_mock().connected = True
await hass.async_block_till_done()
assert mqtt_mock.connected is True
mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_SUCCESS
client, server = socket.socketpair(
family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0
)
client.setblocking(False)
server.setblocking(False)
mqtt_client_mock.on_socket_open(mqtt_client_mock, None, client)
mqtt_client_mock.on_socket_register_writer(mqtt_client_mock, None, client)
await hass.async_block_till_done()
mqtt_client_mock.loop_write.side_effect = OSError("foo")
client.close() # close the client socket out from under the client
assert mqtt_mock.connected is True
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5))
await hass.async_block_till_done()
unsub()
# Should have failed
assert len(calls) == 0
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
async def test_loop_write_failure(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry: MqttMockHAClientGenerator,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test handling the socket connected and disconnected."""
mqtt_mock = await mqtt_mock_entry()
await hass.async_block_till_done()
assert mqtt_mock.connected is True
mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_SUCCESS
client, server = socket.socketpair(
family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0
)
client.setblocking(False)
server.setblocking(False)
mqtt_client_mock.on_socket_open(mqtt_client_mock, None, client)
mqtt_client_mock.on_socket_register_write(mqtt_client_mock, None, client)
mqtt_client_mock.loop_write.return_value = paho_mqtt.MQTT_ERR_CONN_LOST
mqtt_client_mock.loop_read.return_value = paho_mqtt.MQTT_ERR_CONN_LOST
# Fill up the outgoing buffer to ensure that loop_write
# and loop_read are called that next time control is
# returned to the event loop
try:
for _ in range(1000):
server.send(b"long" * 100)
except BlockingIOError:
pass
server.close()
# Once for the reader callback
await hass.async_block_till_done()
# Another for the writer callback
await hass.async_block_till_done()
# Final for the disconnect callback
await hass.async_block_till_done()
assert "Disconnected from MQTT server mock-broker:1883 (7)" in caplog.text

View file

@ -163,7 +163,7 @@ async def help_test_availability_when_connection_lost(
# Disconnected from MQTT server -> state changed to unavailable # Disconnected from MQTT server -> state changed to unavailable
mqtt_mock.connected = False mqtt_mock.connected = False
await hass.async_add_executor_job(mqtt_client_mock.on_disconnect, None, None, 0) mqtt_client_mock.on_disconnect(None, None, 0)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
@ -172,7 +172,7 @@ async def help_test_availability_when_connection_lost(
# Reconnected to MQTT server -> state still unavailable # Reconnected to MQTT server -> state still unavailable
mqtt_mock.connected = True mqtt_mock.connected = True
await hass.async_add_executor_job(mqtt_client_mock.on_connect, None, None, None, 0) mqtt_client_mock.on_connect(None, None, None, 0)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
@ -224,7 +224,7 @@ async def help_test_deep_sleep_availability_when_connection_lost(
# Disconnected from MQTT server -> state changed to unavailable # Disconnected from MQTT server -> state changed to unavailable
mqtt_mock.connected = False mqtt_mock.connected = False
await hass.async_add_executor_job(mqtt_client_mock.on_disconnect, None, None, 0) mqtt_client_mock.on_disconnect(None, None, 0)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
@ -233,7 +233,7 @@ async def help_test_deep_sleep_availability_when_connection_lost(
# Reconnected to MQTT server -> state no longer unavailable # Reconnected to MQTT server -> state no longer unavailable
mqtt_mock.connected = True mqtt_mock.connected = True
await hass.async_add_executor_job(mqtt_client_mock.on_connect, None, None, None, 0) mqtt_client_mock.on_connect(None, None, None, 0)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
@ -476,7 +476,7 @@ async def help_test_availability_poll_state(
# Disconnected from MQTT server # Disconnected from MQTT server
mqtt_mock.connected = False mqtt_mock.connected = False
await hass.async_add_executor_job(mqtt_client_mock.on_disconnect, None, None, 0) mqtt_client_mock.on_disconnect(None, None, 0)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
@ -484,7 +484,7 @@ async def help_test_availability_poll_state(
# Reconnected to MQTT server # Reconnected to MQTT server
mqtt_mock.connected = True mqtt_mock.connected = True
await hass.async_add_executor_job(mqtt_client_mock.on_connect, None, None, None, 0) mqtt_client_mock.on_connect(None, None, None, 0)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()

View file

@ -904,26 +904,45 @@ def mqtt_client_mock(hass: HomeAssistant) -> Generator[MqttMockPahoClient, None,
self.rc = 0 self.rc = 0
with patch("paho.mqtt.client.Client") as mock_client: with patch("paho.mqtt.client.Client") as mock_client:
# The below use a call_soon for the on_publish/on_subscribe/on_unsubscribe
# callbacks to simulate the behavior of the real MQTT client which will
# not be synchronous.
@ha.callback @ha.callback
def _async_fire_mqtt_message(topic, payload, qos, retain): def _async_fire_mqtt_message(topic, payload, qos, retain):
async_fire_mqtt_message(hass, topic, payload, qos, retain) async_fire_mqtt_message(hass, topic, payload, qos, retain)
mid = get_mid() mid = get_mid()
mock_client.on_publish(0, 0, mid) hass.loop.call_soon(mock_client.on_publish, 0, 0, mid)
return FakeInfo(mid) return FakeInfo(mid)
def _subscribe(topic, qos=0): def _subscribe(topic, qos=0):
mid = get_mid() mid = get_mid()
mock_client.on_subscribe(0, 0, mid) hass.loop.call_soon(mock_client.on_subscribe, 0, 0, mid)
return (0, mid) return (0, mid)
def _unsubscribe(topic): def _unsubscribe(topic):
mid = get_mid() mid = get_mid()
mock_client.on_unsubscribe(0, 0, mid) hass.loop.call_soon(mock_client.on_unsubscribe, 0, 0, mid)
return (0, mid) return (0, mid)
def _connect(*args, **kwargs):
# Connect always calls reconnect once, but we
# mock it out so we call reconnect to simulate
# the behavior.
mock_client.reconnect()
hass.loop.call_soon_threadsafe(
mock_client.on_connect, mock_client, None, 0, 0, 0
)
mock_client.on_socket_open(
mock_client, None, Mock(fileno=Mock(return_value=-1))
)
mock_client.on_socket_register_write(
mock_client, None, Mock(fileno=Mock(return_value=-1))
)
return 0
mock_client = mock_client.return_value mock_client = mock_client.return_value
mock_client.connect.return_value = 0 mock_client.connect.side_effect = _connect
mock_client.subscribe.side_effect = _subscribe mock_client.subscribe.side_effect = _subscribe
mock_client.unsubscribe.side_effect = _unsubscribe mock_client.unsubscribe.side_effect = _unsubscribe
mock_client.publish.side_effect = _async_fire_mqtt_message mock_client.publish.side_effect = _async_fire_mqtt_message
@ -985,6 +1004,7 @@ async def _mqtt_mock_entry(
# connected set to True to get a more realistic behavior when subscribing # connected set to True to get a more realistic behavior when subscribing
mock_mqtt_instance.connected = True mock_mqtt_instance.connected = True
mqtt_client_mock.on_connect(mqtt_client_mock, None, 0, 0, 0)
async_dispatcher_send(hass, mqtt.MQTT_CONNECTED) async_dispatcher_send(hass, mqtt.MQTT_CONNECTED)
await hass.async_block_till_done() await hass.async_block_till_done()