Simplify MQTT subscribe debouncer execution (#117006)
This commit is contained in:
parent
38a3c3a823
commit
649dd55da9
3 changed files with 21 additions and 23 deletions
|
@ -316,7 +316,7 @@ class EnsureJobAfterCooldown:
|
|||
self._loop = asyncio.get_running_loop()
|
||||
self._timeout = timeout
|
||||
self._callback = callback_job
|
||||
self._task: asyncio.Future | None = None
|
||||
self._task: asyncio.Task | None = None
|
||||
self._timer: asyncio.TimerHandle | None = None
|
||||
|
||||
def set_timeout(self, timeout: float) -> None:
|
||||
|
@ -331,28 +331,23 @@ class EnsureJobAfterCooldown:
|
|||
_LOGGER.error("%s", ha_error)
|
||||
|
||||
@callback
|
||||
def _async_task_done(self, task: asyncio.Future) -> None:
|
||||
def _async_task_done(self, task: asyncio.Task) -> None:
|
||||
"""Handle task done."""
|
||||
self._task = None
|
||||
|
||||
@callback
|
||||
def _async_execute(self) -> None:
|
||||
def async_execute(self) -> asyncio.Task:
|
||||
"""Execute the job."""
|
||||
if self._task:
|
||||
# Task already running,
|
||||
# so we schedule another run
|
||||
self.async_schedule()
|
||||
return
|
||||
return self._task
|
||||
|
||||
self._async_cancel_timer()
|
||||
self._task = create_eager_task(self._async_job())
|
||||
self._task.add_done_callback(self._async_task_done)
|
||||
|
||||
async def async_fire(self) -> None:
|
||||
"""Execute the job immediately."""
|
||||
if self._task:
|
||||
await self._task
|
||||
self._async_execute()
|
||||
return self._task
|
||||
|
||||
@callback
|
||||
def _async_cancel_timer(self) -> None:
|
||||
|
@ -367,7 +362,7 @@ class EnsureJobAfterCooldown:
|
|||
# We want to reschedule the timer in the future
|
||||
# every time this is called.
|
||||
self._async_cancel_timer()
|
||||
self._timer = self._loop.call_later(self._timeout, self._async_execute)
|
||||
self._timer = self._loop.call_later(self._timeout, self.async_execute)
|
||||
|
||||
async def async_cleanup(self) -> None:
|
||||
"""Cleanup any pending task."""
|
||||
|
@ -882,7 +877,7 @@ class MQTT:
|
|||
await self._discovery_cooldown() # Wait for MQTT discovery to cool down
|
||||
# Update subscribe cooldown period to a shorter time
|
||||
# and make sure we flush the debouncer
|
||||
await self._subscribe_debouncer.async_fire()
|
||||
await self._subscribe_debouncer.async_execute()
|
||||
self._subscribe_debouncer.set_timeout(SUBSCRIBE_COOLDOWN)
|
||||
await self.async_publish(
|
||||
topic=birth_message.topic,
|
||||
|
|
|
@ -2658,19 +2658,19 @@ async def test_subscription_done_when_birth_message_is_sent(
|
|||
mqtt_client_mock.on_connect(None, None, 0, 0)
|
||||
await hass.async_block_till_done()
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||
await mqtt.async_subscribe(hass, "topic/test", record_calls)
|
||||
# We wait until we receive a birth message
|
||||
await asyncio.wait_for(birth.wait(), 1)
|
||||
# Assert we already have subscribed at the client
|
||||
# for new config payloads at the time we the birth message is received
|
||||
assert ("homeassistant/+/+/config", 0) in help_all_subscribe_calls(
|
||||
mqtt_client_mock
|
||||
)
|
||||
assert ("homeassistant/+/+/+/config", 0) in help_all_subscribe_calls(
|
||||
mqtt_client_mock
|
||||
)
|
||||
mqtt_client_mock.publish.assert_called_with(
|
||||
"homeassistant/status", "online", 0, False
|
||||
)
|
||||
|
||||
# Assert we already have subscribed at the client
|
||||
# for new config payloads at the time we the birth message is received
|
||||
subscribe_calls = help_all_subscribe_calls(mqtt_client_mock)
|
||||
assert ("homeassistant/+/+/config", 0) in subscribe_calls
|
||||
assert ("homeassistant/+/+/+/config", 0) in subscribe_calls
|
||||
mqtt_client_mock.publish.assert_called_with(
|
||||
"homeassistant/status", "online", 0, False
|
||||
)
|
||||
assert ("topic/test", 0) in subscribe_calls
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -335,6 +335,9 @@ async def test_default_entity_and_device_name(
|
|||
|
||||
# Assert that no issues ware registered
|
||||
assert len(events) == 0
|
||||
await hass.async_block_till_done()
|
||||
# Assert that no issues ware registered
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
async def test_name_attribute_is_set_or_not(
|
||||
|
|
Loading…
Add table
Reference in a new issue