Automatically clean up executor as part of closing loop (#43284)
This commit is contained in:
parent
5d83f0a911
commit
819dd27925
5 changed files with 120 additions and 166 deletions
|
@ -15,11 +15,7 @@ import yarl
|
|||
|
||||
from homeassistant import config as conf_util, config_entries, core, loader
|
||||
from homeassistant.components import http
|
||||
from homeassistant.const import (
|
||||
EVENT_HOMEASSISTANT_STOP,
|
||||
REQUIRED_NEXT_PYTHON_DATE,
|
||||
REQUIRED_NEXT_PYTHON_VER,
|
||||
)
|
||||
from homeassistant.const import REQUIRED_NEXT_PYTHON_DATE, REQUIRED_NEXT_PYTHON_VER
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.setup import (
|
||||
|
@ -142,11 +138,9 @@ async def async_setup_hass(
|
|||
_LOGGER.warning("Detected that frontend did not load. Activating safe mode")
|
||||
# Ask integrations to shut down. It's messy but we can't
|
||||
# do a clean stop without knowing what is broken
|
||||
hass.async_track_tasks()
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP, {})
|
||||
with contextlib.suppress(asyncio.TimeoutError):
|
||||
async with hass.timeout.async_timeout(10):
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_stop()
|
||||
|
||||
safe_mode = True
|
||||
old_config = hass.config
|
||||
|
|
|
@ -257,12 +257,9 @@ class HomeAssistant:
|
|||
fire_coroutine_threadsafe(self.async_start(), self.loop)
|
||||
|
||||
# Run forever
|
||||
try:
|
||||
# Block until stopped
|
||||
_LOGGER.info("Starting Home Assistant core loop")
|
||||
self.loop.run_forever()
|
||||
finally:
|
||||
self.loop.close()
|
||||
return self.exit_code
|
||||
|
||||
async def async_run(self, *, attach_signals: bool = True) -> int:
|
||||
|
@ -559,16 +556,11 @@ class HomeAssistant:
|
|||
"Timed out waiting for shutdown stage 3 to complete, the shutdown will continue"
|
||||
)
|
||||
|
||||
# Python 3.9+ and backported in runner.py
|
||||
await self.loop.shutdown_default_executor() # type: ignore
|
||||
|
||||
self.exit_code = exit_code
|
||||
self.state = CoreState.stopped
|
||||
|
||||
if self._stopped is not None:
|
||||
self._stopped.set()
|
||||
else:
|
||||
self.loop.stop()
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
|
|
|
@ -4,7 +4,6 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
import dataclasses
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from homeassistant import bootstrap
|
||||
|
@ -77,29 +76,14 @@ class HassEventLoopPolicy(PolicyBase): # type: ignore
|
|||
loop.set_default_executor, "sets default executor on the event loop"
|
||||
)
|
||||
|
||||
# Python 3.9+
|
||||
if hasattr(loop, "shutdown_default_executor"):
|
||||
return loop
|
||||
# Shut down executor when we shut down loop
|
||||
orig_close = loop.close
|
||||
|
||||
# Copied from Python 3.9 source
|
||||
def _do_shutdown(future: asyncio.Future) -> None:
|
||||
try:
|
||||
def close() -> None:
|
||||
executor.shutdown(wait=True)
|
||||
loop.call_soon_threadsafe(future.set_result, None)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
loop.call_soon_threadsafe(future.set_exception, ex)
|
||||
orig_close()
|
||||
|
||||
async def shutdown_default_executor() -> None:
|
||||
"""Schedule the shutdown of the default executor."""
|
||||
future = loop.create_future()
|
||||
thread = threading.Thread(target=_do_shutdown, args=(future,))
|
||||
thread.start()
|
||||
try:
|
||||
await future
|
||||
finally:
|
||||
thread.join()
|
||||
|
||||
setattr(loop, "shutdown_default_executor", shutdown_default_executor)
|
||||
loop.close = close # type: ignore
|
||||
|
||||
return loop
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ from io import StringIO
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
@ -109,24 +108,21 @@ def get_test_config_dir(*add_path):
|
|||
|
||||
def get_test_home_assistant():
|
||||
"""Return a Home Assistant object pointing at test config directory."""
|
||||
if sys.platform == "win32":
|
||||
loop = asyncio.ProactorEventLoop()
|
||||
else:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
asyncio.set_event_loop(loop)
|
||||
hass = loop.run_until_complete(async_test_home_assistant(loop))
|
||||
|
||||
stop_event = threading.Event()
|
||||
loop_stop_event = threading.Event()
|
||||
|
||||
def run_loop():
|
||||
"""Run event loop."""
|
||||
# pylint: disable=protected-access
|
||||
loop._thread_ident = threading.get_ident()
|
||||
loop.run_forever()
|
||||
stop_event.set()
|
||||
loop_stop_event.set()
|
||||
|
||||
orig_stop = hass.stop
|
||||
hass._stopped = Mock(set=loop.stop)
|
||||
|
||||
def start_hass(*mocks):
|
||||
"""Start hass."""
|
||||
|
@ -135,7 +131,7 @@ def get_test_home_assistant():
|
|||
def stop_hass():
|
||||
"""Stop hass."""
|
||||
orig_stop()
|
||||
stop_event.wait()
|
||||
loop_stop_event.wait()
|
||||
loop.close()
|
||||
|
||||
hass.start = start_hass
|
||||
|
|
|
@ -38,7 +38,11 @@ import homeassistant.util.dt as dt_util
|
|||
from homeassistant.util.unit_system import METRIC_SYSTEM
|
||||
|
||||
from tests.async_mock import MagicMock, Mock, PropertyMock, patch
|
||||
from tests.common import async_mock_service, get_test_home_assistant
|
||||
from tests.common import (
|
||||
async_capture_events,
|
||||
async_mock_service,
|
||||
get_test_home_assistant,
|
||||
)
|
||||
|
||||
PST = pytz.timezone("America/Los_Angeles")
|
||||
|
||||
|
@ -151,22 +155,14 @@ def test_async_run_hass_job_delegates_non_async():
|
|||
assert len(hass.async_add_hass_job.mock_calls) == 1
|
||||
|
||||
|
||||
def test_stage_shutdown():
|
||||
async def test_stage_shutdown(hass):
|
||||
"""Simulate a shutdown, test calling stuff."""
|
||||
hass = get_test_home_assistant()
|
||||
test_stop = []
|
||||
test_final_write = []
|
||||
test_close = []
|
||||
test_all = []
|
||||
test_stop = async_capture_events(hass, EVENT_HOMEASSISTANT_STOP)
|
||||
test_final_write = async_capture_events(hass, EVENT_HOMEASSISTANT_FINAL_WRITE)
|
||||
test_close = async_capture_events(hass, EVENT_HOMEASSISTANT_CLOSE)
|
||||
test_all = async_capture_events(hass, MATCH_ALL)
|
||||
|
||||
hass.bus.listen(EVENT_HOMEASSISTANT_STOP, lambda event: test_stop.append(event))
|
||||
hass.bus.listen(
|
||||
EVENT_HOMEASSISTANT_FINAL_WRITE, lambda event: test_final_write.append(event)
|
||||
)
|
||||
hass.bus.listen(EVENT_HOMEASSISTANT_CLOSE, lambda event: test_close.append(event))
|
||||
hass.bus.listen("*", lambda event: test_all.append(event))
|
||||
|
||||
hass.stop()
|
||||
await hass.async_stop()
|
||||
|
||||
assert len(test_stop) == 1
|
||||
assert len(test_close) == 1
|
||||
|
@ -341,40 +337,26 @@ def test_state_as_dict():
|
|||
assert state.as_dict() is state.as_dict()
|
||||
|
||||
|
||||
class TestEventBus(unittest.TestCase):
|
||||
"""Test EventBus methods."""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def setUp(self):
|
||||
"""Set up things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
self.bus = self.hass.bus
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def tearDown(self):
|
||||
"""Stop down stuff we started."""
|
||||
self.hass.stop()
|
||||
|
||||
def test_add_remove_listener(self):
|
||||
async def test_add_remove_listener(hass):
|
||||
"""Test remove_listener method."""
|
||||
self.hass.allow_pool = False
|
||||
old_count = len(self.bus.listeners)
|
||||
old_count = len(hass.bus.async_listeners())
|
||||
|
||||
def listener(_):
|
||||
pass
|
||||
|
||||
unsub = self.bus.listen("test", listener)
|
||||
unsub = hass.bus.async_listen("test", listener)
|
||||
|
||||
assert old_count + 1 == len(self.bus.listeners)
|
||||
assert old_count + 1 == len(hass.bus.async_listeners())
|
||||
|
||||
# Remove listener
|
||||
unsub()
|
||||
assert old_count == len(self.bus.listeners)
|
||||
assert old_count == len(hass.bus.async_listeners())
|
||||
|
||||
# Should do nothing now
|
||||
unsub()
|
||||
|
||||
def test_unsubscribe_listener(self):
|
||||
|
||||
async def test_unsubscribe_listener(hass):
|
||||
"""Test unsubscribe listener from returned function."""
|
||||
calls = []
|
||||
|
||||
|
@ -383,21 +365,22 @@ class TestEventBus(unittest.TestCase):
|
|||
"""Mock listener."""
|
||||
calls.append(event)
|
||||
|
||||
unsub = self.bus.listen("test", listener)
|
||||
unsub = hass.bus.async_listen("test", listener)
|
||||
|
||||
self.bus.fire("test")
|
||||
self.hass.block_till_done()
|
||||
hass.bus.async_fire("test")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
|
||||
unsub()
|
||||
|
||||
self.bus.fire("event")
|
||||
self.hass.block_till_done()
|
||||
hass.bus.async_fire("event")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
|
||||
def test_listen_once_event_with_callback(self):
|
||||
|
||||
async def test_listen_once_event_with_callback(hass):
|
||||
"""Test listen_once_event method."""
|
||||
runs = []
|
||||
|
||||
|
@ -405,60 +388,64 @@ class TestEventBus(unittest.TestCase):
|
|||
def event_handler(event):
|
||||
runs.append(event)
|
||||
|
||||
self.bus.listen_once("test_event", event_handler)
|
||||
hass.bus.async_listen_once("test_event", event_handler)
|
||||
|
||||
self.bus.fire("test_event")
|
||||
hass.bus.async_fire("test_event")
|
||||
# Second time it should not increase runs
|
||||
self.bus.fire("test_event")
|
||||
hass.bus.async_fire("test_event")
|
||||
|
||||
self.hass.block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
assert len(runs) == 1
|
||||
|
||||
def test_listen_once_event_with_coroutine(self):
|
||||
|
||||
async def test_listen_once_event_with_coroutine(hass):
|
||||
"""Test listen_once_event method."""
|
||||
runs = []
|
||||
|
||||
async def event_handler(event):
|
||||
runs.append(event)
|
||||
|
||||
self.bus.listen_once("test_event", event_handler)
|
||||
hass.bus.async_listen_once("test_event", event_handler)
|
||||
|
||||
self.bus.fire("test_event")
|
||||
hass.bus.async_fire("test_event")
|
||||
# Second time it should not increase runs
|
||||
self.bus.fire("test_event")
|
||||
hass.bus.async_fire("test_event")
|
||||
|
||||
self.hass.block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
assert len(runs) == 1
|
||||
|
||||
def test_listen_once_event_with_thread(self):
|
||||
|
||||
async def test_listen_once_event_with_thread(hass):
|
||||
"""Test listen_once_event method."""
|
||||
runs = []
|
||||
|
||||
def event_handler(event):
|
||||
runs.append(event)
|
||||
|
||||
self.bus.listen_once("test_event", event_handler)
|
||||
hass.bus.async_listen_once("test_event", event_handler)
|
||||
|
||||
self.bus.fire("test_event")
|
||||
hass.bus.async_fire("test_event")
|
||||
# Second time it should not increase runs
|
||||
self.bus.fire("test_event")
|
||||
hass.bus.async_fire("test_event")
|
||||
|
||||
self.hass.block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
assert len(runs) == 1
|
||||
|
||||
def test_thread_event_listener(self):
|
||||
|
||||
async def test_thread_event_listener(hass):
|
||||
"""Test thread event listener."""
|
||||
thread_calls = []
|
||||
|
||||
def thread_listener(event):
|
||||
thread_calls.append(event)
|
||||
|
||||
self.bus.listen("test_thread", thread_listener)
|
||||
self.bus.fire("test_thread")
|
||||
self.hass.block_till_done()
|
||||
hass.bus.async_listen("test_thread", thread_listener)
|
||||
hass.bus.async_fire("test_thread")
|
||||
await hass.async_block_till_done()
|
||||
assert len(thread_calls) == 1
|
||||
|
||||
def test_callback_event_listener(self):
|
||||
|
||||
async def test_callback_event_listener(hass):
|
||||
"""Test callback event listener."""
|
||||
callback_calls = []
|
||||
|
||||
|
@ -466,21 +453,22 @@ class TestEventBus(unittest.TestCase):
|
|||
def callback_listener(event):
|
||||
callback_calls.append(event)
|
||||
|
||||
self.bus.listen("test_callback", callback_listener)
|
||||
self.bus.fire("test_callback")
|
||||
self.hass.block_till_done()
|
||||
hass.bus.async_listen("test_callback", callback_listener)
|
||||
hass.bus.async_fire("test_callback")
|
||||
await hass.async_block_till_done()
|
||||
assert len(callback_calls) == 1
|
||||
|
||||
def test_coroutine_event_listener(self):
|
||||
|
||||
async def test_coroutine_event_listener(hass):
|
||||
"""Test coroutine event listener."""
|
||||
coroutine_calls = []
|
||||
|
||||
async def coroutine_listener(event):
|
||||
coroutine_calls.append(event)
|
||||
|
||||
self.bus.listen("test_coroutine", coroutine_listener)
|
||||
self.bus.fire("test_coroutine")
|
||||
self.hass.block_till_done()
|
||||
hass.bus.async_listen("test_coroutine", coroutine_listener)
|
||||
hass.bus.async_fire("test_coroutine")
|
||||
await hass.async_block_till_done()
|
||||
assert len(coroutine_calls) == 1
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue