Always do thread safety check when writing state (#118886)

* Always do thread safety check when writing state

Refactor the 3 most common places where the thread safety check
for the event loop to be inline to make the check fast enough
that we can keep it long term. While code review catches most
of the thread safety issues in core, some of them still make
it through, and new ones keep getting added. Its not possible
to catch them all with manual code review, so its worth the
tiny overhead to check each time.

Previously the check was limited to custom components
because they were the most common source of thread
safety issues.

* Always do thread safety check when writing state

Refactor the 3 most common places where the thread safety check
for the event loop to be inline to make the check fast enough
that we can keep it long term. While code review catches most
of the thread safety issues in core, some of them still make
it through, and new ones keep getting added. Its not possible
to catch them all with manual code review, so its worth the
tiny overhead to check each time.

Previously the check was limited to custom components
because they were the most common source of thread
safety issues.

* async_fire is more common than expected with ccs

* fix mock

* fix hass mocking
This commit is contained in:
J. Nick Koston 2024-06-05 22:41:55 -05:00 committed by GitHub
parent 64b23419e0
commit 475c20d529
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 34 additions and 33 deletions

View file

@ -434,25 +434,17 @@ class HomeAssistant:
self.import_executor = InterruptibleThreadPoolExecutor( self.import_executor = InterruptibleThreadPoolExecutor(
max_workers=1, thread_name_prefix="ImportExecutor" max_workers=1, thread_name_prefix="ImportExecutor"
) )
self._loop_thread_id = getattr( self.loop_thread_id = getattr(
self.loop, "_thread_ident", getattr(self.loop, "_thread_id") self.loop, "_thread_ident", getattr(self.loop, "_thread_id")
) )
def verify_event_loop_thread(self, what: str) -> None: def verify_event_loop_thread(self, what: str) -> None:
"""Report and raise if we are not running in the event loop thread.""" """Report and raise if we are not running in the event loop thread."""
if self._loop_thread_id != threading.get_ident(): if self.loop_thread_id != threading.get_ident():
# frame is a circular import, so we import it here
from .helpers import frame # pylint: disable=import-outside-toplevel from .helpers import frame # pylint: disable=import-outside-toplevel
# frame is a circular import, so we import it here frame.report_non_thread_safe_operation(what)
frame.report(
f"calls {what} from a thread other than the event loop, "
"which may cause Home Assistant to crash or data to corrupt. "
"For more information, see "
"https://developers.home-assistant.io/docs/asyncio_thread_safety/"
f"#{what.replace('.', '')}",
error_if_core=True,
error_if_integration=True,
)
@property @property
def _active_tasks(self) -> set[asyncio.Future[Any]]: def _active_tasks(self) -> set[asyncio.Future[Any]]:
@ -793,16 +785,10 @@ class HomeAssistant:
target: target to call. target: target to call.
""" """
# We turned on asyncio debug in April 2024 in the dev containers if self.loop_thread_id != threading.get_ident():
# in the hope of catching some of the issues that have been from .helpers import frame # pylint: disable=import-outside-toplevel
# reported. It will take a while to get all the issues fixed in
# custom components. frame.report_non_thread_safe_operation("hass.async_create_task")
#
# In 2025.5 we should guard the `verify_event_loop_thread`
# check with a check for the `hass.config.debug` flag being set as
# long term we don't want to be checking this in production
# environments since it is a performance hit.
self.verify_event_loop_thread("hass.async_create_task")
return self.async_create_task_internal(target, name, eager_start) return self.async_create_task_internal(target, name, eager_start)
@callback @callback
@ -1497,7 +1483,10 @@ class EventBus:
This method must be run in the event loop. This method must be run in the event loop.
""" """
_verify_event_type_length_or_raise(event_type) _verify_event_type_length_or_raise(event_type)
self._hass.verify_event_loop_thread("hass.bus.async_fire") if self._hass.loop_thread_id != threading.get_ident():
from .helpers import frame # pylint: disable=import-outside-toplevel
frame.report_non_thread_safe_operation("hass.bus.async_fire")
return self.async_fire_internal( return self.async_fire_internal(
event_type, event_data, origin, context, time_fired event_type, event_data, origin, context, time_fired
) )

View file

@ -14,6 +14,7 @@ import logging
import math import math
from operator import attrgetter from operator import attrgetter
import sys import sys
import threading
import time import time
from types import FunctionType from types import FunctionType
from typing import TYPE_CHECKING, Any, Final, Literal, NotRequired, TypedDict, final from typing import TYPE_CHECKING, Any, Final, Literal, NotRequired, TypedDict, final
@ -63,6 +64,7 @@ from .event import (
async_track_device_registry_updated_event, async_track_device_registry_updated_event,
async_track_entity_registry_updated_event, async_track_entity_registry_updated_event,
) )
from .frame import report_non_thread_safe_operation
from .typing import UNDEFINED, StateType, UndefinedType from .typing import UNDEFINED, StateType, UndefinedType
timer = time.time timer = time.time
@ -512,7 +514,6 @@ class Entity(
# While not purely typed, it makes typehinting more useful for us # While not purely typed, it makes typehinting more useful for us
# and removes the need for constant None checks or asserts. # and removes the need for constant None checks or asserts.
_state_info: StateInfo = None # type: ignore[assignment] _state_info: StateInfo = None # type: ignore[assignment]
_is_custom_component: bool = False
__capabilities_updated_at: deque[float] __capabilities_updated_at: deque[float]
__capabilities_updated_at_reported: bool = False __capabilities_updated_at_reported: bool = False
@ -995,8 +996,8 @@ class Entity(
def async_write_ha_state(self) -> None: def async_write_ha_state(self) -> None:
"""Write the state to the state machine.""" """Write the state to the state machine."""
self._async_verify_state_writable() self._async_verify_state_writable()
if self._is_custom_component or self.hass.config.debug: if self.hass.loop_thread_id != threading.get_ident():
self.hass.verify_event_loop_thread("async_write_ha_state") report_non_thread_safe_operation("async_write_ha_state")
self._async_write_ha_state() self._async_write_ha_state()
def _stringify_state(self, available: bool) -> str: def _stringify_state(self, available: bool) -> str:
@ -1440,8 +1441,6 @@ class Entity(
"domain": self.platform.platform_name, "domain": self.platform.platform_name,
"custom_component": is_custom_component, "custom_component": is_custom_component,
} }
self._is_custom_component = is_custom_component
if self.platform.config_entry: if self.platform.config_entry:
entity_info["config_entry"] = self.platform.config_entry.entry_id entity_info["config_entry"] = self.platform.config_entry.entry_id

View file

@ -218,3 +218,16 @@ def warn_use[_CallableT: Callable](func: _CallableT, what: str) -> _CallableT:
report(what) report(what)
return cast(_CallableT, report_use) return cast(_CallableT, report_use)
def report_non_thread_safe_operation(what: str) -> None:
"""Report a non-thread safe operation."""
report(
f"calls {what} from a thread other than the event loop, "
"which may cause Home Assistant to crash or data to corrupt. "
"For more information, see "
"https://developers.home-assistant.io/docs/asyncio_thread_safety/"
f"#{what.replace('.', '')}",
error_if_core=True,
error_if_integration=True,
)

View file

@ -174,7 +174,7 @@ def get_test_home_assistant() -> Generator[HomeAssistant, None, None]:
"""Run event loop.""" """Run event loop."""
loop._thread_ident = threading.get_ident() loop._thread_ident = threading.get_ident()
hass._loop_thread_id = loop._thread_ident hass.loop_thread_id = loop._thread_ident
loop.run_forever() loop.run_forever()
loop_stop_event.set() loop_stop_event.set()

View file

@ -3,6 +3,7 @@
from collections.abc import Callable from collections.abc import Callable
import logging import logging
import math import math
import threading
from types import NoneType from types import NoneType
from unittest import mock from unittest import mock
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
@ -86,6 +87,7 @@ def endpoint(zigpy_coordinator_device):
type(endpoint_mock.device).skip_configuration = mock.PropertyMock( type(endpoint_mock.device).skip_configuration = mock.PropertyMock(
return_value=False return_value=False
) )
endpoint_mock.device.hass.loop_thread_id = threading.get_ident()
endpoint_mock.id = 1 endpoint_mock.id = 1
return endpoint_mock return endpoint_mock

View file

@ -2617,13 +2617,12 @@ async def test_async_write_ha_state_thread_safety(hass: HomeAssistant) -> None:
assert not hass.states.get(ent2.entity_id) assert not hass.states.get(ent2.entity_id)
async def test_async_write_ha_state_thread_safety_custom_component( async def test_async_write_ha_state_thread_safety_always(
hass: HomeAssistant, hass: HomeAssistant,
) -> None: ) -> None:
"""Test async_write_ha_state thread safe for custom components.""" """Test async_write_ha_state thread safe check."""
ent = entity.Entity() ent = entity.Entity()
ent._is_custom_component = True
ent.entity_id = "test.any" ent.entity_id = "test.any"
ent.hass = hass ent.hass = hass
ent.platform = MockEntityPlatform(hass, domain="test") ent.platform = MockEntityPlatform(hass, domain="test")
@ -2631,7 +2630,6 @@ async def test_async_write_ha_state_thread_safety_custom_component(
assert hass.states.get(ent.entity_id) assert hass.states.get(ent.entity_id)
ent2 = entity.Entity() ent2 = entity.Entity()
ent2._is_custom_component = True
ent2.entity_id = "test.any2" ent2.entity_id = "test.any2"
ent2.hass = hass ent2.hass = hass
ent2.platform = MockEntityPlatform(hass, domain="test") ent2.platform = MockEntityPlatform(hass, domain="test")