Wait before sending MQTT birth message (#39120)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
415213a325
commit
114a7226d6
3 changed files with 87 additions and 18 deletions
|
@ -8,6 +8,7 @@ import logging
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
import os
|
import os
|
||||||
import ssl
|
import ssl
|
||||||
|
import time
|
||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -26,10 +27,11 @@ from homeassistant.const import (
|
||||||
CONF_PROTOCOL,
|
CONF_PROTOCOL,
|
||||||
CONF_USERNAME,
|
CONF_USERNAME,
|
||||||
CONF_VALUE_TEMPLATE,
|
CONF_VALUE_TEMPLATE,
|
||||||
|
EVENT_HOMEASSISTANT_STARTED,
|
||||||
EVENT_HOMEASSISTANT_STOP,
|
EVENT_HOMEASSISTANT_STOP,
|
||||||
)
|
)
|
||||||
from homeassistant.const import CONF_UNIQUE_ID # noqa: F401
|
from homeassistant.const import CONF_UNIQUE_ID # noqa: F401
|
||||||
from homeassistant.core import Event, ServiceCall, callback
|
from homeassistant.core import CoreState, Event, ServiceCall, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError, Unauthorized
|
from homeassistant.exceptions import HomeAssistantError, Unauthorized
|
||||||
from homeassistant.helpers import config_validation as cv, event, template
|
from homeassistant.helpers import config_validation as cv, event, template
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect, dispatcher_send
|
from homeassistant.helpers.dispatcher import async_dispatcher_connect, dispatcher_send
|
||||||
|
@ -71,7 +73,12 @@ from .const import (
|
||||||
PROTOCOL_311,
|
PROTOCOL_311,
|
||||||
)
|
)
|
||||||
from .debug_info import log_messages
|
from .debug_info import log_messages
|
||||||
from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash, set_discovery_hash
|
from .discovery import (
|
||||||
|
LAST_DISCOVERY,
|
||||||
|
MQTT_DISCOVERY_UPDATED,
|
||||||
|
clear_discovery_hash,
|
||||||
|
set_discovery_hash,
|
||||||
|
)
|
||||||
from .models import Message, MessageCallbackType, PublishPayloadType
|
from .models import Message, MessageCallbackType, PublishPayloadType
|
||||||
from .subscription import async_subscribe_topics, async_unsubscribe_topics
|
from .subscription import async_subscribe_topics, async_unsubscribe_topics
|
||||||
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
|
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
|
||||||
|
@ -126,6 +133,7 @@ CONNECTION_SUCCESS = "connection_success"
|
||||||
CONNECTION_FAILED = "connection_failed"
|
CONNECTION_FAILED = "connection_failed"
|
||||||
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
|
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
|
||||||
|
|
||||||
|
DISCOVERY_COOLDOWN = 2
|
||||||
TIMEOUT_ACK = 1
|
TIMEOUT_ACK = 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -623,11 +631,23 @@ class MQTT:
|
||||||
self.conf = conf
|
self.conf = conf
|
||||||
self.subscriptions: List[Subscription] = []
|
self.subscriptions: List[Subscription] = []
|
||||||
self.connected = False
|
self.connected = False
|
||||||
|
self._ha_started = asyncio.Event()
|
||||||
|
self._last_subscribe = time.time()
|
||||||
self._mqttc: mqtt.Client = None
|
self._mqttc: mqtt.Client = None
|
||||||
self._paho_lock = asyncio.Lock()
|
self._paho_lock = asyncio.Lock()
|
||||||
|
|
||||||
self._pending_operations = {}
|
self._pending_operations = {}
|
||||||
|
|
||||||
|
if self.hass.state == CoreState.running:
|
||||||
|
self._ha_started.set()
|
||||||
|
else:
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def ha_started(_):
|
||||||
|
self._ha_started.set()
|
||||||
|
|
||||||
|
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started)
|
||||||
|
|
||||||
self.init_client()
|
self.init_client()
|
||||||
self.config_entry.add_update_listener(self.async_config_entry_updated)
|
self.config_entry.add_update_listener(self.async_config_entry_updated)
|
||||||
|
|
||||||
|
@ -800,6 +820,7 @@ class MQTT:
|
||||||
|
|
||||||
# Only subscribe if currently connected.
|
# Only subscribe if currently connected.
|
||||||
if self.connected:
|
if self.connected:
|
||||||
|
self._last_subscribe = time.time()
|
||||||
await self._async_perform_subscription(topic, qos)
|
await self._async_perform_subscription(topic, qos)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -880,15 +901,19 @@ class MQTT:
|
||||||
CONF_BIRTH_MESSAGE in self.conf
|
CONF_BIRTH_MESSAGE in self.conf
|
||||||
and ATTR_TOPIC in self.conf[CONF_BIRTH_MESSAGE]
|
and ATTR_TOPIC in self.conf[CONF_BIRTH_MESSAGE]
|
||||||
):
|
):
|
||||||
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
|
|
||||||
self.hass.add_job(
|
async def publish_birth_message(birth_message):
|
||||||
self.async_publish( # pylint: disable=no-value-for-parameter
|
await self._ha_started.wait() # Wait for Home Assistant to start
|
||||||
|
await self._discovery_cooldown() # Wait for MQTT discovery to cool down
|
||||||
|
await self.async_publish( # pylint: disable=no-value-for-parameter
|
||||||
topic=birth_message.topic,
|
topic=birth_message.topic,
|
||||||
payload=birth_message.payload,
|
payload=birth_message.payload,
|
||||||
qos=birth_message.qos,
|
qos=birth_message.qos,
|
||||||
retain=birth_message.retain,
|
retain=birth_message.retain,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
|
||||||
|
self.hass.loop.create_task(publish_birth_message(birth_message))
|
||||||
|
|
||||||
def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None:
|
def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None:
|
||||||
"""Message received callback."""
|
"""Message received callback."""
|
||||||
|
@ -970,6 +995,26 @@ class MQTT:
|
||||||
finally:
|
finally:
|
||||||
del self._pending_operations[mid]
|
del self._pending_operations[mid]
|
||||||
|
|
||||||
|
async def _discovery_cooldown(self):
|
||||||
|
now = time.time()
|
||||||
|
# Reset discovery and subscribe cooldowns
|
||||||
|
self.hass.data[LAST_DISCOVERY] = now
|
||||||
|
self._last_subscribe = now
|
||||||
|
|
||||||
|
last_discovery = self.hass.data[LAST_DISCOVERY]
|
||||||
|
last_subscribe = self._last_subscribe
|
||||||
|
wait_until = max(
|
||||||
|
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
|
||||||
|
)
|
||||||
|
while now < wait_until:
|
||||||
|
await asyncio.sleep(wait_until - now)
|
||||||
|
now = time.time()
|
||||||
|
last_discovery = self.hass.data[LAST_DISCOVERY]
|
||||||
|
last_subscribe = self._last_subscribe
|
||||||
|
wait_until = max(
|
||||||
|
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _raise_on_error(result_code: int) -> None:
|
def _raise_on_error(result_code: int) -> None:
|
||||||
"""Raise error if error result."""
|
"""Raise error if error result."""
|
||||||
|
|
|
@ -3,6 +3,7 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
|
|
||||||
from homeassistant.components import mqtt
|
from homeassistant.components import mqtt
|
||||||
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
|
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
|
||||||
|
@ -40,6 +41,7 @@ DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock"
|
||||||
DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe"
|
DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe"
|
||||||
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
|
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
|
||||||
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
|
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
|
||||||
|
LAST_DISCOVERY = "mqtt_last_discovery"
|
||||||
|
|
||||||
TOPIC_BASE = "~"
|
TOPIC_BASE = "~"
|
||||||
|
|
||||||
|
@ -65,6 +67,7 @@ async def async_start(
|
||||||
|
|
||||||
async def async_device_message_received(msg):
|
async def async_device_message_received(msg):
|
||||||
"""Process the received message."""
|
"""Process the received message."""
|
||||||
|
hass.data[LAST_DISCOVERY] = time.time()
|
||||||
payload = msg.payload
|
payload = msg.payload
|
||||||
topic = msg.topic
|
topic = msg.topic
|
||||||
topic_trimmed = topic.replace(f"{discovery_topic}/", "", 1)
|
topic_trimmed = topic.replace(f"{discovery_topic}/", "", 1)
|
||||||
|
@ -167,6 +170,7 @@ async def async_start(
|
||||||
hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe(
|
hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe(
|
||||||
hass, f"{discovery_topic}/#", async_device_message_received, 0
|
hass, f"{discovery_topic}/#", async_device_message_received, 0
|
||||||
)
|
)
|
||||||
|
hass.data[LAST_DISCOVERY] = time.time()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""The tests for the MQTT component."""
|
"""The tests for the MQTT component."""
|
||||||
|
import asyncio
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import json
|
import json
|
||||||
import ssl
|
import ssl
|
||||||
|
@ -361,7 +362,6 @@ async def test_subscribe_deprecated_async(hass, mqtt_mock):
|
||||||
"""Test the subscription of a topic using deprecated callback signature."""
|
"""Test the subscription of a topic using deprecated callback signature."""
|
||||||
calls = []
|
calls = []
|
||||||
|
|
||||||
@callback
|
|
||||||
async def record_calls(topic, payload, qos):
|
async def record_calls(topic, payload, qos):
|
||||||
"""Record calls."""
|
"""Record calls."""
|
||||||
calls.append((topic, payload, qos))
|
calls.append((topic, payload, qos))
|
||||||
|
@ -758,18 +758,36 @@ async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass):
|
||||||
)
|
)
|
||||||
async def test_custom_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
async def test_custom_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
||||||
"""Test sending birth message."""
|
"""Test sending birth message."""
|
||||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
birth = asyncio.Event()
|
||||||
await hass.async_block_till_done()
|
|
||||||
mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False)
|
async def wait_birth(topic, payload, qos):
|
||||||
|
"""Handle birth message."""
|
||||||
|
birth.set()
|
||||||
|
|
||||||
|
with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1):
|
||||||
|
await mqtt.async_subscribe(hass, "birth", wait_birth)
|
||||||
|
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
await birth.wait()
|
||||||
|
mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False)
|
||||||
|
|
||||||
|
|
||||||
async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
||||||
"""Test sending birth message."""
|
"""Test sending birth message."""
|
||||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
birth = asyncio.Event()
|
||||||
await hass.async_block_till_done()
|
|
||||||
mqtt_client_mock.publish.assert_called_with(
|
async def wait_birth(topic, payload, qos):
|
||||||
"homeassistant/status", "online", 0, False
|
"""Handle birth message."""
|
||||||
)
|
birth.set()
|
||||||
|
|
||||||
|
with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1):
|
||||||
|
await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth)
|
||||||
|
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
await birth.wait()
|
||||||
|
mqtt_client_mock.publish.assert_called_with(
|
||||||
|
"homeassistant/status", "online", 0, False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -777,9 +795,11 @@ async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
||||||
)
|
)
|
||||||
async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock):
|
||||||
"""Test disabling birth message."""
|
"""Test disabling birth message."""
|
||||||
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1):
|
||||||
await hass.async_block_till_done()
|
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
|
||||||
mqtt_client_mock.publish.assert_not_called()
|
await hass.async_block_till_done()
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
mqtt_client_mock.publish.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
Loading…
Add table
Reference in a new issue