diff --git a/homeassistant/helpers/start.py b/homeassistant/helpers/start.py index f6c9a536a23..fe3bd2b0987 100644 --- a/homeassistant/helpers/start.py +++ b/homeassistant/helpers/start.py @@ -4,21 +4,31 @@ from __future__ import annotations from collections.abc import Callable, Coroutine from typing import Any -from homeassistant.const import EVENT_HOMEASSISTANT_START -from homeassistant.core import CALLBACK_TYPE, Event, HassJob, HomeAssistant, callback +from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STARTED +from homeassistant.core import ( + CALLBACK_TYPE, + CoreState, + Event, + HassJob, + HomeAssistant, + callback, +) @callback -def async_at_start( +def _async_at_core_state( hass: HomeAssistant, at_start_cb: Callable[[HomeAssistant], Coroutine[Any, Any, None] | None], + event_type: str, + check_state: Callable[[HomeAssistant], bool], ) -> CALLBACK_TYPE: - """Execute something when Home Assistant is started. + """Execute a job at_start_cb when Home Assistant has the wanted state. - Will execute it now if Home Assistant is already started. + The job is executed immediately if Home Assistant is in the wanted state. + Will wait for event specified by event_type if it isn't. """ at_start_job = HassJob(at_start_cb) - if hass.is_running: + if check_state(hass): hass.async_run_hass_job(at_start_job, hass) return lambda: None @@ -36,5 +46,43 @@ def async_at_start( if unsub: unsub() - unsub = hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, _matched_event) + unsub = hass.bus.async_listen_once(event_type, _matched_event) return cancel + + +@callback +def async_at_start( + hass: HomeAssistant, + at_start_cb: Callable[[HomeAssistant], Coroutine[Any, Any, None] | None], +) -> CALLBACK_TYPE: + """Execute a job at_start_cb when Home Assistant is starting. + + The job is executed immediately if Home Assistant is already starting or started. + Will wait for EVENT_HOMEASSISTANT_START if it isn't. + """ + + def _is_running(hass: HomeAssistant) -> bool: + return hass.is_running + + return _async_at_core_state( + hass, at_start_cb, EVENT_HOMEASSISTANT_START, _is_running + ) + + +@callback +def async_at_started( + hass: HomeAssistant, + at_start_cb: Callable[[HomeAssistant], Coroutine[Any, Any, None] | None], +) -> CALLBACK_TYPE: + """Execute a job at_start_cb when Home Assistant has started. + + The job is executed immediately if Home Assistant is already started. + Will wait for EVENT_HOMEASSISTANT_STARTED if it isn't. + """ + + def _is_started(hass: HomeAssistant) -> bool: + return hass.state == CoreState.running + + return _async_at_core_state( + hass, at_start_cb, EVENT_HOMEASSISTANT_STARTED, _is_started + ) diff --git a/tests/helpers/test_start.py b/tests/helpers/test_start.py index bc32ffa35fd..bccf99a4274 100644 --- a/tests/helpers/test_start.py +++ b/tests/helpers/test_start.py @@ -1,6 +1,6 @@ """Test starting HA helpers.""" from homeassistant import core -from homeassistant.const import EVENT_HOMEASSISTANT_START +from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STARTED from homeassistant.helpers import start @@ -100,7 +100,7 @@ async def test_at_start_when_starting_callback(hass, caplog): assert record.levelname in ("DEBUG", "INFO") -async def test_cancelling_when_running(hass, caplog): +async def test_cancelling_at_start_when_running(hass, caplog): """Test cancelling at start when already running.""" assert hass.state == core.CoreState.running assert hass.is_running @@ -120,7 +120,7 @@ async def test_cancelling_when_running(hass, caplog): assert record.levelname in ("DEBUG", "INFO") -async def test_cancelling_when_starting(hass): +async def test_cancelling_at_start_when_starting(hass): """Test cancelling at start when yet to start.""" hass.state = core.CoreState.not_running assert not hass.is_running @@ -139,3 +139,148 @@ async def test_cancelling_when_starting(hass): hass.bus.async_fire(EVENT_HOMEASSISTANT_START) await hass.async_block_till_done() assert len(calls) == 0 + + +async def test_at_started_when_running_awaitable(hass): + """Test at started when already started.""" + assert hass.state == core.CoreState.running + + calls = [] + + async def cb_at_start(hass): + """Home Assistant is started.""" + calls.append(1) + + start.async_at_started(hass, cb_at_start) + await hass.async_block_till_done() + assert len(calls) == 1 + + # Test the job is not run if state is CoreState.starting + hass.state = core.CoreState.starting + + start.async_at_started(hass, cb_at_start) + await hass.async_block_till_done() + assert len(calls) == 1 + + +async def test_at_started_when_running_callback(hass, caplog): + """Test at started when already running.""" + assert hass.state == core.CoreState.running + + calls = [] + + @core.callback + def cb_at_start(hass): + """Home Assistant is started.""" + calls.append(1) + + start.async_at_started(hass, cb_at_start)() + assert len(calls) == 1 + + # Test the job is not run if state is CoreState.starting + hass.state = core.CoreState.starting + + start.async_at_started(hass, cb_at_start)() + assert len(calls) == 1 + + # Check the unnecessary cancel did not generate warnings or errors + for record in caplog.records: + assert record.levelname in ("DEBUG", "INFO") + + +async def test_at_started_when_starting_awaitable(hass): + """Test at started when yet to start.""" + hass.state = core.CoreState.not_running + + calls = [] + + async def cb_at_start(hass): + """Home Assistant is started.""" + calls.append(1) + + start.async_at_started(hass, cb_at_start) + await hass.async_block_till_done() + assert len(calls) == 0 + + hass.bus.async_fire(EVENT_HOMEASSISTANT_START) + await hass.async_block_till_done() + assert len(calls) == 0 + + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + assert len(calls) == 1 + + +async def test_at_started_when_starting_callback(hass, caplog): + """Test at started when yet to start.""" + hass.state = core.CoreState.not_running + + calls = [] + + @core.callback + def cb_at_start(hass): + """Home Assistant is started.""" + calls.append(1) + + cancel = start.async_at_started(hass, cb_at_start) + await hass.async_block_till_done() + assert len(calls) == 0 + + hass.bus.async_fire(EVENT_HOMEASSISTANT_START) + await hass.async_block_till_done() + assert len(calls) == 0 + + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + assert len(calls) == 1 + + cancel() + + # Check the unnecessary cancel did not generate warnings or errors + for record in caplog.records: + assert record.levelname in ("DEBUG", "INFO") + + +async def test_cancelling_at_started_when_running(hass, caplog): + """Test cancelling at start when already running.""" + assert hass.state == core.CoreState.running + assert hass.is_running + + calls = [] + + async def cb_at_start(hass): + """Home Assistant is started.""" + calls.append(1) + + start.async_at_started(hass, cb_at_start)() + await hass.async_block_till_done() + assert len(calls) == 1 + + # Check the unnecessary cancel did not generate warnings or errors + for record in caplog.records: + assert record.levelname in ("DEBUG", "INFO") + + +async def test_cancelling_at_started_when_starting(hass): + """Test cancelling at start when yet to start.""" + hass.state = core.CoreState.not_running + assert not hass.is_running + + calls = [] + + @core.callback + def cb_at_start(hass): + """Home Assistant is started.""" + calls.append(1) + + start.async_at_started(hass, cb_at_start)() + await hass.async_block_till_done() + assert len(calls) == 0 + + hass.bus.async_fire(EVENT_HOMEASSISTANT_START) + await hass.async_block_till_done() + assert len(calls) == 0 + + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + assert len(calls) == 0