Wait before sending MQTT birth message (#39120)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
415213a325
commit
114a7226d6
3 changed files with 87 additions and 18 deletions
|
@ -8,6 +8,7 @@ import logging
|
|||
from operator import attrgetter
|
||||
import os
|
||||
import ssl
|
||||
import time
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import attr
|
||||
|
@ -26,10 +27,11 @@ from homeassistant.const import (
|
|||
CONF_PROTOCOL,
|
||||
CONF_USERNAME,
|
||||
CONF_VALUE_TEMPLATE,
|
||||
EVENT_HOMEASSISTANT_STARTED,
|
||||
EVENT_HOMEASSISTANT_STOP,
|
||||
)
|
||||
from homeassistant.const import CONF_UNIQUE_ID # noqa: F401
|
||||
from homeassistant.core import Event, ServiceCall, callback
|
||||
from homeassistant.core import CoreState, Event, ServiceCall, callback
|
||||
from homeassistant.exceptions import HomeAssistantError, Unauthorized
|
||||
from homeassistant.helpers import config_validation as cv, event, template
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect, dispatcher_send
|
||||
|
@ -71,7 +73,12 @@ from .const import (
|
|||
PROTOCOL_311,
|
||||
)
|
||||
from .debug_info import log_messages
|
||||
from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash, set_discovery_hash
|
||||
from .discovery import (
|
||||
LAST_DISCOVERY,
|
||||
MQTT_DISCOVERY_UPDATED,
|
||||
clear_discovery_hash,
|
||||
set_discovery_hash,
|
||||
)
|
||||
from .models import Message, MessageCallbackType, PublishPayloadType
|
||||
from .subscription import async_subscribe_topics, async_unsubscribe_topics
|
||||
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
|
||||
|
@ -126,6 +133,7 @@ CONNECTION_SUCCESS = "connection_success"
|
|||
CONNECTION_FAILED = "connection_failed"
|
||||
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
|
||||
|
||||
DISCOVERY_COOLDOWN = 2
|
||||
TIMEOUT_ACK = 1
|
||||
|
||||
|
||||
|
@ -623,11 +631,23 @@ class MQTT:
|
|||
self.conf = conf
|
||||
self.subscriptions: List[Subscription] = []
|
||||
self.connected = False
|
||||
self._ha_started = asyncio.Event()
|
||||
self._last_subscribe = time.time()
|
||||
self._mqttc: mqtt.Client = None
|
||||
self._paho_lock = asyncio.Lock()
|
||||
|
||||
self._pending_operations = {}
|
||||
|
||||
if self.hass.state == CoreState.running:
|
||||
self._ha_started.set()
|
||||
else:
|
||||
|
||||
@callback
|
||||
def ha_started(_):
|
||||
self._ha_started.set()
|
||||
|
||||
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started)
|
||||
|
||||
self.init_client()
|
||||
self.config_entry.add_update_listener(self.async_config_entry_updated)
|
||||
|
||||
|
@ -800,6 +820,7 @@ class MQTT:
|
|||
|
||||
# Only subscribe if currently connected.
|
||||
if self.connected:
|
||||
self._last_subscribe = time.time()
|
||||
await self._async_perform_subscription(topic, qos)
|
||||
|
||||
@callback
|
||||
|
@ -880,15 +901,19 @@ class MQTT:
|
|||
CONF_BIRTH_MESSAGE in self.conf
|
||||
and ATTR_TOPIC in self.conf[CONF_BIRTH_MESSAGE]
|
||||
):
|
||||
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
|
||||
self.hass.add_job(
|
||||
self.async_publish( # pylint: disable=no-value-for-parameter
|
||||
|
||||
async def publish_birth_message(birth_message):
|
||||
await self._ha_started.wait() # Wait for Home Assistant to start
|
||||
await self._discovery_cooldown() # Wait for MQTT discovery to cool down
|
||||
await self.async_publish( # pylint: disable=no-value-for-parameter
|
||||
topic=birth_message.topic,
|
||||
payload=birth_message.payload,
|
||||
qos=birth_message.qos,
|
||||
retain=birth_message.retain,
|
||||
)
|
||||
)
|
||||
|
||||
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
|
||||
self.hass.loop.create_task(publish_birth_message(birth_message))
|
||||
|
||||
def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None:
|
||||
"""Message received callback."""
|
||||
|
@ -970,6 +995,26 @@ class MQTT:
|
|||
finally:
|
||||
del self._pending_operations[mid]
|
||||
|
||||
async def _discovery_cooldown(self):
|
||||
now = time.time()
|
||||
# Reset discovery and subscribe cooldowns
|
||||
self.hass.data[LAST_DISCOVERY] = now
|
||||
self._last_subscribe = now
|
||||
|
||||
last_discovery = self.hass.data[LAST_DISCOVERY]
|
||||
last_subscribe = self._last_subscribe
|
||||
wait_until = max(
|
||||
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
|
||||
)
|
||||
while now < wait_until:
|
||||
await asyncio.sleep(wait_until - now)
|
||||
now = time.time()
|
||||
last_discovery = self.hass.data[LAST_DISCOVERY]
|
||||
last_subscribe = self._last_subscribe
|
||||
wait_until = max(
|
||||
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
|
||||
)
|
||||
|
||||
|
||||
def _raise_on_error(result_code: int) -> None:
|
||||
"""Raise error if error result."""
|
||||
|
|
|
@ -3,6 +3,7 @@ import asyncio
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
|
||||
from homeassistant.components import mqtt
|
||||
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
|
||||
|
@ -40,6 +41,7 @@ DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock"
|
|||
DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe"
|
||||
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
|
||||
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
|
||||
LAST_DISCOVERY = "mqtt_last_discovery"
|
||||
|
||||
TOPIC_BASE = "~"
|
||||
|
||||
|
@ -65,6 +67,7 @@ async def async_start(
|
|||
|
||||
async def async_device_message_received(msg):
|
||||
"""Process the received message."""
|
||||
hass.data[LAST_DISCOVERY] = time.time()
|
||||
payload = msg.payload
|
||||
topic = msg.topic
|
||||
topic_trimmed = topic.replace(f"{discovery_topic}/", "", 1)
|
||||
|
@ -167,6 +170,7 @@ async def async_start(
|
|||
hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe(
|
||||
hass, f"{discovery_topic}/#", async_device_message_received, 0
|
||||
)
|
||||
hass.data[LAST_DISCOVERY] = time.time()
|
||||
|
||||
return True
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""The tests for the MQTT component."""
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
import ssl
|
||||
|
@ -361,7 +362,6 @@ async def test_subscribe_deprecated_async(hass, mqtt_mock):
|
|||
"""Test the subscription of a topic using deprecated callback signature."""
|
||||
calls = []
|
||||
|
||||
@callback
|
||||
async def record_calls(topic, payload, qos):
|
||||
"""Record calls."""
|
||||
calls.append((topic, payload, qos))
|
||||
|
@ -758,18 +758,36 @@ async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass):
|
|||
)
|
||||
async def test_custom_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
||||
"""Test sending birth message."""
|
||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||
await hass.async_block_till_done()
|
||||
mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False)
|
||||
birth = asyncio.Event()
|
||||
|
||||
async def wait_birth(topic, payload, qos):
|
||||
"""Handle birth message."""
|
||||
birth.set()
|
||||
|
||||
with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1):
|
||||
await mqtt.async_subscribe(hass, "birth", wait_birth)
|
||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||
await hass.async_block_till_done()
|
||||
await birth.wait()
|
||||
mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False)
|
||||
|
||||
|
||||
async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
||||
"""Test sending birth message."""
|
||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||
await hass.async_block_till_done()
|
||||
mqtt_client_mock.publish.assert_called_with(
|
||||
"homeassistant/status", "online", 0, False
|
||||
)
|
||||
birth = asyncio.Event()
|
||||
|
||||
async def wait_birth(topic, payload, qos):
|
||||
"""Handle birth message."""
|
||||
birth.set()
|
||||
|
||||
with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1):
|
||||
await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth)
|
||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||
await hass.async_block_till_done()
|
||||
await birth.wait()
|
||||
mqtt_client_mock.publish.assert_called_with(
|
||||
"homeassistant/status", "online", 0, False
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -777,9 +795,11 @@ async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
|||
)
|
||||
async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
||||
"""Test disabling birth message."""
|
||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||
await hass.async_block_till_done()
|
||||
mqtt_client_mock.publish.assert_not_called()
|
||||
with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1):
|
||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||
await hass.async_block_till_done()
|
||||
await asyncio.sleep(0.2)
|
||||
mqtt_client_mock.publish.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
Loading…
Add table
Reference in a new issue