Move thread safety check in device_registry sooner (#116264)
It turns out we have custom components that are writing to the device registry using the async APIs from threads. We now catch it at the point async_fire is called. Instead we should check sooner and use async_fire_internal so we catch the unsafe operation before it can corrupt the registry.
This commit is contained in:
parent
8bae614d4e
commit
09ebbfa0e1
2 changed files with 51 additions and 2 deletions
|
@ -904,6 +904,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
||||||
if not new_values:
|
if not new_values:
|
||||||
return old
|
return old
|
||||||
|
|
||||||
|
self.hass.verify_event_loop_thread("async_update_device")
|
||||||
new = attr.evolve(old, **new_values)
|
new = attr.evolve(old, **new_values)
|
||||||
self.devices[device_id] = new
|
self.devices[device_id] = new
|
||||||
|
|
||||||
|
@ -923,13 +924,14 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
||||||
else:
|
else:
|
||||||
data = {"action": "update", "device_id": new.id, "changes": old_values}
|
data = {"action": "update", "device_id": new.id, "changes": old_values}
|
||||||
|
|
||||||
self.hass.bus.async_fire(EVENT_DEVICE_REGISTRY_UPDATED, data)
|
self.hass.bus.async_fire_internal(EVENT_DEVICE_REGISTRY_UPDATED, data)
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_remove_device(self, device_id: str) -> None:
|
def async_remove_device(self, device_id: str) -> None:
|
||||||
"""Remove a device from the device registry."""
|
"""Remove a device from the device registry."""
|
||||||
|
self.hass.verify_event_loop_thread("async_remove_device")
|
||||||
device = self.devices.pop(device_id)
|
device = self.devices.pop(device_id)
|
||||||
self.deleted_devices[device_id] = DeletedDeviceEntry(
|
self.deleted_devices[device_id] = DeletedDeviceEntry(
|
||||||
config_entries=device.config_entries,
|
config_entries=device.config_entries,
|
||||||
|
@ -941,7 +943,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
||||||
for other_device in list(self.devices.values()):
|
for other_device in list(self.devices.values()):
|
||||||
if other_device.via_device_id == device_id:
|
if other_device.via_device_id == device_id:
|
||||||
self.async_update_device(other_device.id, via_device_id=None)
|
self.async_update_device(other_device.id, via_device_id=None)
|
||||||
self.hass.bus.async_fire(
|
self.hass.bus.async_fire_internal(
|
||||||
EVENT_DEVICE_REGISTRY_UPDATED,
|
EVENT_DEVICE_REGISTRY_UPDATED,
|
||||||
_EventDeviceRegistryUpdatedData_CreateRemove(
|
_EventDeviceRegistryUpdatedData_CreateRemove(
|
||||||
action="remove", device_id=device_id
|
action="remove", device_id=device_id
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from contextlib import AbstractContextManager, nullcontext
|
from contextlib import AbstractContextManager, nullcontext
|
||||||
|
from functools import partial
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
@ -2473,3 +2474,49 @@ async def test_device_name_translation_placeholders_errors(
|
||||||
)
|
)
|
||||||
|
|
||||||
assert expected_error in caplog.text
|
assert expected_error in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_get_or_create_thread_safety(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
) -> None:
|
||||||
|
"""Test async_get_or_create raises when called from wrong thread."""
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError,
|
||||||
|
match="Detected code that calls async_update_device from a thread. Please report this issue.",
|
||||||
|
):
|
||||||
|
await hass.async_add_executor_job(
|
||||||
|
partial(
|
||||||
|
device_registry.async_get_or_create,
|
||||||
|
config_entry_id=mock_config_entry.entry_id,
|
||||||
|
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||||
|
identifiers=set(),
|
||||||
|
manufacturer="manufacturer",
|
||||||
|
model="model",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_remove_device_thread_safety(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
) -> None:
|
||||||
|
"""Test async_remove_device raises when called from wrong thread."""
|
||||||
|
device = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=mock_config_entry.entry_id,
|
||||||
|
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||||
|
identifiers=set(),
|
||||||
|
manufacturer="manufacturer",
|
||||||
|
model="model",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError,
|
||||||
|
match="Detected code that calls async_remove_device from a thread. Please report this issue.",
|
||||||
|
):
|
||||||
|
await hass.async_add_executor_job(
|
||||||
|
device_registry.async_remove_device, device.id
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue