Debounce and group MQTT subscriptions (#88862)

* Debounce and group mqtt subscriptions

* Cleanup

* Do not cooldown on resubscribe

* Remove lock from task

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* ruff

* Longer initial cool down. Manages unsubscribes

* Own lock for access to self._pending_subscriptions

* adjust

* Subscribe to highest QoS when sharing subscription

* do not block _pending_subscriptions_lock with io

* Test the highest qos is subscribed at

* Cleanup max qos

* Follow up comments part 1

* Make docstr more generic

* Make max qos update thread safe

* Add lock on clearing _max_qos when resubscribing

* Wait for linger task

* User copy

* Check for key before cleaning up

* Fix lingering task

* Do not use a lock

* do not await _async_queue_subscriptions

* Replace copy with assignment

* Update max qos before returning

* Do not iterate if max_qos == 0

* Do not ieterate subs if max qos == 0

* Set initial cooldown correctly

* Ensure discovery cooldown ends after subscribing

* plan last subscribe with debouncer timeout

* cooldown if self._pending_subscriptions is set

* Revert format changes

* Remove stale assingnment self._last_subscribe

* Remove not used property

* Also check while for pending subscriptions

* revert first added sleep()

* Optimize

---------

Co-authored-by: Erik Montnemery <erik@montnemery.com>
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Jan Bouwhuis 2023-03-14 11:13:55 +01:00 committed by GitHub
parent 03b204f445
commit ec1b8b616f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 361 additions and 38 deletions

View file

@ -83,6 +83,8 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DISCOVERY_COOLDOWN = 2 DISCOVERY_COOLDOWN = 2
INITIAL_SUBSCRIBE_COOLDOWN = 1.0
SUBSCRIBE_COOLDOWN = 0.1
TIMEOUT_ACK = 10 TIMEOUT_ACK = 10
SubscribePayloadType = str | bytes # Only bytes if encoding is None SubscribePayloadType = str | bytes # Only bytes if encoding is None
@ -295,10 +297,86 @@ def _is_simple_match(topic: str) -> bool:
return not ("+" in topic or "#" in topic) return not ("+" in topic or "#" in topic)
class EnsureJobAfterCooldown:
"""Ensure a cool down period before executing a job.
When a new execute request arrives we cancel the current request
and start a new one.
"""
def __init__(
self, timeout: float, callback_job: Callable[[], Coroutine[Any, None, None]]
) -> None:
"""Initialize the timer."""
self._loop = asyncio.get_running_loop()
self._timeout = timeout
self._callback = callback_job
self._task: asyncio.Future | None = None
self._timer: asyncio.TimerHandle | None = None
def set_timeout(self, timeout: float) -> None:
"""Set a new timeout period."""
self._timeout = timeout
async def _async_job(self) -> None:
"""Execute after a cooldown period."""
try:
await self._callback()
except HomeAssistantError as ha_error:
_LOGGER.error("%s", ha_error)
@callback
def _async_task_done(self, task: asyncio.Future) -> None:
"""Handle task done."""
self._task = None
@callback
def _async_execute(self) -> None:
"""Execute the job."""
if self._task:
# Task already running,
# so we schedule another run
self.async_schedule()
return
self._async_cancel_timer()
self._task = asyncio.create_task(self._async_job())
self._task.add_done_callback(self._async_task_done)
@callback
def _async_cancel_timer(self) -> None:
"""Cancel any pending task."""
if self._timer:
self._timer.cancel()
self._timer = None
@callback
def async_schedule(self) -> None:
"""Ensure we execute after a cooldown period."""
# We want to reschedule the timer in the future
# every time this is called.
self._async_cancel_timer()
self._timer = self._loop.call_later(self._timeout, self._async_execute)
async def async_cleanup(self) -> None:
"""Cleanup any pending task."""
self._async_cancel_timer()
if not self._task:
return
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error cleaning up task", exc_info=True)
class MQTT: class MQTT:
"""Home Assistant MQTT client.""" """Home Assistant MQTT client."""
_mqttc: mqtt.Client _mqttc: mqtt.Client
_last_subscribe: float
def __init__( def __init__(
self, self,
@ -316,12 +394,16 @@ class MQTT:
self._wildcard_subscriptions: list[Subscription] = [] self._wildcard_subscriptions: list[Subscription] = []
self.connected = False self.connected = False
self._ha_started = asyncio.Event() self._ha_started = asyncio.Event()
self._last_subscribe = time.time()
self._cleanup_on_unload: list[Callable[[], None]] = [] self._cleanup_on_unload: list[Callable[[], None]] = []
self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client
self._pending_operations: dict[int, asyncio.Event] = {} self._pending_operations: dict[int, asyncio.Event] = {}
self._pending_operations_condition = asyncio.Condition() self._pending_operations_condition = asyncio.Condition()
self._subscribe_debouncer = EnsureJobAfterCooldown(
INITIAL_SUBSCRIBE_COOLDOWN, self._async_perform_subscriptions
)
self._max_qos: dict[str, int] = {} # topic, max qos
self._pending_subscriptions: dict[str, int] = {} # topic, qos
if self.hass.state == CoreState.running: if self.hass.state == CoreState.running:
self._ha_started.set() self._ha_started.set()
@ -442,6 +524,11 @@ class MQTT:
"""Return False if there are unprocessed ACKs.""" """Return False if there are unprocessed ACKs."""
return not any(not op.is_set() for op in self._pending_operations.values()) return not any(not op.is_set() for op in self._pending_operations.values())
# stop waiting for any pending subscriptions
await self._subscribe_debouncer.async_cleanup()
# reset timeout to initial subscribe cooldown
self._subscribe_debouncer.set_timeout(INITIAL_SUBSCRIBE_COOLDOWN)
# wait for ACKs to be processed # wait for ACKs to be processed
async with self._pending_operations_condition: async with self._pending_operations_condition:
await self._pending_operations_condition.wait_for(no_more_acks) await self._pending_operations_condition.wait_for(no_more_acks)
@ -494,6 +581,20 @@ class MQTT:
except (KeyError, ValueError) as ex: except (KeyError, ValueError) as ex:
raise HomeAssistantError("Can't remove subscription twice") from ex raise HomeAssistantError("Can't remove subscription twice") from ex
@callback
def _async_queue_subscriptions(
self, subscriptions: Iterable[tuple[str, int]], queue_only: bool = False
) -> None:
"""Queue requested subscriptions."""
for subscription in subscriptions:
topic, qos = subscription
max_qos = max(qos, self._max_qos.setdefault(topic, qos))
self._max_qos[topic] = max_qos
self._pending_subscriptions[topic] = max_qos
if queue_only:
return
self._subscribe_debouncer.async_schedule()
async def async_subscribe( async def async_subscribe(
self, self,
topic: str, topic: str,
@ -516,15 +617,13 @@ class MQTT:
# Only subscribe if currently connected. # Only subscribe if currently connected.
if self.connected: if self.connected:
self._last_subscribe = time.time() self._async_queue_subscriptions(((topic, qos),))
await self._async_perform_subscriptions(((topic, qos),))
@callback @callback
def async_remove() -> None: def async_remove() -> None:
"""Remove subscription.""" """Remove subscription."""
self._async_untrack_subscription(subscription) self._async_untrack_subscription(subscription)
self._matching_subscriptions.cache_clear() self._matching_subscriptions.cache_clear()
# Only unsubscribe if currently connected # Only unsubscribe if currently connected
if self.connected: if self.connected:
self.hass.async_create_task(self._async_unsubscribe(topic)) self.hass.async_create_task(self._async_unsubscribe(topic))
@ -543,21 +642,27 @@ class MQTT:
_raise_on_error(result) _raise_on_error(result)
return mid return mid
async with self._paho_lock: if self._is_active_subscription(topic):
if self._is_active_subscription(topic): if self._max_qos[topic] == 0:
# Other subscriptions on topic remaining - don't unsubscribe.
return return
subs = self._matching_subscriptions(topic)
self._max_qos[topic] = max(sub.qos for sub in subs)
# Other subscriptions on topic remaining - don't unsubscribe.
return
if topic in self._max_qos:
del self._max_qos[topic]
if topic in self._pending_subscriptions:
# avoid any pending subscription to be executed
del self._pending_subscriptions[topic]
async with self._paho_lock:
mid = await self.hass.async_add_executor_job(_client_unsubscribe, topic) mid = await self.hass.async_add_executor_job(_client_unsubscribe, topic)
await self._register_mid(mid) await self._register_mid(mid)
self.hass.async_create_task(self._wait_for_mid(mid)) self.hass.async_create_task(self._wait_for_mid(mid))
async def _async_perform_subscriptions( async def _async_perform_subscriptions(self) -> None:
self, subscriptions: Iterable[tuple[str, int]]
) -> None:
"""Perform MQTT client subscriptions.""" """Perform MQTT client subscriptions."""
subscriptions: dict[str, int]
# Section 3.3.1.3 in the specification: # Section 3.3.1.3 in the specification:
# http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html
# When sending a PUBLISH Packet to a Client the Server MUST # When sending a PUBLISH Packet to a Client the Server MUST
@ -573,16 +678,20 @@ class MQTT:
def _process_client_subscriptions() -> list[tuple[int, int]]: def _process_client_subscriptions() -> list[tuple[int, int]]:
"""Initiate all subscriptions on the MQTT client and return the results.""" """Initiate all subscriptions on the MQTT client and return the results."""
subscribe_result_list = [] subscribe_result_list = []
for topic, qos in subscriptions: for topic, qos in subscriptions.items():
result, mid = self._mqttc.subscribe(topic, qos) result, mid = self._mqttc.subscribe(topic, qos)
subscribe_result_list.append((result, mid)) subscribe_result_list.append((result, mid))
_LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos) _LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos)
return subscribe_result_list return subscribe_result_list
subscriptions = self._pending_subscriptions
self._pending_subscriptions = {}
async with self._paho_lock: async with self._paho_lock:
results = await self.hass.async_add_executor_job( results = await self.hass.async_add_executor_job(
_process_client_subscriptions _process_client_subscriptions
) )
self._last_subscribe = time.time()
tasks: list[Coroutine[Any, Any, None]] = [] tasks: list[Coroutine[Any, Any, None]] = []
errors: list[int] = [] errors: list[int] = []
@ -639,6 +748,8 @@ class MQTT:
async def publish_birth_message(birth_message: PublishMessage) -> None: async def publish_birth_message(birth_message: PublishMessage) -> None:
await self._ha_started.wait() # Wait for Home Assistant to start await self._ha_started.wait() # Wait for Home Assistant to start
await self._discovery_cooldown() # Wait for MQTT discovery to cool down await self._discovery_cooldown() # Wait for MQTT discovery to cool down
# Update subscribe cooldown period to a shorter time
self._subscribe_debouncer.set_timeout(SUBSCRIBE_COOLDOWN)
await self.async_publish( await self.async_publish(
topic=birth_message.topic, topic=birth_message.topic,
payload=birth_message.payload, payload=birth_message.payload,
@ -654,16 +765,19 @@ class MQTT:
async def _async_resubscribe(self) -> None: async def _async_resubscribe(self) -> None:
"""Resubscribe on reconnect.""" """Resubscribe on reconnect."""
# Group subscriptions to only re-subscribe once for each topic. # Group subscriptions to only re-subscribe once for each topic.
self._max_qos.clear()
keyfunc = attrgetter("topic") keyfunc = attrgetter("topic")
await self._async_perform_subscriptions( self._async_queue_subscriptions(
[ [
# Re-subscribe with the highest requested qos # Re-subscribe with the highest requested qos
(topic, max(subscription.qos for subscription in subs)) (topic, max(subscription.qos for subscription in subs))
for topic, subs in groupby( for topic, subs in groupby(
sorted(self.subscriptions, key=keyfunc), keyfunc sorted(self.subscriptions, key=keyfunc), keyfunc
) )
] ],
queue_only=True,
) )
await self._async_perform_subscriptions()
def _mqtt_on_message( def _mqtt_on_message(
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
@ -785,13 +899,14 @@ class MQTT:
self._pending_operations_condition.notify_all() self._pending_operations_condition.notify_all()
async def _discovery_cooldown(self) -> None: async def _discovery_cooldown(self) -> None:
"""Wait until all discovery and subscriptions are processed."""
now = time.time() now = time.time()
# Reset discovery and subscribe cooldowns # Reset discovery and subscribe cooldowns
self._mqtt_data.last_discovery = now self._mqtt_data.last_discovery = now
self._last_subscribe = now self._last_subscribe = now
last_discovery = self._mqtt_data.last_discovery last_discovery = self._mqtt_data.last_discovery
last_subscribe = self._last_subscribe last_subscribe = now if self._pending_subscriptions else self._last_subscribe
wait_until = max( wait_until = max(
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
) )
@ -799,7 +914,9 @@ class MQTT:
await asyncio.sleep(wait_until - now) await asyncio.sleep(wait_until - now)
now = time.time() now = time.time()
last_discovery = self._mqtt_data.last_discovery last_discovery = self._mqtt_data.last_discovery
last_subscribe = self._last_subscribe last_subscribe = (
now if self._pending_subscriptions else self._last_subscribe
)
wait_until = max( wait_until = max(
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
) )

View file

@ -1374,6 +1374,8 @@ async def test_complex_discovery_topic_prefix(
@patch("homeassistant.components.mqtt.PLATFORMS", []) @patch("homeassistant.components.mqtt.PLATFORMS", [])
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
async def test_mqtt_integration_discovery_subscribe_unsubscribe( async def test_mqtt_integration_discovery_subscribe_unsubscribe(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
@ -1392,6 +1394,7 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe(
): ):
await async_start(hass, "homeassistant", entry) await async_start(hass, "homeassistant", entry)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done()
mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0) mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0)
assert not mqtt_client_mock.unsubscribe.called assert not mqtt_client_mock.unsubscribe.called
@ -1418,6 +1421,8 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe(
@patch("homeassistant.components.mqtt.PLATFORMS", []) @patch("homeassistant.components.mqtt.PLATFORMS", [])
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
async def test_mqtt_discovery_unsubscribe_once( async def test_mqtt_discovery_unsubscribe_once(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
@ -1436,6 +1441,7 @@ async def test_mqtt_discovery_unsubscribe_once(
): ):
await async_start(hass, "homeassistant", entry) await async_start(hass, "homeassistant", entry)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done()
mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0) mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0)
assert not mqtt_client_mock.unsubscribe.called assert not mqtt_client_mock.unsubscribe.called

View file

@ -17,6 +17,7 @@ import yaml
from homeassistant import config as hass_config from homeassistant import config as hass_config
from homeassistant.components import mqtt from homeassistant.components import mqtt
from homeassistant.components.mqtt import CONFIG_SCHEMA, debug_info from homeassistant.components.mqtt import CONFIG_SCHEMA, debug_info
from homeassistant.components.mqtt.client import EnsureJobAfterCooldown
from homeassistant.components.mqtt.mixins import MQTT_ENTITY_DEVICE_INFO_SCHEMA from homeassistant.components.mqtt.mixins import MQTT_ENTITY_DEVICE_INFO_SCHEMA
from homeassistant.components.mqtt.models import MessageCallbackType, ReceiveMessage from homeassistant.components.mqtt.models import MessageCallbackType, ReceiveMessage
from homeassistant.config_entries import ConfigEntryDisabler, ConfigEntryState from homeassistant.config_entries import ConfigEntryDisabler, ConfigEntryState
@ -1262,6 +1263,9 @@ async def test_subscribe_special_characters(
assert calls[0].payload == payload assert calls[0].payload == payload
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
async def test_subscribe_same_topic( async def test_subscribe_same_topic(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
@ -1286,26 +1290,35 @@ async def test_subscribe_same_topic(
def _callback_b(msg: ReceiveMessage) -> None: def _callback_b(msg: ReceiveMessage) -> None:
calls_b.append(msg) calls_b.append(msg)
await mqtt.async_subscribe(hass, "test/state", _callback_a) await mqtt.async_subscribe(hass, "test/state", _callback_a, qos=0)
async_fire_mqtt_message( async_fire_mqtt_message(
hass, "test/state", "online" hass, "test/state", "online"
) # Simulate a (retained) message ) # Simulate a (retained) message replaying
async_fire_time_changed(hass, utcnow() + timedelta(seconds=1))
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls_a) == 1 assert len(calls_a) == 1
mqtt_client_mock.subscribe.assert_called() mqtt_client_mock.subscribe.assert_called()
calls_a = [] calls_a = []
mqtt_client_mock.reset_mock() mqtt_client_mock.reset_mock()
await mqtt.async_subscribe(hass, "test/state", _callback_b) async_fire_time_changed(hass, utcnow() + timedelta(seconds=3))
await hass.async_block_till_done()
await mqtt.async_subscribe(hass, "test/state", _callback_b, qos=1)
async_fire_mqtt_message( async_fire_mqtt_message(
hass, "test/state", "online" hass, "test/state", "online"
) # Simulate a (retained) message ) # Simulate a (retained) message replaying
async_fire_time_changed(hass, utcnow() + timedelta(seconds=1))
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=1))
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls_a) == 1 assert len(calls_a) == 1
assert len(calls_b) == 1 assert len(calls_b) == 1
mqtt_client_mock.subscribe.assert_called() mqtt_client_mock.subscribe.assert_called()
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
async def test_not_calling_unsubscribe_with_active_subscribers( async def test_not_calling_unsubscribe_with_active_subscribers(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
@ -1317,8 +1330,10 @@ async def test_not_calling_unsubscribe_with_active_subscribers(
# Fake that the client is connected # Fake that the client is connected
mqtt_mock().connected = True mqtt_mock().connected = True
unsub = await mqtt.async_subscribe(hass, "test/state", record_calls) unsub = await mqtt.async_subscribe(hass, "test/state", record_calls, 2)
await mqtt.async_subscribe(hass, "test/state", record_calls) await mqtt.async_subscribe(hass, "test/state", record_calls, 1)
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
assert mqtt_client_mock.subscribe.called assert mqtt_client_mock.subscribe.called
@ -1327,6 +1342,30 @@ async def test_not_calling_unsubscribe_with_active_subscribers(
assert not mqtt_client_mock.unsubscribe.called assert not mqtt_client_mock.unsubscribe.called
async def test_not_calling_subscribe_when_unsubscribed_within_cooldown(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator,
record_calls: MessageCallbackType,
) -> None:
"""Test not calling subscribe() when it is unsubscribed.
Make sure subscriptions are cleared if unsubscribed before
the subscribe cool down period has ended.
"""
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
# Fake that the client is connected
mqtt_mock().connected = True
unsub = await mqtt.async_subscribe(hass, "test/state", record_calls)
unsub()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown
await hass.async_block_till_done()
assert not mqtt_client_mock.subscribe.called
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
async def test_unsubscribe_race( async def test_unsubscribe_race(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
@ -1351,13 +1390,15 @@ async def test_unsubscribe_race(
unsub() unsub()
await mqtt.async_subscribe(hass, "test/state", _callback_b) await mqtt.async_subscribe(hass, "test/state", _callback_b)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done()
async_fire_mqtt_message(hass, "test/state", "online") async_fire_mqtt_message(hass, "test/state", "online")
await hass.async_block_till_done() await hass.async_block_till_done()
assert not calls_a assert not calls_a
assert calls_b assert calls_b
# We allow either calls [subscribe, unsubscribe, subscribe] or [subscribe, subscribe] # We allow either calls [subscribe, unsubscribe, subscribe], [subscribe, subscribe] or
# when both subscriptions were combined [subscribe]
expected_calls_1 = [ expected_calls_1 = [
call.subscribe("test/state", 0), call.subscribe("test/state", 0),
call.unsubscribe("test/state"), call.unsubscribe("test/state"),
@ -1367,13 +1408,23 @@ async def test_unsubscribe_race(
call.subscribe("test/state", 0), call.subscribe("test/state", 0),
call.subscribe("test/state", 0), call.subscribe("test/state", 0),
] ]
assert mqtt_client_mock.mock_calls in (expected_calls_1, expected_calls_2) expected_calls_3 = [
call.subscribe("test/state", 0),
]
assert mqtt_client_mock.mock_calls in (
expected_calls_1,
expected_calls_2,
expected_calls_3,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mqtt_config_entry_data", "mqtt_config_entry_data",
[{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}], [{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}],
) )
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
async def test_restore_subscriptions_on_reconnect( async def test_restore_subscriptions_on_reconnect(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
@ -1386,13 +1437,15 @@ async def test_restore_subscriptions_on_reconnect(
mqtt_mock().connected = True mqtt_mock().connected = True
await mqtt.async_subscribe(hass, "test/state", record_calls) await mqtt.async_subscribe(hass, "test/state", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
assert mqtt_client_mock.subscribe.call_count == 1 assert mqtt_client_mock.subscribe.call_count == 1
mqtt_client_mock.on_disconnect(None, None, 0) mqtt_client_mock.on_disconnect(None, None, 0)
with patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0): mqtt_client_mock.on_connect(None, None, None, 0)
mqtt_client_mock.on_connect(None, None, None, 0) async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done()
assert mqtt_client_mock.subscribe.call_count == 2 assert mqtt_client_mock.subscribe.call_count == 2
@ -1400,6 +1453,9 @@ async def test_restore_subscriptions_on_reconnect(
"mqtt_config_entry_data", "mqtt_config_entry_data",
[{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}], [{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}],
) )
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 1.0)
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 1.0)
async def test_restore_all_active_subscriptions_on_reconnect( async def test_restore_all_active_subscriptions_on_reconnect(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
@ -1412,14 +1468,15 @@ async def test_restore_all_active_subscriptions_on_reconnect(
mqtt_mock().connected = True mqtt_mock().connected = True
unsub = await mqtt.async_subscribe(hass, "test/state", record_calls, qos=2) unsub = await mqtt.async_subscribe(hass, "test/state", record_calls, qos=2)
await mqtt.async_subscribe(hass, "test/state", record_calls)
await mqtt.async_subscribe(hass, "test/state", record_calls, qos=1) await mqtt.async_subscribe(hass, "test/state", record_calls, qos=1)
await mqtt.async_subscribe(hass, "test/state", record_calls, qos=0)
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
# the subscribtion with the highest QoS should survive
expected = [ expected = [
call("test/state", 2), call("test/state", 2),
call("test/state", 0),
call("test/state", 1),
] ]
assert mqtt_client_mock.subscribe.mock_calls == expected assert mqtt_client_mock.subscribe.mock_calls == expected
@ -1428,13 +1485,60 @@ async def test_restore_all_active_subscriptions_on_reconnect(
assert mqtt_client_mock.unsubscribe.call_count == 0 assert mqtt_client_mock.unsubscribe.call_count == 0
mqtt_client_mock.on_disconnect(None, None, 0) mqtt_client_mock.on_disconnect(None, None, 0)
with patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0): await hass.async_block_till_done()
mqtt_client_mock.on_connect(None, None, None, 0) mqtt_client_mock.on_connect(None, None, None, 0)
await hass.async_block_till_done() async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown
await hass.async_block_till_done()
expected.append(call("test/state", 1)) expected.append(call("test/state", 1))
assert mqtt_client_mock.subscribe.mock_calls == expected assert mqtt_client_mock.subscribe.mock_calls == expected
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown
await hass.async_block_till_done()
@pytest.mark.parametrize(
"mqtt_config_entry_data",
[{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}],
)
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 1.0)
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 1.0)
async def test_subscribed_at_highest_qos(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator,
record_calls: MessageCallbackType,
) -> None:
"""Test the highest qos as assigned when subscribing to the same topic."""
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
# Fake that the client is connected
mqtt_mock().connected = True
await mqtt.async_subscribe(hass, "test/state", record_calls, qos=0)
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown
await hass.async_block_till_done()
assert mqtt_client_mock.subscribe.mock_calls == [
call("test/state", 0),
]
mqtt_client_mock.reset_mock()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown
await hass.async_block_till_done()
await hass.async_block_till_done()
await mqtt.async_subscribe(hass, "test/state", record_calls, qos=1)
await mqtt.async_subscribe(hass, "test/state", record_calls, qos=2)
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown
await hass.async_block_till_done()
# the subscribtion with the highest QoS should survive
assert mqtt_client_mock.subscribe.mock_calls == [
call("test/state", 2),
]
async def test_reload_entry_with_restored_subscriptions( async def test_reload_entry_with_restored_subscriptions(
hass: HomeAssistant, hass: HomeAssistant,
@ -1499,6 +1603,93 @@ async def test_reload_entry_with_restored_subscriptions(
assert calls[1].payload == "wild-card-payload3" assert calls[1].payload == "wild-card-payload3"
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 2)
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 2)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 2)
async def test_canceling_debouncer_on_shutdown(
hass: HomeAssistant,
record_calls: MessageCallbackType,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator,
) -> None:
"""Test canceling the debouncer when HA shuts down."""
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
# Fake that the client is connected
mqtt_mock().connected = True
await mqtt.async_subscribe(hass, "test/state1", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.2))
await hass.async_block_till_done()
await mqtt.async_subscribe(hass, "test/state2", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.2))
await hass.async_block_till_done()
await mqtt.async_subscribe(hass, "test/state3", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.2))
await hass.async_block_till_done()
await mqtt.async_subscribe(hass, "test/state4", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.2))
await hass.async_block_till_done()
await mqtt.async_subscribe(hass, "test/state5", record_calls)
mqtt_client_mock.subscribe.assert_not_called()
# Stop HA so the scheduled task will be canceled
hass.bus.fire(EVENT_HOMEASSISTANT_STOP)
# mock disconnect status
mqtt_client_mock.on_disconnect(None, None, 0)
await hass.async_block_till_done()
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5))
await hass.async_block_till_done()
mqtt_client_mock.subscribe.assert_not_called()
async def test_canceling_debouncer_normal(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test canceling the debouncer before completion."""
async def _async_myjob() -> None:
await asyncio.sleep(1.0)
debouncer = EnsureJobAfterCooldown(0.0, _async_myjob)
debouncer.async_schedule()
await asyncio.sleep(0.01)
assert debouncer._task is not None
await debouncer.async_cleanup()
assert debouncer._task is None
async def test_canceling_debouncer_throws(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test canceling the debouncer when HA shuts down."""
async def _async_myjob() -> None:
await asyncio.sleep(1.0)
debouncer = EnsureJobAfterCooldown(0.0, _async_myjob)
debouncer.async_schedule()
await asyncio.sleep(0.01)
assert debouncer._task is not None
# let debouncer._task fail by mocking it
with patch.object(debouncer, "_task") as task:
task.cancel = MagicMock(return_value=True)
await debouncer.async_cleanup()
assert "Error cleaning up task" in caplog.text
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5))
await hass.async_block_till_done()
async def test_initial_setup_logs_error( async def test_initial_setup_logs_error(
hass: HomeAssistant, hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
@ -1575,21 +1766,30 @@ async def test_publish_error(
assert "Failed to connect to MQTT server: Out of memory." in caplog.text assert "Failed to connect to MQTT server: Out of memory." in caplog.text
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
async def test_subscribe_error( async def test_subscribe_error(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator,
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
record_calls: MessageCallbackType, record_calls: MessageCallbackType,
caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test publish error.""" """Test publish error."""
await mqtt_mock_entry_no_yaml_config() await mqtt_mock_entry_no_yaml_config()
mqtt_client_mock.on_connect(mqtt_client_mock, None, None, 0) mqtt_client_mock.on_connect(mqtt_client_mock, None, None, 0)
await hass.async_block_till_done() await hass.async_block_till_done()
with pytest.raises(HomeAssistantError): await hass.async_block_till_done()
# simulate client is not connected error before subscribing mqtt_client_mock.reset_mock()
mqtt_client_mock.subscribe.side_effect = lambda *args: (4, None) # simulate client is not connected error before subscribing
await mqtt.async_subscribe(hass, "some-topic", record_calls) mqtt_client_mock.subscribe.side_effect = lambda *args: (4, None)
await mqtt.async_subscribe(hass, "some-topic", record_calls)
while mqtt_client_mock.subscribe.call_count == 0:
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done()
assert (
"Error talking to MQTT: The client is not currently connected." in caplog.text
)
async def test_handle_message_callback( async def test_handle_message_callback(