Automatically clean up executor as part of closing loop (#43284)

This commit is contained in:
Paulus Schoutsen 2020-11-16 15:43:48 +01:00 committed by GitHub
parent 5d83f0a911
commit 819dd27925
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 120 additions and 166 deletions

View file

@ -15,11 +15,7 @@ import yarl
from homeassistant import config as conf_util, config_entries, core, loader from homeassistant import config as conf_util, config_entries, core, loader
from homeassistant.components import http from homeassistant.components import http
from homeassistant.const import ( from homeassistant.const import REQUIRED_NEXT_PYTHON_DATE, REQUIRED_NEXT_PYTHON_VER
EVENT_HOMEASSISTANT_STOP,
REQUIRED_NEXT_PYTHON_DATE,
REQUIRED_NEXT_PYTHON_VER,
)
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import ( from homeassistant.setup import (
@ -142,11 +138,9 @@ async def async_setup_hass(
_LOGGER.warning("Detected that frontend did not load. Activating safe mode") _LOGGER.warning("Detected that frontend did not load. Activating safe mode")
# Ask integrations to shut down. It's messy but we can't # Ask integrations to shut down. It's messy but we can't
# do a clean stop without knowing what is broken # do a clean stop without knowing what is broken
hass.async_track_tasks()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP, {})
with contextlib.suppress(asyncio.TimeoutError): with contextlib.suppress(asyncio.TimeoutError):
async with hass.timeout.async_timeout(10): async with hass.timeout.async_timeout(10):
await hass.async_block_till_done() await hass.async_stop()
safe_mode = True safe_mode = True
old_config = hass.config old_config = hass.config

View file

@ -257,12 +257,9 @@ class HomeAssistant:
fire_coroutine_threadsafe(self.async_start(), self.loop) fire_coroutine_threadsafe(self.async_start(), self.loop)
# Run forever # Run forever
try:
# Block until stopped # Block until stopped
_LOGGER.info("Starting Home Assistant core loop") _LOGGER.info("Starting Home Assistant core loop")
self.loop.run_forever() self.loop.run_forever()
finally:
self.loop.close()
return self.exit_code return self.exit_code
async def async_run(self, *, attach_signals: bool = True) -> int: async def async_run(self, *, attach_signals: bool = True) -> int:
@ -559,16 +556,11 @@ class HomeAssistant:
"Timed out waiting for shutdown stage 3 to complete, the shutdown will continue" "Timed out waiting for shutdown stage 3 to complete, the shutdown will continue"
) )
# Python 3.9+ and backported in runner.py
await self.loop.shutdown_default_executor() # type: ignore
self.exit_code = exit_code self.exit_code = exit_code
self.state = CoreState.stopped self.state = CoreState.stopped
if self._stopped is not None: if self._stopped is not None:
self._stopped.set() self._stopped.set()
else:
self.loop.stop()
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)

View file

@ -4,7 +4,6 @@ from concurrent.futures import ThreadPoolExecutor
import dataclasses import dataclasses
import logging import logging
import sys import sys
import threading
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from homeassistant import bootstrap from homeassistant import bootstrap
@ -77,29 +76,14 @@ class HassEventLoopPolicy(PolicyBase): # type: ignore
loop.set_default_executor, "sets default executor on the event loop" loop.set_default_executor, "sets default executor on the event loop"
) )
# Python 3.9+ # Shut down executor when we shut down loop
if hasattr(loop, "shutdown_default_executor"): orig_close = loop.close
return loop
# Copied from Python 3.9 source def close() -> None:
def _do_shutdown(future: asyncio.Future) -> None:
try:
executor.shutdown(wait=True) executor.shutdown(wait=True)
loop.call_soon_threadsafe(future.set_result, None) orig_close()
except Exception as ex: # pylint: disable=broad-except
loop.call_soon_threadsafe(future.set_exception, ex)
async def shutdown_default_executor() -> None: loop.close = close # type: ignore
"""Schedule the shutdown of the default executor."""
future = loop.create_future()
thread = threading.Thread(target=_do_shutdown, args=(future,))
thread.start()
try:
await future
finally:
thread.join()
setattr(loop, "shutdown_default_executor", shutdown_default_executor)
return loop return loop

View file

@ -9,7 +9,6 @@ from io import StringIO
import json import json
import logging import logging
import os import os
import sys
import threading import threading
import time import time
import uuid import uuid
@ -109,24 +108,21 @@ def get_test_config_dir(*add_path):
def get_test_home_assistant(): def get_test_home_assistant():
"""Return a Home Assistant object pointing at test config directory.""" """Return a Home Assistant object pointing at test config directory."""
if sys.platform == "win32":
loop = asyncio.ProactorEventLoop()
else:
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
hass = loop.run_until_complete(async_test_home_assistant(loop)) hass = loop.run_until_complete(async_test_home_assistant(loop))
stop_event = threading.Event() loop_stop_event = threading.Event()
def run_loop(): def run_loop():
"""Run event loop.""" """Run event loop."""
# pylint: disable=protected-access # pylint: disable=protected-access
loop._thread_ident = threading.get_ident() loop._thread_ident = threading.get_ident()
loop.run_forever() loop.run_forever()
stop_event.set() loop_stop_event.set()
orig_stop = hass.stop orig_stop = hass.stop
hass._stopped = Mock(set=loop.stop)
def start_hass(*mocks): def start_hass(*mocks):
"""Start hass.""" """Start hass."""
@ -135,7 +131,7 @@ def get_test_home_assistant():
def stop_hass(): def stop_hass():
"""Stop hass.""" """Stop hass."""
orig_stop() orig_stop()
stop_event.wait() loop_stop_event.wait()
loop.close() loop.close()
hass.start = start_hass hass.start = start_hass

View file

@ -38,7 +38,11 @@ import homeassistant.util.dt as dt_util
from homeassistant.util.unit_system import METRIC_SYSTEM from homeassistant.util.unit_system import METRIC_SYSTEM
from tests.async_mock import MagicMock, Mock, PropertyMock, patch from tests.async_mock import MagicMock, Mock, PropertyMock, patch
from tests.common import async_mock_service, get_test_home_assistant from tests.common import (
async_capture_events,
async_mock_service,
get_test_home_assistant,
)
PST = pytz.timezone("America/Los_Angeles") PST = pytz.timezone("America/Los_Angeles")
@ -151,22 +155,14 @@ def test_async_run_hass_job_delegates_non_async():
assert len(hass.async_add_hass_job.mock_calls) == 1 assert len(hass.async_add_hass_job.mock_calls) == 1
def test_stage_shutdown(): async def test_stage_shutdown(hass):
"""Simulate a shutdown, test calling stuff.""" """Simulate a shutdown, test calling stuff."""
hass = get_test_home_assistant() test_stop = async_capture_events(hass, EVENT_HOMEASSISTANT_STOP)
test_stop = [] test_final_write = async_capture_events(hass, EVENT_HOMEASSISTANT_FINAL_WRITE)
test_final_write = [] test_close = async_capture_events(hass, EVENT_HOMEASSISTANT_CLOSE)
test_close = [] test_all = async_capture_events(hass, MATCH_ALL)
test_all = []
hass.bus.listen(EVENT_HOMEASSISTANT_STOP, lambda event: test_stop.append(event)) await hass.async_stop()
hass.bus.listen(
EVENT_HOMEASSISTANT_FINAL_WRITE, lambda event: test_final_write.append(event)
)
hass.bus.listen(EVENT_HOMEASSISTANT_CLOSE, lambda event: test_close.append(event))
hass.bus.listen("*", lambda event: test_all.append(event))
hass.stop()
assert len(test_stop) == 1 assert len(test_stop) == 1
assert len(test_close) == 1 assert len(test_close) == 1
@ -341,40 +337,26 @@ def test_state_as_dict():
assert state.as_dict() is state.as_dict() assert state.as_dict() is state.as_dict()
class TestEventBus(unittest.TestCase): async def test_add_remove_listener(hass):
"""Test EventBus methods."""
# pylint: disable=invalid-name
def setUp(self):
"""Set up things to be run when tests are started."""
self.hass = get_test_home_assistant()
self.bus = self.hass.bus
# pylint: disable=invalid-name
def tearDown(self):
"""Stop down stuff we started."""
self.hass.stop()
def test_add_remove_listener(self):
"""Test remove_listener method.""" """Test remove_listener method."""
self.hass.allow_pool = False old_count = len(hass.bus.async_listeners())
old_count = len(self.bus.listeners)
def listener(_): def listener(_):
pass pass
unsub = self.bus.listen("test", listener) unsub = hass.bus.async_listen("test", listener)
assert old_count + 1 == len(self.bus.listeners) assert old_count + 1 == len(hass.bus.async_listeners())
# Remove listener # Remove listener
unsub() unsub()
assert old_count == len(self.bus.listeners) assert old_count == len(hass.bus.async_listeners())
# Should do nothing now # Should do nothing now
unsub() unsub()
def test_unsubscribe_listener(self):
async def test_unsubscribe_listener(hass):
"""Test unsubscribe listener from returned function.""" """Test unsubscribe listener from returned function."""
calls = [] calls = []
@ -383,21 +365,22 @@ class TestEventBus(unittest.TestCase):
"""Mock listener.""" """Mock listener."""
calls.append(event) calls.append(event)
unsub = self.bus.listen("test", listener) unsub = hass.bus.async_listen("test", listener)
self.bus.fire("test") hass.bus.async_fire("test")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
unsub() unsub()
self.bus.fire("event") hass.bus.async_fire("event")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
def test_listen_once_event_with_callback(self):
async def test_listen_once_event_with_callback(hass):
"""Test listen_once_event method.""" """Test listen_once_event method."""
runs = [] runs = []
@ -405,60 +388,64 @@ class TestEventBus(unittest.TestCase):
def event_handler(event): def event_handler(event):
runs.append(event) runs.append(event)
self.bus.listen_once("test_event", event_handler) hass.bus.async_listen_once("test_event", event_handler)
self.bus.fire("test_event") hass.bus.async_fire("test_event")
# Second time it should not increase runs # Second time it should not increase runs
self.bus.fire("test_event") hass.bus.async_fire("test_event")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(runs) == 1 assert len(runs) == 1
def test_listen_once_event_with_coroutine(self):
async def test_listen_once_event_with_coroutine(hass):
"""Test listen_once_event method.""" """Test listen_once_event method."""
runs = [] runs = []
async def event_handler(event): async def event_handler(event):
runs.append(event) runs.append(event)
self.bus.listen_once("test_event", event_handler) hass.bus.async_listen_once("test_event", event_handler)
self.bus.fire("test_event") hass.bus.async_fire("test_event")
# Second time it should not increase runs # Second time it should not increase runs
self.bus.fire("test_event") hass.bus.async_fire("test_event")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(runs) == 1 assert len(runs) == 1
def test_listen_once_event_with_thread(self):
async def test_listen_once_event_with_thread(hass):
"""Test listen_once_event method.""" """Test listen_once_event method."""
runs = [] runs = []
def event_handler(event): def event_handler(event):
runs.append(event) runs.append(event)
self.bus.listen_once("test_event", event_handler) hass.bus.async_listen_once("test_event", event_handler)
self.bus.fire("test_event") hass.bus.async_fire("test_event")
# Second time it should not increase runs # Second time it should not increase runs
self.bus.fire("test_event") hass.bus.async_fire("test_event")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(runs) == 1 assert len(runs) == 1
def test_thread_event_listener(self):
async def test_thread_event_listener(hass):
"""Test thread event listener.""" """Test thread event listener."""
thread_calls = [] thread_calls = []
def thread_listener(event): def thread_listener(event):
thread_calls.append(event) thread_calls.append(event)
self.bus.listen("test_thread", thread_listener) hass.bus.async_listen("test_thread", thread_listener)
self.bus.fire("test_thread") hass.bus.async_fire("test_thread")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(thread_calls) == 1 assert len(thread_calls) == 1
def test_callback_event_listener(self):
async def test_callback_event_listener(hass):
"""Test callback event listener.""" """Test callback event listener."""
callback_calls = [] callback_calls = []
@ -466,21 +453,22 @@ class TestEventBus(unittest.TestCase):
def callback_listener(event): def callback_listener(event):
callback_calls.append(event) callback_calls.append(event)
self.bus.listen("test_callback", callback_listener) hass.bus.async_listen("test_callback", callback_listener)
self.bus.fire("test_callback") hass.bus.async_fire("test_callback")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(callback_calls) == 1 assert len(callback_calls) == 1
def test_coroutine_event_listener(self):
async def test_coroutine_event_listener(hass):
"""Test coroutine event listener.""" """Test coroutine event listener."""
coroutine_calls = [] coroutine_calls = []
async def coroutine_listener(event): async def coroutine_listener(event):
coroutine_calls.append(event) coroutine_calls.append(event)
self.bus.listen("test_coroutine", coroutine_listener) hass.bus.async_listen("test_coroutine", coroutine_listener)
self.bus.fire("test_coroutine") hass.bus.async_fire("test_coroutine")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(coroutine_calls) == 1 assert len(coroutine_calls) == 1