Move thread safety check in device_registry sooner (#116264)

It turns out we have custom components that are writing to the device registry using the async APIs from threads. We now catch it at the point async_fire is called. Instead we should check sooner and use async_fire_internal so we catch the unsafe operation before it can corrupt the registry.
This commit is contained in:
J. Nick Koston 2024-04-27 02:24:55 -05:00 committed by GitHub
parent 8bae614d4e
commit 09ebbfa0e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 51 additions and 2 deletions

View file

@ -904,6 +904,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
if not new_values: if not new_values:
return old return old
self.hass.verify_event_loop_thread("async_update_device")
new = attr.evolve(old, **new_values) new = attr.evolve(old, **new_values)
self.devices[device_id] = new self.devices[device_id] = new
@ -923,13 +924,14 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
else: else:
data = {"action": "update", "device_id": new.id, "changes": old_values} data = {"action": "update", "device_id": new.id, "changes": old_values}
self.hass.bus.async_fire(EVENT_DEVICE_REGISTRY_UPDATED, data) self.hass.bus.async_fire_internal(EVENT_DEVICE_REGISTRY_UPDATED, data)
return new return new
@callback @callback
def async_remove_device(self, device_id: str) -> None: def async_remove_device(self, device_id: str) -> None:
"""Remove a device from the device registry.""" """Remove a device from the device registry."""
self.hass.verify_event_loop_thread("async_remove_device")
device = self.devices.pop(device_id) device = self.devices.pop(device_id)
self.deleted_devices[device_id] = DeletedDeviceEntry( self.deleted_devices[device_id] = DeletedDeviceEntry(
config_entries=device.config_entries, config_entries=device.config_entries,
@ -941,7 +943,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
for other_device in list(self.devices.values()): for other_device in list(self.devices.values()):
if other_device.via_device_id == device_id: if other_device.via_device_id == device_id:
self.async_update_device(other_device.id, via_device_id=None) self.async_update_device(other_device.id, via_device_id=None)
self.hass.bus.async_fire( self.hass.bus.async_fire_internal(
EVENT_DEVICE_REGISTRY_UPDATED, EVENT_DEVICE_REGISTRY_UPDATED,
_EventDeviceRegistryUpdatedData_CreateRemove( _EventDeviceRegistryUpdatedData_CreateRemove(
action="remove", device_id=device_id action="remove", device_id=device_id

View file

@ -2,6 +2,7 @@
from collections.abc import Iterable from collections.abc import Iterable
from contextlib import AbstractContextManager, nullcontext from contextlib import AbstractContextManager, nullcontext
from functools import partial
import time import time
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import patch
@ -2473,3 +2474,49 @@ async def test_device_name_translation_placeholders_errors(
) )
assert expected_error in caplog.text assert expected_error in caplog.text
async def test_async_get_or_create_thread_safety(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test async_get_or_create raises when called from wrong thread."""
with pytest.raises(
RuntimeError,
match="Detected code that calls async_update_device from a thread. Please report this issue.",
):
await hass.async_add_executor_job(
partial(
device_registry.async_get_or_create,
config_entry_id=mock_config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
identifiers=set(),
manufacturer="manufacturer",
model="model",
)
)
async def test_async_remove_device_thread_safety(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test async_remove_device raises when called from wrong thread."""
device = device_registry.async_get_or_create(
config_entry_id=mock_config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
identifiers=set(),
manufacturer="manufacturer",
model="model",
)
with pytest.raises(
RuntimeError,
match="Detected code that calls async_remove_device from a thread. Please report this issue.",
):
await hass.async_add_executor_job(
device_registry.async_remove_device, device.id
)