Automatically clean up executor as part of closing loop (#43284)

This commit is contained in:
Paulus Schoutsen 2020-11-16 15:43:48 +01:00 committed by GitHub
parent 5d83f0a911
commit 819dd27925
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 120 additions and 166 deletions

View file

@ -15,11 +15,7 @@ import yarl
from homeassistant import config as conf_util, config_entries, core, loader from homeassistant import config as conf_util, config_entries, core, loader
from homeassistant.components import http from homeassistant.components import http
from homeassistant.const import ( from homeassistant.const import REQUIRED_NEXT_PYTHON_DATE, REQUIRED_NEXT_PYTHON_VER
EVENT_HOMEASSISTANT_STOP,
REQUIRED_NEXT_PYTHON_DATE,
REQUIRED_NEXT_PYTHON_VER,
)
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import ( from homeassistant.setup import (
@ -142,11 +138,9 @@ async def async_setup_hass(
_LOGGER.warning("Detected that frontend did not load. Activating safe mode") _LOGGER.warning("Detected that frontend did not load. Activating safe mode")
# Ask integrations to shut down. It's messy but we can't # Ask integrations to shut down. It's messy but we can't
# do a clean stop without knowing what is broken # do a clean stop without knowing what is broken
hass.async_track_tasks()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP, {})
with contextlib.suppress(asyncio.TimeoutError): with contextlib.suppress(asyncio.TimeoutError):
async with hass.timeout.async_timeout(10): async with hass.timeout.async_timeout(10):
await hass.async_block_till_done() await hass.async_stop()
safe_mode = True safe_mode = True
old_config = hass.config old_config = hass.config

View file

@ -257,12 +257,9 @@ class HomeAssistant:
fire_coroutine_threadsafe(self.async_start(), self.loop) fire_coroutine_threadsafe(self.async_start(), self.loop)
# Run forever # Run forever
try: # Block until stopped
# Block until stopped _LOGGER.info("Starting Home Assistant core loop")
_LOGGER.info("Starting Home Assistant core loop") self.loop.run_forever()
self.loop.run_forever()
finally:
self.loop.close()
return self.exit_code return self.exit_code
async def async_run(self, *, attach_signals: bool = True) -> int: async def async_run(self, *, attach_signals: bool = True) -> int:
@ -559,16 +556,11 @@ class HomeAssistant:
"Timed out waiting for shutdown stage 3 to complete, the shutdown will continue" "Timed out waiting for shutdown stage 3 to complete, the shutdown will continue"
) )
# Python 3.9+ and backported in runner.py
await self.loop.shutdown_default_executor() # type: ignore
self.exit_code = exit_code self.exit_code = exit_code
self.state = CoreState.stopped self.state = CoreState.stopped
if self._stopped is not None: if self._stopped is not None:
self._stopped.set() self._stopped.set()
else:
self.loop.stop()
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)

View file

@ -4,7 +4,6 @@ from concurrent.futures import ThreadPoolExecutor
import dataclasses import dataclasses
import logging import logging
import sys import sys
import threading
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from homeassistant import bootstrap from homeassistant import bootstrap
@ -77,29 +76,14 @@ class HassEventLoopPolicy(PolicyBase): # type: ignore
loop.set_default_executor, "sets default executor on the event loop" loop.set_default_executor, "sets default executor on the event loop"
) )
# Python 3.9+ # Shut down executor when we shut down loop
if hasattr(loop, "shutdown_default_executor"): orig_close = loop.close
return loop
# Copied from Python 3.9 source def close() -> None:
def _do_shutdown(future: asyncio.Future) -> None: executor.shutdown(wait=True)
try: orig_close()
executor.shutdown(wait=True)
loop.call_soon_threadsafe(future.set_result, None)
except Exception as ex: # pylint: disable=broad-except
loop.call_soon_threadsafe(future.set_exception, ex)
async def shutdown_default_executor() -> None: loop.close = close # type: ignore
"""Schedule the shutdown of the default executor."""
future = loop.create_future()
thread = threading.Thread(target=_do_shutdown, args=(future,))
thread.start()
try:
await future
finally:
thread.join()
setattr(loop, "shutdown_default_executor", shutdown_default_executor)
return loop return loop

View file

@ -9,7 +9,6 @@ from io import StringIO
import json import json
import logging import logging
import os import os
import sys
import threading import threading
import time import time
import uuid import uuid
@ -109,24 +108,21 @@ def get_test_config_dir(*add_path):
def get_test_home_assistant(): def get_test_home_assistant():
"""Return a Home Assistant object pointing at test config directory.""" """Return a Home Assistant object pointing at test config directory."""
if sys.platform == "win32": loop = asyncio.new_event_loop()
loop = asyncio.ProactorEventLoop()
else:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
hass = loop.run_until_complete(async_test_home_assistant(loop)) hass = loop.run_until_complete(async_test_home_assistant(loop))
stop_event = threading.Event() loop_stop_event = threading.Event()
def run_loop(): def run_loop():
"""Run event loop.""" """Run event loop."""
# pylint: disable=protected-access # pylint: disable=protected-access
loop._thread_ident = threading.get_ident() loop._thread_ident = threading.get_ident()
loop.run_forever() loop.run_forever()
stop_event.set() loop_stop_event.set()
orig_stop = hass.stop orig_stop = hass.stop
hass._stopped = Mock(set=loop.stop)
def start_hass(*mocks): def start_hass(*mocks):
"""Start hass.""" """Start hass."""
@ -135,7 +131,7 @@ def get_test_home_assistant():
def stop_hass(): def stop_hass():
"""Stop hass.""" """Stop hass."""
orig_stop() orig_stop()
stop_event.wait() loop_stop_event.wait()
loop.close() loop.close()
hass.start = start_hass hass.start = start_hass

View file

@ -38,7 +38,11 @@ import homeassistant.util.dt as dt_util
from homeassistant.util.unit_system import METRIC_SYSTEM from homeassistant.util.unit_system import METRIC_SYSTEM
from tests.async_mock import MagicMock, Mock, PropertyMock, patch from tests.async_mock import MagicMock, Mock, PropertyMock, patch
from tests.common import async_mock_service, get_test_home_assistant from tests.common import (
async_capture_events,
async_mock_service,
get_test_home_assistant,
)
PST = pytz.timezone("America/Los_Angeles") PST = pytz.timezone("America/Los_Angeles")
@ -151,22 +155,14 @@ def test_async_run_hass_job_delegates_non_async():
assert len(hass.async_add_hass_job.mock_calls) == 1 assert len(hass.async_add_hass_job.mock_calls) == 1
def test_stage_shutdown(): async def test_stage_shutdown(hass):
"""Simulate a shutdown, test calling stuff.""" """Simulate a shutdown, test calling stuff."""
hass = get_test_home_assistant() test_stop = async_capture_events(hass, EVENT_HOMEASSISTANT_STOP)
test_stop = [] test_final_write = async_capture_events(hass, EVENT_HOMEASSISTANT_FINAL_WRITE)
test_final_write = [] test_close = async_capture_events(hass, EVENT_HOMEASSISTANT_CLOSE)
test_close = [] test_all = async_capture_events(hass, MATCH_ALL)
test_all = []
hass.bus.listen(EVENT_HOMEASSISTANT_STOP, lambda event: test_stop.append(event)) await hass.async_stop()
hass.bus.listen(
EVENT_HOMEASSISTANT_FINAL_WRITE, lambda event: test_final_write.append(event)
)
hass.bus.listen(EVENT_HOMEASSISTANT_CLOSE, lambda event: test_close.append(event))
hass.bus.listen("*", lambda event: test_all.append(event))
hass.stop()
assert len(test_stop) == 1 assert len(test_stop) == 1
assert len(test_close) == 1 assert len(test_close) == 1
@ -341,147 +337,139 @@ def test_state_as_dict():
assert state.as_dict() is state.as_dict() assert state.as_dict() is state.as_dict()
class TestEventBus(unittest.TestCase): async def test_add_remove_listener(hass):
"""Test EventBus methods.""" """Test remove_listener method."""
old_count = len(hass.bus.async_listeners())
# pylint: disable=invalid-name def listener(_):
def setUp(self): pass
"""Set up things to be run when tests are started."""
self.hass = get_test_home_assistant()
self.bus = self.hass.bus
# pylint: disable=invalid-name unsub = hass.bus.async_listen("test", listener)
def tearDown(self):
"""Stop down stuff we started."""
self.hass.stop()
def test_add_remove_listener(self): assert old_count + 1 == len(hass.bus.async_listeners())
"""Test remove_listener method."""
self.hass.allow_pool = False
old_count = len(self.bus.listeners)
def listener(_): # Remove listener
pass unsub()
assert old_count == len(hass.bus.async_listeners())
unsub = self.bus.listen("test", listener) # Should do nothing now
unsub()
assert old_count + 1 == len(self.bus.listeners)
# Remove listener async def test_unsubscribe_listener(hass):
unsub() """Test unsubscribe listener from returned function."""
assert old_count == len(self.bus.listeners) calls = []
# Should do nothing now @ha.callback
unsub() def listener(event):
"""Mock listener."""
calls.append(event)
def test_unsubscribe_listener(self): unsub = hass.bus.async_listen("test", listener)
"""Test unsubscribe listener from returned function."""
calls = []
@ha.callback hass.bus.async_fire("test")
def listener(event): await hass.async_block_till_done()
"""Mock listener."""
calls.append(event)
unsub = self.bus.listen("test", listener) assert len(calls) == 1
self.bus.fire("test") unsub()
self.hass.block_till_done()
assert len(calls) == 1 hass.bus.async_fire("event")
await hass.async_block_till_done()
unsub() assert len(calls) == 1
self.bus.fire("event")
self.hass.block_till_done()
assert len(calls) == 1 async def test_listen_once_event_with_callback(hass):
"""Test listen_once_event method."""
runs = []
def test_listen_once_event_with_callback(self): @ha.callback
"""Test listen_once_event method.""" def event_handler(event):
runs = [] runs.append(event)
@ha.callback hass.bus.async_listen_once("test_event", event_handler)
def event_handler(event):
runs.append(event)
self.bus.listen_once("test_event", event_handler) hass.bus.async_fire("test_event")
# Second time it should not increase runs
hass.bus.async_fire("test_event")
self.bus.fire("test_event") await hass.async_block_till_done()
# Second time it should not increase runs assert len(runs) == 1
self.bus.fire("test_event")
self.hass.block_till_done()
assert len(runs) == 1
def test_listen_once_event_with_coroutine(self): async def test_listen_once_event_with_coroutine(hass):
"""Test listen_once_event method.""" """Test listen_once_event method."""
runs = [] runs = []
async def event_handler(event): async def event_handler(event):
runs.append(event) runs.append(event)
self.bus.listen_once("test_event", event_handler) hass.bus.async_listen_once("test_event", event_handler)
self.bus.fire("test_event") hass.bus.async_fire("test_event")
# Second time it should not increase runs # Second time it should not increase runs
self.bus.fire("test_event") hass.bus.async_fire("test_event")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(runs) == 1 assert len(runs) == 1
def test_listen_once_event_with_thread(self):
"""Test listen_once_event method."""
runs = []
def event_handler(event): async def test_listen_once_event_with_thread(hass):
runs.append(event) """Test listen_once_event method."""
runs = []
self.bus.listen_once("test_event", event_handler) def event_handler(event):
runs.append(event)
self.bus.fire("test_event") hass.bus.async_listen_once("test_event", event_handler)
# Second time it should not increase runs
self.bus.fire("test_event")
self.hass.block_till_done() hass.bus.async_fire("test_event")
assert len(runs) == 1 # Second time it should not increase runs
hass.bus.async_fire("test_event")
def test_thread_event_listener(self): await hass.async_block_till_done()
"""Test thread event listener.""" assert len(runs) == 1
thread_calls = []
def thread_listener(event):
thread_calls.append(event)
self.bus.listen("test_thread", thread_listener) async def test_thread_event_listener(hass):
self.bus.fire("test_thread") """Test thread event listener."""
self.hass.block_till_done() thread_calls = []
assert len(thread_calls) == 1
def test_callback_event_listener(self): def thread_listener(event):
"""Test callback event listener.""" thread_calls.append(event)
callback_calls = []
@ha.callback hass.bus.async_listen("test_thread", thread_listener)
def callback_listener(event): hass.bus.async_fire("test_thread")
callback_calls.append(event) await hass.async_block_till_done()
assert len(thread_calls) == 1
self.bus.listen("test_callback", callback_listener)
self.bus.fire("test_callback")
self.hass.block_till_done()
assert len(callback_calls) == 1
def test_coroutine_event_listener(self): async def test_callback_event_listener(hass):
"""Test coroutine event listener.""" """Test callback event listener."""
coroutine_calls = [] callback_calls = []
async def coroutine_listener(event): @ha.callback
coroutine_calls.append(event) def callback_listener(event):
callback_calls.append(event)
self.bus.listen("test_coroutine", coroutine_listener) hass.bus.async_listen("test_callback", callback_listener)
self.bus.fire("test_coroutine") hass.bus.async_fire("test_callback")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(coroutine_calls) == 1 assert len(callback_calls) == 1
async def test_coroutine_event_listener(hass):
"""Test coroutine event listener."""
coroutine_calls = []
async def coroutine_listener(event):
coroutine_calls.append(event)
hass.bus.async_listen("test_coroutine", coroutine_listener)
hass.bus.async_fire("test_coroutine")
await hass.async_block_till_done()
assert len(coroutine_calls) == 1
def test_state_init(): def test_state_init():