From f35e7d1129ccde6994a0969137efdca63a4b0a1e Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Mon, 2 May 2022 16:41:14 +0200 Subject: [PATCH] Allow cancelling async_at_start helper (#71196) --- homeassistant/helpers/start.py | 15 +++++++-- tests/helpers/test_start.py | 61 +++++++++++++++++++++++++++++++--- 2 files changed, 69 insertions(+), 7 deletions(-) diff --git a/homeassistant/helpers/start.py b/homeassistant/helpers/start.py index 7f919f5351d..6c17ae5be3a 100644 --- a/homeassistant/helpers/start.py +++ b/homeassistant/helpers/start.py @@ -20,8 +20,19 @@ def async_at_start( hass.async_run_hass_job(at_start_job, hass) return lambda: None - async def _matched_event(event: Event) -> None: + unsub: None | CALLBACK_TYPE = None + + @callback + def _matched_event(event: Event) -> None: """Call the callback when Home Assistant started.""" hass.async_run_hass_job(at_start_job, hass) + nonlocal unsub + unsub = None - return hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, _matched_event) + @callback + def cancel() -> None: + if unsub: + unsub() + + unsub = hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, _matched_event) + return cancel diff --git a/tests/helpers/test_start.py b/tests/helpers/test_start.py index 55f98cf60eb..bc32ffa35fd 100644 --- a/tests/helpers/test_start.py +++ b/tests/helpers/test_start.py @@ -27,7 +27,7 @@ async def test_at_start_when_running_awaitable(hass): assert len(calls) == 2 -async def test_at_start_when_running_callback(hass): +async def test_at_start_when_running_callback(hass, caplog): """Test at start when already running.""" assert hass.state == core.CoreState.running assert hass.is_running @@ -39,15 +39,19 @@ async def test_at_start_when_running_callback(hass): """Home Assistant is started.""" calls.append(1) - start.async_at_start(hass, cb_at_start) + start.async_at_start(hass, cb_at_start)() assert len(calls) == 1 hass.state = core.CoreState.starting assert hass.is_running - start.async_at_start(hass, cb_at_start) + start.async_at_start(hass, cb_at_start)() assert len(calls) == 2 + # Check the unnecessary cancel did not generate warnings or errors + for record in caplog.records: + assert record.levelname in ("DEBUG", "INFO") + async def test_at_start_when_starting_awaitable(hass): """Test at start when yet to start.""" @@ -69,7 +73,7 @@ async def test_at_start_when_starting_awaitable(hass): assert len(calls) == 1 -async def test_at_start_when_starting_callback(hass): +async def test_at_start_when_starting_callback(hass, caplog): """Test at start when yet to start.""" hass.state = core.CoreState.not_running assert not hass.is_running @@ -81,10 +85,57 @@ async def test_at_start_when_starting_callback(hass): """Home Assistant is started.""" calls.append(1) - start.async_at_start(hass, cb_at_start) + cancel = start.async_at_start(hass, cb_at_start) await hass.async_block_till_done() assert len(calls) == 0 hass.bus.async_fire(EVENT_HOMEASSISTANT_START) await hass.async_block_till_done() assert len(calls) == 1 + + cancel() + + # Check the unnecessary cancel did not generate warnings or errors + for record in caplog.records: + assert record.levelname in ("DEBUG", "INFO") + + +async def test_cancelling_when_running(hass, caplog): + """Test cancelling at start when already running.""" + assert hass.state == core.CoreState.running + assert hass.is_running + + calls = [] + + async def cb_at_start(hass): + """Home Assistant is started.""" + calls.append(1) + + start.async_at_start(hass, cb_at_start)() + await hass.async_block_till_done() + assert len(calls) == 1 + + # Check the unnecessary cancel did not generate warnings or errors + for record in caplog.records: + assert record.levelname in ("DEBUG", "INFO") + + +async def test_cancelling_when_starting(hass): + """Test cancelling at start when yet to start.""" + hass.state = core.CoreState.not_running + assert not hass.is_running + + calls = [] + + @core.callback + def cb_at_start(hass): + """Home Assistant is started.""" + calls.append(1) + + start.async_at_start(hass, cb_at_start)() + await hass.async_block_till_done() + assert len(calls) == 0 + + hass.bus.async_fire(EVENT_HOMEASSISTANT_START) + await hass.async_block_till_done() + assert len(calls) == 0