From 114a7226d68718c27a4155d854cbeb1926a89b19 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 25 Aug 2020 16:42:24 +0200 Subject: [PATCH] Wait before sending MQTT birth message (#39120) Co-authored-by: Paulus Schoutsen --- homeassistant/components/mqtt/__init__.py | 57 +++++++++++++++++++--- homeassistant/components/mqtt/discovery.py | 4 ++ tests/components/mqtt/test_init.py | 44 ++++++++++++----- 3 files changed, 87 insertions(+), 18 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 819fbc9838f..865f21b9d38 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -8,6 +8,7 @@ import logging from operator import attrgetter import os import ssl +import time from typing import Any, Callable, List, Optional, Union import attr @@ -26,10 +27,11 @@ from homeassistant.const import ( CONF_PROTOCOL, CONF_USERNAME, CONF_VALUE_TEMPLATE, + EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, ) from homeassistant.const import CONF_UNIQUE_ID # noqa: F401 -from homeassistant.core import Event, ServiceCall, callback +from homeassistant.core import CoreState, Event, ServiceCall, callback from homeassistant.exceptions import HomeAssistantError, Unauthorized from homeassistant.helpers import config_validation as cv, event, template from homeassistant.helpers.dispatcher import async_dispatcher_connect, dispatcher_send @@ -71,7 +73,12 @@ from .const import ( PROTOCOL_311, ) from .debug_info import log_messages -from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash, set_discovery_hash +from .discovery import ( + LAST_DISCOVERY, + MQTT_DISCOVERY_UPDATED, + clear_discovery_hash, + set_discovery_hash, +) from .models import Message, MessageCallbackType, PublishPayloadType from .subscription import async_subscribe_topics, async_unsubscribe_topics from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic @@ -126,6 +133,7 @@ CONNECTION_SUCCESS = "connection_success" CONNECTION_FAILED = "connection_failed" CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable" +DISCOVERY_COOLDOWN = 2 TIMEOUT_ACK = 1 @@ -623,11 +631,23 @@ class MQTT: self.conf = conf self.subscriptions: List[Subscription] = [] self.connected = False + self._ha_started = asyncio.Event() + self._last_subscribe = time.time() self._mqttc: mqtt.Client = None self._paho_lock = asyncio.Lock() self._pending_operations = {} + if self.hass.state == CoreState.running: + self._ha_started.set() + else: + + @callback + def ha_started(_): + self._ha_started.set() + + self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started) + self.init_client() self.config_entry.add_update_listener(self.async_config_entry_updated) @@ -800,6 +820,7 @@ class MQTT: # Only subscribe if currently connected. if self.connected: + self._last_subscribe = time.time() await self._async_perform_subscription(topic, qos) @callback @@ -880,15 +901,19 @@ class MQTT: CONF_BIRTH_MESSAGE in self.conf and ATTR_TOPIC in self.conf[CONF_BIRTH_MESSAGE] ): - birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE]) - self.hass.add_job( - self.async_publish( # pylint: disable=no-value-for-parameter + + async def publish_birth_message(birth_message): + await self._ha_started.wait() # Wait for Home Assistant to start + await self._discovery_cooldown() # Wait for MQTT discovery to cool down + await self.async_publish( # pylint: disable=no-value-for-parameter topic=birth_message.topic, payload=birth_message.payload, qos=birth_message.qos, retain=birth_message.retain, ) - ) + + birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE]) + self.hass.loop.create_task(publish_birth_message(birth_message)) def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None: """Message received callback.""" @@ -970,6 +995,26 @@ class MQTT: finally: del self._pending_operations[mid] + async def _discovery_cooldown(self): + now = time.time() + # Reset discovery and subscribe cooldowns + self.hass.data[LAST_DISCOVERY] = now + self._last_subscribe = now + + last_discovery = self.hass.data[LAST_DISCOVERY] + last_subscribe = self._last_subscribe + wait_until = max( + last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN + ) + while now < wait_until: + await asyncio.sleep(wait_until - now) + now = time.time() + last_discovery = self.hass.data[LAST_DISCOVERY] + last_subscribe = self._last_subscribe + wait_until = max( + last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN + ) + def _raise_on_error(result_code: int) -> None: """Raise error if error result.""" diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index aff66954968..a7d5236148b 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -3,6 +3,7 @@ import asyncio import json import logging import re +import time from homeassistant.components import mqtt from homeassistant.const import CONF_DEVICE, CONF_PLATFORM @@ -40,6 +41,7 @@ DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock" DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe" MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}" MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}" +LAST_DISCOVERY = "mqtt_last_discovery" TOPIC_BASE = "~" @@ -65,6 +67,7 @@ async def async_start( async def async_device_message_received(msg): """Process the received message.""" + hass.data[LAST_DISCOVERY] = time.time() payload = msg.payload topic = msg.topic topic_trimmed = topic.replace(f"{discovery_topic}/", "", 1) @@ -167,6 +170,7 @@ async def async_start( hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe( hass, f"{discovery_topic}/#", async_device_message_received, 0 ) + hass.data[LAST_DISCOVERY] = time.time() return True diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index dea0852d580..0dfef17a145 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -1,4 +1,5 @@ """The tests for the MQTT component.""" +import asyncio from datetime import datetime, timedelta import json import ssl @@ -361,7 +362,6 @@ async def test_subscribe_deprecated_async(hass, mqtt_mock): """Test the subscription of a topic using deprecated callback signature.""" calls = [] - @callback async def record_calls(topic, payload, qos): """Record calls.""" calls.append((topic, payload, qos)) @@ -758,18 +758,36 @@ async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass): ) async def test_custom_birth_message(hass, mqtt_client_mock, mqtt_mock): """Test sending birth message.""" - mqtt_mock._mqtt_on_connect(None, None, 0, 0) - await hass.async_block_till_done() - mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False) + birth = asyncio.Event() + + async def wait_birth(topic, payload, qos): + """Handle birth message.""" + birth.set() + + with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1): + await mqtt.async_subscribe(hass, "birth", wait_birth) + mqtt_mock._mqtt_on_connect(None, None, 0, 0) + await hass.async_block_till_done() + await birth.wait() + mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False) async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock): """Test sending birth message.""" - mqtt_mock._mqtt_on_connect(None, None, 0, 0) - await hass.async_block_till_done() - mqtt_client_mock.publish.assert_called_with( - "homeassistant/status", "online", 0, False - ) + birth = asyncio.Event() + + async def wait_birth(topic, payload, qos): + """Handle birth message.""" + birth.set() + + with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1): + await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth) + mqtt_mock._mqtt_on_connect(None, None, 0, 0) + await hass.async_block_till_done() + await birth.wait() + mqtt_client_mock.publish.assert_called_with( + "homeassistant/status", "online", 0, False + ) @pytest.mark.parametrize( @@ -777,9 +795,11 @@ async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock): ) async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock): """Test disabling birth message.""" - mqtt_mock._mqtt_on_connect(None, None, 0, 0) - await hass.async_block_till_done() - mqtt_client_mock.publish.assert_not_called() + with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1): + mqtt_mock._mqtt_on_connect(None, None, 0, 0) + await hass.async_block_till_done() + await asyncio.sleep(0.2) + mqtt_client_mock.publish.assert_not_called() @pytest.mark.parametrize(