Wait before sending MQTT birth message (#39120)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Erik Montnemery 2020-08-25 16:42:24 +02:00 committed by GitHub
parent 415213a325
commit 114a7226d6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 18 deletions

View file

@ -8,6 +8,7 @@ import logging
from operator import attrgetter
import os
import ssl
import time
from typing import Any, Callable, List, Optional, Union
import attr
@ -26,10 +27,11 @@ from homeassistant.const import (
CONF_PROTOCOL,
CONF_USERNAME,
CONF_VALUE_TEMPLATE,
EVENT_HOMEASSISTANT_STARTED,
EVENT_HOMEASSISTANT_STOP,
)
from homeassistant.const import CONF_UNIQUE_ID # noqa: F401
from homeassistant.core import Event, ServiceCall, callback
from homeassistant.core import CoreState, Event, ServiceCall, callback
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers import config_validation as cv, event, template
from homeassistant.helpers.dispatcher import async_dispatcher_connect, dispatcher_send
@ -71,7 +73,12 @@ from .const import (
PROTOCOL_311,
)
from .debug_info import log_messages
from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash, set_discovery_hash
from .discovery import (
LAST_DISCOVERY,
MQTT_DISCOVERY_UPDATED,
clear_discovery_hash,
set_discovery_hash,
)
from .models import Message, MessageCallbackType, PublishPayloadType
from .subscription import async_subscribe_topics, async_unsubscribe_topics
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
@ -126,6 +133,7 @@ CONNECTION_SUCCESS = "connection_success"
CONNECTION_FAILED = "connection_failed"
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
DISCOVERY_COOLDOWN = 2
TIMEOUT_ACK = 1
@ -623,11 +631,23 @@ class MQTT:
self.conf = conf
self.subscriptions: List[Subscription] = []
self.connected = False
self._ha_started = asyncio.Event()
self._last_subscribe = time.time()
self._mqttc: mqtt.Client = None
self._paho_lock = asyncio.Lock()
self._pending_operations = {}
if self.hass.state == CoreState.running:
self._ha_started.set()
else:
@callback
def ha_started(_):
self._ha_started.set()
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started)
self.init_client()
self.config_entry.add_update_listener(self.async_config_entry_updated)
@ -800,6 +820,7 @@ class MQTT:
# Only subscribe if currently connected.
if self.connected:
self._last_subscribe = time.time()
await self._async_perform_subscription(topic, qos)
@callback
@ -880,15 +901,19 @@ class MQTT:
CONF_BIRTH_MESSAGE in self.conf
and ATTR_TOPIC in self.conf[CONF_BIRTH_MESSAGE]
):
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
self.hass.add_job(
self.async_publish( # pylint: disable=no-value-for-parameter
async def publish_birth_message(birth_message):
await self._ha_started.wait() # Wait for Home Assistant to start
await self._discovery_cooldown() # Wait for MQTT discovery to cool down
await self.async_publish( # pylint: disable=no-value-for-parameter
topic=birth_message.topic,
payload=birth_message.payload,
qos=birth_message.qos,
retain=birth_message.retain,
)
)
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
self.hass.loop.create_task(publish_birth_message(birth_message))
def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None:
"""Message received callback."""
@ -970,6 +995,26 @@ class MQTT:
finally:
del self._pending_operations[mid]
async def _discovery_cooldown(self):
now = time.time()
# Reset discovery and subscribe cooldowns
self.hass.data[LAST_DISCOVERY] = now
self._last_subscribe = now
last_discovery = self.hass.data[LAST_DISCOVERY]
last_subscribe = self._last_subscribe
wait_until = max(
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
)
while now < wait_until:
await asyncio.sleep(wait_until - now)
now = time.time()
last_discovery = self.hass.data[LAST_DISCOVERY]
last_subscribe = self._last_subscribe
wait_until = max(
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
)
def _raise_on_error(result_code: int) -> None:
"""Raise error if error result."""

View file

@ -3,6 +3,7 @@ import asyncio
import json
import logging
import re
import time
from homeassistant.components import mqtt
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
@ -40,6 +41,7 @@ DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock"
DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe"
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
LAST_DISCOVERY = "mqtt_last_discovery"
TOPIC_BASE = "~"
@ -65,6 +67,7 @@ async def async_start(
async def async_device_message_received(msg):
"""Process the received message."""
hass.data[LAST_DISCOVERY] = time.time()
payload = msg.payload
topic = msg.topic
topic_trimmed = topic.replace(f"{discovery_topic}/", "", 1)
@ -167,6 +170,7 @@ async def async_start(
hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe(
hass, f"{discovery_topic}/#", async_device_message_received, 0
)
hass.data[LAST_DISCOVERY] = time.time()
return True

View file

@ -1,4 +1,5 @@
"""The tests for the MQTT component."""
import asyncio
from datetime import datetime, timedelta
import json
import ssl
@ -361,7 +362,6 @@ async def test_subscribe_deprecated_async(hass, mqtt_mock):
"""Test the subscription of a topic using deprecated callback signature."""
calls = []
@callback
async def record_calls(topic, payload, qos):
"""Record calls."""
calls.append((topic, payload, qos))
@ -758,18 +758,36 @@ async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass):
)
async def test_custom_birth_message(hass, mqtt_client_mock, mqtt_mock):
"""Test sending birth message."""
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
await hass.async_block_till_done()
mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False)
birth = asyncio.Event()
async def wait_birth(topic, payload, qos):
"""Handle birth message."""
birth.set()
with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1):
await mqtt.async_subscribe(hass, "birth", wait_birth)
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
await hass.async_block_till_done()
await birth.wait()
mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False)
async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock):
"""Test sending birth message."""
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
await hass.async_block_till_done()
mqtt_client_mock.publish.assert_called_with(
"homeassistant/status", "online", 0, False
)
birth = asyncio.Event()
async def wait_birth(topic, payload, qos):
"""Handle birth message."""
birth.set()
with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1):
await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth)
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
await hass.async_block_till_done()
await birth.wait()
mqtt_client_mock.publish.assert_called_with(
"homeassistant/status", "online", 0, False
)
@pytest.mark.parametrize(
@ -777,9 +795,11 @@ async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock):
)
async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock):
"""Test disabling birth message."""
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
await hass.async_block_till_done()
mqtt_client_mock.publish.assert_not_called()
with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1):
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
await hass.async_block_till_done()
await asyncio.sleep(0.2)
mqtt_client_mock.publish.assert_not_called()
@pytest.mark.parametrize(