diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index aec5dbc6c4a..6b653784824 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -904,6 +904,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): if not new_values: return old + self.hass.verify_event_loop_thread("async_update_device") new = attr.evolve(old, **new_values) self.devices[device_id] = new @@ -923,13 +924,14 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): else: data = {"action": "update", "device_id": new.id, "changes": old_values} - self.hass.bus.async_fire(EVENT_DEVICE_REGISTRY_UPDATED, data) + self.hass.bus.async_fire_internal(EVENT_DEVICE_REGISTRY_UPDATED, data) return new @callback def async_remove_device(self, device_id: str) -> None: """Remove a device from the device registry.""" + self.hass.verify_event_loop_thread("async_remove_device") device = self.devices.pop(device_id) self.deleted_devices[device_id] = DeletedDeviceEntry( config_entries=device.config_entries, @@ -941,7 +943,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): for other_device in list(self.devices.values()): if other_device.via_device_id == device_id: self.async_update_device(other_device.id, via_device_id=None) - self.hass.bus.async_fire( + self.hass.bus.async_fire_internal( EVENT_DEVICE_REGISTRY_UPDATED, _EventDeviceRegistryUpdatedData_CreateRemove( action="remove", device_id=device_id diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index ee895e3fd3e..6b167f8ee49 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -2,6 +2,7 @@ from collections.abc import Iterable from contextlib import AbstractContextManager, nullcontext +from functools import partial import time from typing import Any from unittest.mock import patch @@ -2473,3 +2474,49 @@ async def test_device_name_translation_placeholders_errors( ) assert expected_error in caplog.text + + +async def test_async_get_or_create_thread_safety( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + mock_config_entry: MockConfigEntry, +) -> None: + """Test async_get_or_create raises when called from wrong thread.""" + + with pytest.raises( + RuntimeError, + match="Detected code that calls async_update_device from a thread. Please report this issue.", + ): + await hass.async_add_executor_job( + partial( + device_registry.async_get_or_create, + config_entry_id=mock_config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + identifiers=set(), + manufacturer="manufacturer", + model="model", + ) + ) + + +async def test_async_remove_device_thread_safety( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + mock_config_entry: MockConfigEntry, +) -> None: + """Test async_remove_device raises when called from wrong thread.""" + device = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + identifiers=set(), + manufacturer="manufacturer", + model="model", + ) + + with pytest.raises( + RuntimeError, + match="Detected code that calls async_remove_device from a thread. Please report this issue.", + ): + await hass.async_add_executor_job( + device_registry.async_remove_device, device.id + )