Wait before sending MQTT birth message (#39120)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Erik Montnemery 2020-08-25 16:42:24 +02:00 committed by GitHub
parent 415213a325
commit 114a7226d6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 18 deletions

View file

@ -8,6 +8,7 @@ import logging
from operator import attrgetter from operator import attrgetter
import os import os
import ssl import ssl
import time
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
import attr import attr
@ -26,10 +27,11 @@ from homeassistant.const import (
CONF_PROTOCOL, CONF_PROTOCOL,
CONF_USERNAME, CONF_USERNAME,
CONF_VALUE_TEMPLATE, CONF_VALUE_TEMPLATE,
EVENT_HOMEASSISTANT_STARTED,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_STOP,
) )
from homeassistant.const import CONF_UNIQUE_ID # noqa: F401 from homeassistant.const import CONF_UNIQUE_ID # noqa: F401
from homeassistant.core import Event, ServiceCall, callback from homeassistant.core import CoreState, Event, ServiceCall, callback
from homeassistant.exceptions import HomeAssistantError, Unauthorized from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers import config_validation as cv, event, template from homeassistant.helpers import config_validation as cv, event, template
from homeassistant.helpers.dispatcher import async_dispatcher_connect, dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_connect, dispatcher_send
@ -71,7 +73,12 @@ from .const import (
PROTOCOL_311, PROTOCOL_311,
) )
from .debug_info import log_messages from .debug_info import log_messages
from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash, set_discovery_hash from .discovery import (
LAST_DISCOVERY,
MQTT_DISCOVERY_UPDATED,
clear_discovery_hash,
set_discovery_hash,
)
from .models import Message, MessageCallbackType, PublishPayloadType from .models import Message, MessageCallbackType, PublishPayloadType
from .subscription import async_subscribe_topics, async_unsubscribe_topics from .subscription import async_subscribe_topics, async_unsubscribe_topics
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
@ -126,6 +133,7 @@ CONNECTION_SUCCESS = "connection_success"
CONNECTION_FAILED = "connection_failed" CONNECTION_FAILED = "connection_failed"
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable" CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
DISCOVERY_COOLDOWN = 2
TIMEOUT_ACK = 1 TIMEOUT_ACK = 1
@ -623,11 +631,23 @@ class MQTT:
self.conf = conf self.conf = conf
self.subscriptions: List[Subscription] = [] self.subscriptions: List[Subscription] = []
self.connected = False self.connected = False
self._ha_started = asyncio.Event()
self._last_subscribe = time.time()
self._mqttc: mqtt.Client = None self._mqttc: mqtt.Client = None
self._paho_lock = asyncio.Lock() self._paho_lock = asyncio.Lock()
self._pending_operations = {} self._pending_operations = {}
if self.hass.state == CoreState.running:
self._ha_started.set()
else:
@callback
def ha_started(_):
self._ha_started.set()
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started)
self.init_client() self.init_client()
self.config_entry.add_update_listener(self.async_config_entry_updated) self.config_entry.add_update_listener(self.async_config_entry_updated)
@ -800,6 +820,7 @@ class MQTT:
# Only subscribe if currently connected. # Only subscribe if currently connected.
if self.connected: if self.connected:
self._last_subscribe = time.time()
await self._async_perform_subscription(topic, qos) await self._async_perform_subscription(topic, qos)
@callback @callback
@ -880,15 +901,19 @@ class MQTT:
CONF_BIRTH_MESSAGE in self.conf CONF_BIRTH_MESSAGE in self.conf
and ATTR_TOPIC in self.conf[CONF_BIRTH_MESSAGE] and ATTR_TOPIC in self.conf[CONF_BIRTH_MESSAGE]
): ):
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
self.hass.add_job( async def publish_birth_message(birth_message):
self.async_publish( # pylint: disable=no-value-for-parameter await self._ha_started.wait() # Wait for Home Assistant to start
await self._discovery_cooldown() # Wait for MQTT discovery to cool down
await self.async_publish( # pylint: disable=no-value-for-parameter
topic=birth_message.topic, topic=birth_message.topic,
payload=birth_message.payload, payload=birth_message.payload,
qos=birth_message.qos, qos=birth_message.qos,
retain=birth_message.retain, retain=birth_message.retain,
) )
)
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
self.hass.loop.create_task(publish_birth_message(birth_message))
def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None: def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None:
"""Message received callback.""" """Message received callback."""
@ -970,6 +995,26 @@ class MQTT:
finally: finally:
del self._pending_operations[mid] del self._pending_operations[mid]
async def _discovery_cooldown(self):
now = time.time()
# Reset discovery and subscribe cooldowns
self.hass.data[LAST_DISCOVERY] = now
self._last_subscribe = now
last_discovery = self.hass.data[LAST_DISCOVERY]
last_subscribe = self._last_subscribe
wait_until = max(
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
)
while now < wait_until:
await asyncio.sleep(wait_until - now)
now = time.time()
last_discovery = self.hass.data[LAST_DISCOVERY]
last_subscribe = self._last_subscribe
wait_until = max(
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
)
def _raise_on_error(result_code: int) -> None: def _raise_on_error(result_code: int) -> None:
"""Raise error if error result.""" """Raise error if error result."""

View file

@ -3,6 +3,7 @@ import asyncio
import json import json
import logging import logging
import re import re
import time
from homeassistant.components import mqtt from homeassistant.components import mqtt
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
@ -40,6 +41,7 @@ DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock"
DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe" DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe"
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}" MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}" MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
LAST_DISCOVERY = "mqtt_last_discovery"
TOPIC_BASE = "~" TOPIC_BASE = "~"
@ -65,6 +67,7 @@ async def async_start(
async def async_device_message_received(msg): async def async_device_message_received(msg):
"""Process the received message.""" """Process the received message."""
hass.data[LAST_DISCOVERY] = time.time()
payload = msg.payload payload = msg.payload
topic = msg.topic topic = msg.topic
topic_trimmed = topic.replace(f"{discovery_topic}/", "", 1) topic_trimmed = topic.replace(f"{discovery_topic}/", "", 1)
@ -167,6 +170,7 @@ async def async_start(
hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe( hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe(
hass, f"{discovery_topic}/#", async_device_message_received, 0 hass, f"{discovery_topic}/#", async_device_message_received, 0
) )
hass.data[LAST_DISCOVERY] = time.time()
return True return True

View file

@ -1,4 +1,5 @@
"""The tests for the MQTT component.""" """The tests for the MQTT component."""
import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
import json import json
import ssl import ssl
@ -361,7 +362,6 @@ async def test_subscribe_deprecated_async(hass, mqtt_mock):
"""Test the subscription of a topic using deprecated callback signature.""" """Test the subscription of a topic using deprecated callback signature."""
calls = [] calls = []
@callback
async def record_calls(topic, payload, qos): async def record_calls(topic, payload, qos):
"""Record calls.""" """Record calls."""
calls.append((topic, payload, qos)) calls.append((topic, payload, qos))
@ -758,18 +758,36 @@ async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass):
) )
async def test_custom_birth_message(hass, mqtt_client_mock, mqtt_mock): async def test_custom_birth_message(hass, mqtt_client_mock, mqtt_mock):
"""Test sending birth message.""" """Test sending birth message."""
mqtt_mock._mqtt_on_connect(None, None, 0, 0) birth = asyncio.Event()
await hass.async_block_till_done()
mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False) async def wait_birth(topic, payload, qos):
"""Handle birth message."""
birth.set()
with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1):
await mqtt.async_subscribe(hass, "birth", wait_birth)
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
await hass.async_block_till_done()
await birth.wait()
mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False)
async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock): async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock):
"""Test sending birth message.""" """Test sending birth message."""
mqtt_mock._mqtt_on_connect(None, None, 0, 0) birth = asyncio.Event()
await hass.async_block_till_done()
mqtt_client_mock.publish.assert_called_with( async def wait_birth(topic, payload, qos):
"homeassistant/status", "online", 0, False """Handle birth message."""
) birth.set()
with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1):
await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth)
mqtt_mock._mqtt_on_connect(None, None, 0, 0)
await hass.async_block_till_done()
await birth.wait()
mqtt_client_mock.publish.assert_called_with(
"homeassistant/status", "online", 0, False
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -777,9 +795,11 @@ async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock):
) )
async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock): async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock):
"""Test disabling birth message.""" """Test disabling birth message."""
mqtt_mock._mqtt_on_connect(None, None, 0, 0) with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1):
await hass.async_block_till_done() mqtt_mock._mqtt_on_connect(None, None, 0, 0)
mqtt_client_mock.publish.assert_not_called() await hass.async_block_till_done()
await asyncio.sleep(0.2)
mqtt_client_mock.publish.assert_not_called()
@pytest.mark.parametrize( @pytest.mark.parametrize(