Use singleton helper for registries (#117027)

This commit is contained in:
J. Nick Koston 2024-05-07 14:04:01 -05:00 committed by GitHub
parent 6e024d54f1
commit 26cc1cd3db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 27 additions and 18 deletions

View file

@ -18,6 +18,7 @@ from .normalized_name_base_registry import (
normalize_name, normalize_name,
) )
from .registry import BaseRegistry from .registry import BaseRegistry
from .singleton import singleton
from .storage import Store from .storage import Store
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
@ -417,16 +418,16 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
@callback @callback
@singleton(DATA_REGISTRY)
def async_get(hass: HomeAssistant) -> AreaRegistry: def async_get(hass: HomeAssistant) -> AreaRegistry:
"""Get area registry.""" """Get area registry."""
return hass.data[DATA_REGISTRY] return AreaRegistry(hass)
async def async_load(hass: HomeAssistant) -> None: async def async_load(hass: HomeAssistant) -> None:
"""Load area registry.""" """Load area registry."""
assert DATA_REGISTRY not in hass.data assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = AreaRegistry(hass) await async_get(hass).async_load()
await hass.data[DATA_REGISTRY].async_load()
@callback @callback

View file

@ -13,6 +13,7 @@ from homeassistant.util.hass_dict import HassKey
from homeassistant.util.ulid import ulid_now from homeassistant.util.ulid import ulid_now
from .registry import BaseRegistry from .registry import BaseRegistry
from .singleton import singleton
from .storage import Store from .storage import Store
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
@ -217,13 +218,13 @@ class CategoryRegistry(BaseRegistry[CategoryRegistryStoreData]):
@callback @callback
@singleton(DATA_REGISTRY)
def async_get(hass: HomeAssistant) -> CategoryRegistry: def async_get(hass: HomeAssistant) -> CategoryRegistry:
"""Get category registry.""" """Get category registry."""
return hass.data[DATA_REGISTRY] return CategoryRegistry(hass)
async def async_load(hass: HomeAssistant) -> None: async def async_load(hass: HomeAssistant) -> None:
"""Load category registry.""" """Load category registry."""
assert DATA_REGISTRY not in hass.data assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = CategoryRegistry(hass) await async_get(hass).async_load()
await hass.data[DATA_REGISTRY].async_load()

View file

@ -38,6 +38,7 @@ from .deprecation import (
from .frame import report from .frame import report
from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes, json_fragment from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes, json_fragment
from .registry import BaseRegistry, BaseRegistryItems from .registry import BaseRegistry, BaseRegistryItems
from .singleton import singleton
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
if TYPE_CHECKING: if TYPE_CHECKING:
@ -1077,16 +1078,16 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
@callback @callback
@singleton(DATA_REGISTRY)
def async_get(hass: HomeAssistant) -> DeviceRegistry: def async_get(hass: HomeAssistant) -> DeviceRegistry:
"""Get device registry.""" """Get device registry."""
return hass.data[DATA_REGISTRY] return DeviceRegistry(hass)
async def async_load(hass: HomeAssistant) -> None: async def async_load(hass: HomeAssistant) -> None:
"""Load device registry.""" """Load device registry."""
assert DATA_REGISTRY not in hass.data assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = DeviceRegistry(hass) await async_get(hass).async_load()
await hass.data[DATA_REGISTRY].async_load()
@callback @callback

View file

@ -59,6 +59,7 @@ from .device_registry import (
) )
from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes, json_fragment from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes, json_fragment
from .registry import BaseRegistry, BaseRegistryItems from .registry import BaseRegistry, BaseRegistryItems
from .singleton import singleton
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
if TYPE_CHECKING: if TYPE_CHECKING:
@ -1374,16 +1375,16 @@ class EntityRegistry(BaseRegistry):
@callback @callback
@singleton(DATA_REGISTRY)
def async_get(hass: HomeAssistant) -> EntityRegistry: def async_get(hass: HomeAssistant) -> EntityRegistry:
"""Get entity registry.""" """Get entity registry."""
return hass.data[DATA_REGISTRY] return EntityRegistry(hass)
async def async_load(hass: HomeAssistant) -> None: async def async_load(hass: HomeAssistant) -> None:
"""Load entity registry.""" """Load entity registry."""
assert DATA_REGISTRY not in hass.data assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = EntityRegistry(hass) await async_get(hass).async_load()
await hass.data[DATA_REGISTRY].async_load()
@callback @callback

View file

@ -18,6 +18,7 @@ from .normalized_name_base_registry import (
normalize_name, normalize_name,
) )
from .registry import BaseRegistry from .registry import BaseRegistry
from .singleton import singleton
from .storage import Store from .storage import Store
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
@ -239,13 +240,13 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
@callback @callback
@singleton(DATA_REGISTRY)
def async_get(hass: HomeAssistant) -> FloorRegistry: def async_get(hass: HomeAssistant) -> FloorRegistry:
"""Get floor registry.""" """Get floor registry."""
return hass.data[DATA_REGISTRY] return FloorRegistry(hass)
async def async_load(hass: HomeAssistant) -> None: async def async_load(hass: HomeAssistant) -> None:
"""Load floor registry.""" """Load floor registry."""
assert DATA_REGISTRY not in hass.data assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = FloorRegistry(hass) await async_get(hass).async_load()
await hass.data[DATA_REGISTRY].async_load()

View file

@ -18,6 +18,7 @@ from .normalized_name_base_registry import (
normalize_name, normalize_name,
) )
from .registry import BaseRegistry from .registry import BaseRegistry
from .singleton import singleton
from .storage import Store from .storage import Store
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
@ -240,13 +241,13 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
@callback @callback
@singleton(DATA_REGISTRY)
def async_get(hass: HomeAssistant) -> LabelRegistry: def async_get(hass: HomeAssistant) -> LabelRegistry:
"""Get label registry.""" """Get label registry."""
return hass.data[DATA_REGISTRY] return LabelRegistry(hass)
async def async_load(hass: HomeAssistant) -> None: async def async_load(hass: HomeAssistant) -> None:
"""Load label registry.""" """Load label registry."""
assert DATA_REGISTRY not in hass.data assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = LabelRegistry(hass) await async_get(hass).async_load()
await hass.data[DATA_REGISTRY].async_load()

View file

@ -631,6 +631,7 @@ def mock_registry(
registry.entities[key] = entry registry.entities[key] = entry
hass.data[er.DATA_REGISTRY] = registry hass.data[er.DATA_REGISTRY] = registry
er.async_get.cache_clear()
return registry return registry
@ -654,6 +655,7 @@ def mock_area_registry(
registry.areas[key] = entry registry.areas[key] = entry
hass.data[ar.DATA_REGISTRY] = registry hass.data[ar.DATA_REGISTRY] = registry
ar.async_get.cache_clear()
return registry return registry
@ -682,6 +684,7 @@ def mock_device_registry(
registry.deleted_devices = dr.DeviceRegistryItems() registry.deleted_devices = dr.DeviceRegistryItems()
hass.data[dr.DATA_REGISTRY] = registry hass.data[dr.DATA_REGISTRY] = registry
dr.async_get.cache_clear()
return registry return registry