diff --git a/homeassistant/helpers/area_registry.py b/homeassistant/helpers/area_registry.py index 96200c7b43a..56d6b8be224 100644 --- a/homeassistant/helpers/area_registry.py +++ b/homeassistant/helpers/area_registry.py @@ -18,6 +18,7 @@ from .normalized_name_base_registry import ( normalize_name, ) from .registry import BaseRegistry +from .singleton import singleton from .storage import Store from .typing import UNDEFINED, UndefinedType @@ -417,16 +418,16 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]): @callback +@singleton(DATA_REGISTRY) def async_get(hass: HomeAssistant) -> AreaRegistry: """Get area registry.""" - return hass.data[DATA_REGISTRY] + return AreaRegistry(hass) async def async_load(hass: HomeAssistant) -> None: """Load area registry.""" assert DATA_REGISTRY not in hass.data - hass.data[DATA_REGISTRY] = AreaRegistry(hass) - await hass.data[DATA_REGISTRY].async_load() + await async_get(hass).async_load() @callback diff --git a/homeassistant/helpers/category_registry.py b/homeassistant/helpers/category_registry.py index dafb81d02ce..b0a465314f7 100644 --- a/homeassistant/helpers/category_registry.py +++ b/homeassistant/helpers/category_registry.py @@ -13,6 +13,7 @@ from homeassistant.util.hass_dict import HassKey from homeassistant.util.ulid import ulid_now from .registry import BaseRegistry +from .singleton import singleton from .storage import Store from .typing import UNDEFINED, UndefinedType @@ -217,13 +218,13 @@ class CategoryRegistry(BaseRegistry[CategoryRegistryStoreData]): @callback +@singleton(DATA_REGISTRY) def async_get(hass: HomeAssistant) -> CategoryRegistry: """Get category registry.""" - return hass.data[DATA_REGISTRY] + return CategoryRegistry(hass) async def async_load(hass: HomeAssistant) -> None: """Load category registry.""" assert DATA_REGISTRY not in hass.data - hass.data[DATA_REGISTRY] = CategoryRegistry(hass) - await hass.data[DATA_REGISTRY].async_load() + await async_get(hass).async_load() diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index e32f2b77284..3a7ef2f2352 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -38,6 +38,7 @@ from .deprecation import ( from .frame import report from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes, json_fragment from .registry import BaseRegistry, BaseRegistryItems +from .singleton import singleton from .typing import UNDEFINED, UndefinedType if TYPE_CHECKING: @@ -1077,16 +1078,16 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): @callback +@singleton(DATA_REGISTRY) def async_get(hass: HomeAssistant) -> DeviceRegistry: """Get device registry.""" - return hass.data[DATA_REGISTRY] + return DeviceRegistry(hass) async def async_load(hass: HomeAssistant) -> None: """Load device registry.""" assert DATA_REGISTRY not in hass.data - hass.data[DATA_REGISTRY] = DeviceRegistry(hass) - await hass.data[DATA_REGISTRY].async_load() + await async_get(hass).async_load() @callback diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index ac41326ed95..ac2307feea5 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -59,6 +59,7 @@ from .device_registry import ( ) from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes, json_fragment from .registry import BaseRegistry, BaseRegistryItems +from .singleton import singleton from .typing import UNDEFINED, UndefinedType if TYPE_CHECKING: @@ -1374,16 +1375,16 @@ class EntityRegistry(BaseRegistry): @callback +@singleton(DATA_REGISTRY) def async_get(hass: HomeAssistant) -> EntityRegistry: """Get entity registry.""" - return hass.data[DATA_REGISTRY] + return EntityRegistry(hass) async def async_load(hass: HomeAssistant) -> None: """Load entity registry.""" assert DATA_REGISTRY not in hass.data - hass.data[DATA_REGISTRY] = EntityRegistry(hass) - await hass.data[DATA_REGISTRY].async_load() + await async_get(hass).async_load() @callback diff --git a/homeassistant/helpers/floor_registry.py b/homeassistant/helpers/floor_registry.py index ad17d214b44..63d3bb56100 100644 --- a/homeassistant/helpers/floor_registry.py +++ b/homeassistant/helpers/floor_registry.py @@ -18,6 +18,7 @@ from .normalized_name_base_registry import ( normalize_name, ) from .registry import BaseRegistry +from .singleton import singleton from .storage import Store from .typing import UNDEFINED, UndefinedType @@ -239,13 +240,13 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]): @callback +@singleton(DATA_REGISTRY) def async_get(hass: HomeAssistant) -> FloorRegistry: """Get floor registry.""" - return hass.data[DATA_REGISTRY] + return FloorRegistry(hass) async def async_load(hass: HomeAssistant) -> None: """Load floor registry.""" assert DATA_REGISTRY not in hass.data - hass.data[DATA_REGISTRY] = FloorRegistry(hass) - await hass.data[DATA_REGISTRY].async_load() + await async_get(hass).async_load() diff --git a/homeassistant/helpers/label_registry.py b/homeassistant/helpers/label_registry.py index 8be63257de3..5c9b1eb066e 100644 --- a/homeassistant/helpers/label_registry.py +++ b/homeassistant/helpers/label_registry.py @@ -18,6 +18,7 @@ from .normalized_name_base_registry import ( normalize_name, ) from .registry import BaseRegistry +from .singleton import singleton from .storage import Store from .typing import UNDEFINED, UndefinedType @@ -240,13 +241,13 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]): @callback +@singleton(DATA_REGISTRY) def async_get(hass: HomeAssistant) -> LabelRegistry: """Get label registry.""" - return hass.data[DATA_REGISTRY] + return LabelRegistry(hass) async def async_load(hass: HomeAssistant) -> None: """Load label registry.""" assert DATA_REGISTRY not in hass.data - hass.data[DATA_REGISTRY] = LabelRegistry(hass) - await hass.data[DATA_REGISTRY].async_load() + await async_get(hass).async_load() diff --git a/tests/common.py b/tests/common.py index 8e220f59215..41b79f29475 100644 --- a/tests/common.py +++ b/tests/common.py @@ -631,6 +631,7 @@ def mock_registry( registry.entities[key] = entry hass.data[er.DATA_REGISTRY] = registry + er.async_get.cache_clear() return registry @@ -654,6 +655,7 @@ def mock_area_registry( registry.areas[key] = entry hass.data[ar.DATA_REGISTRY] = registry + ar.async_get.cache_clear() return registry @@ -682,6 +684,7 @@ def mock_device_registry( registry.deleted_devices = dr.DeviceRegistryItems() hass.data[dr.DATA_REGISTRY] = registry + dr.async_get.cache_clear() return registry