Complete device and entity registry type hints (#44406)

This commit is contained in:
Ville Skyttä 2021-01-05 03:03:16 +02:00 committed by GitHub
parent d315ab2cf5
commit 65e56d03bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 133 additions and 136 deletions

View file

@ -11,13 +11,11 @@ import homeassistant.util.uuid as uuid_util
from .debounce import Debouncer from .debounce import Debouncer
from .singleton import singleton from .singleton import singleton
from .typing import UNDEFINED, HomeAssistantType from .typing import UNDEFINED, HomeAssistantType, UndefinedType
if TYPE_CHECKING: if TYPE_CHECKING:
from . import entity_registry from . import entity_registry
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DATA_REGISTRY = "device_registry" DATA_REGISTRY = "device_registry"
@ -40,26 +38,6 @@ DISABLED_INTEGRATION = "integration"
DISABLED_USER = "user" DISABLED_USER = "user"
@attr.s(slots=True, frozen=True)
class DeletedDeviceEntry:
"""Deleted Device Registry Entry."""
config_entries: Set[str] = attr.ib()
connections: Set[Tuple[str, str]] = attr.ib()
identifiers: Set[Tuple[str, str]] = attr.ib()
id: str = attr.ib()
def to_device_entry(self, config_entry_id, connections, identifiers):
"""Create DeviceEntry from DeletedDeviceEntry."""
return DeviceEntry(
config_entries={config_entry_id},
connections=self.connections & connections,
identifiers=self.identifiers & identifiers,
id=self.id,
is_new=True,
)
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class DeviceEntry: class DeviceEntry:
"""Device Registry Entry.""" """Device Registry Entry."""
@ -67,14 +45,14 @@ class DeviceEntry:
config_entries: Set[str] = attr.ib(converter=set, factory=set) config_entries: Set[str] = attr.ib(converter=set, factory=set)
connections: Set[Tuple[str, str]] = attr.ib(converter=set, factory=set) connections: Set[Tuple[str, str]] = attr.ib(converter=set, factory=set)
identifiers: Set[Tuple[str, str]] = attr.ib(converter=set, factory=set) identifiers: Set[Tuple[str, str]] = attr.ib(converter=set, factory=set)
manufacturer: str = attr.ib(default=None) manufacturer: Optional[str] = attr.ib(default=None)
model: str = attr.ib(default=None) model: Optional[str] = attr.ib(default=None)
name: str = attr.ib(default=None) name: Optional[str] = attr.ib(default=None)
sw_version: str = attr.ib(default=None) sw_version: Optional[str] = attr.ib(default=None)
via_device_id: str = attr.ib(default=None) via_device_id: Optional[str] = attr.ib(default=None)
area_id: str = attr.ib(default=None) area_id: Optional[str] = attr.ib(default=None)
name_by_user: str = attr.ib(default=None) name_by_user: Optional[str] = attr.ib(default=None)
entry_type: str = attr.ib(default=None) entry_type: Optional[str] = attr.ib(default=None)
id: str = attr.ib(factory=uuid_util.random_uuid_hex) id: str = attr.ib(factory=uuid_util.random_uuid_hex)
# This value is not stored, just used to keep track of events to fire. # This value is not stored, just used to keep track of events to fire.
is_new: bool = attr.ib(default=False) is_new: bool = attr.ib(default=False)
@ -95,6 +73,32 @@ class DeviceEntry:
return self.disabled_by is not None return self.disabled_by is not None
@attr.s(slots=True, frozen=True)
class DeletedDeviceEntry:
"""Deleted Device Registry Entry."""
config_entries: Set[str] = attr.ib()
connections: Set[Tuple[str, str]] = attr.ib()
identifiers: Set[Tuple[str, str]] = attr.ib()
id: str = attr.ib()
def to_device_entry(
self,
config_entry_id: str,
connections: Set[Tuple[str, str]],
identifiers: Set[Tuple[str, str]],
) -> DeviceEntry:
"""Create DeviceEntry from DeletedDeviceEntry."""
return DeviceEntry(
# type ignores: likely https://github.com/python/mypy/issues/8625
config_entries={config_entry_id}, # type: ignore[arg-type]
connections=self.connections & connections, # type: ignore[arg-type]
identifiers=self.identifiers & identifiers, # type: ignore[arg-type]
id=self.id,
is_new=True,
)
def format_mac(mac: str) -> str: def format_mac(mac: str) -> str:
"""Format the mac address string for entry into dev reg.""" """Format the mac address string for entry into dev reg."""
to_test = mac to_test = mac
@ -201,40 +205,40 @@ class DeviceRegistry:
_remove_device_from_index(devices_index, old_device) _remove_device_from_index(devices_index, old_device)
_add_device_to_index(devices_index, new_device) _add_device_to_index(devices_index, new_device)
def _clear_index(self): def _clear_index(self) -> None:
"""Clear the index.""" """Clear the index."""
self._devices_index = { self._devices_index = {
REGISTERED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}}, REGISTERED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}},
DELETED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}}, DELETED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}},
} }
def _rebuild_index(self): def _rebuild_index(self) -> None:
"""Create the index after loading devices.""" """Create the index after loading devices."""
self._clear_index() self._clear_index()
for device in self.devices.values(): for device in self.devices.values():
_add_device_to_index(self._devices_index[REGISTERED_DEVICE], device) _add_device_to_index(self._devices_index[REGISTERED_DEVICE], device)
for device in self.deleted_devices.values(): for deleted_device in self.deleted_devices.values():
_add_device_to_index(self._devices_index[DELETED_DEVICE], device) _add_device_to_index(self._devices_index[DELETED_DEVICE], deleted_device)
@callback @callback
def async_get_or_create( def async_get_or_create(
self, self,
*, *,
config_entry_id, config_entry_id: str,
connections=None, connections: Optional[set] = None,
identifiers=None, identifiers: Optional[set] = None,
manufacturer=UNDEFINED, manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
model=UNDEFINED, model: Union[str, None, UndefinedType] = UNDEFINED,
name=UNDEFINED, name: Union[str, None, UndefinedType] = UNDEFINED,
default_manufacturer=UNDEFINED, default_manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
default_model=UNDEFINED, default_model: Union[str, None, UndefinedType] = UNDEFINED,
default_name=UNDEFINED, default_name: Union[str, None, UndefinedType] = UNDEFINED,
sw_version=UNDEFINED, sw_version: Union[str, None, UndefinedType] = UNDEFINED,
entry_type=UNDEFINED, entry_type: Union[str, None, UndefinedType] = UNDEFINED,
via_device=None, via_device: Optional[str] = None,
# To disable a device if it gets created # To disable a device if it gets created
disabled_by=UNDEFINED, disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
): ) -> Optional[DeviceEntry]:
"""Get device. Create if it doesn't exist.""" """Get device. Create if it doesn't exist."""
if not identifiers and not connections: if not identifiers and not connections:
return None return None
@ -271,7 +275,7 @@ class DeviceRegistry:
if via_device is not None: if via_device is not None:
via = self.async_get_device({via_device}, set()) via = self.async_get_device({via_device}, set())
via_device_id = via.id if via else UNDEFINED via_device_id: Union[str, UndefinedType] = via.id if via else UNDEFINED
else: else:
via_device_id = UNDEFINED via_device_id = UNDEFINED
@ -292,19 +296,19 @@ class DeviceRegistry:
@callback @callback
def async_update_device( def async_update_device(
self, self,
device_id, device_id: str,
*, *,
area_id=UNDEFINED, area_id: Union[str, None, UndefinedType] = UNDEFINED,
manufacturer=UNDEFINED, manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
model=UNDEFINED, model: Union[str, None, UndefinedType] = UNDEFINED,
name=UNDEFINED, name: Union[str, None, UndefinedType] = UNDEFINED,
name_by_user=UNDEFINED, name_by_user: Union[str, None, UndefinedType] = UNDEFINED,
new_identifiers=UNDEFINED, new_identifiers: Union[set, UndefinedType] = UNDEFINED,
sw_version=UNDEFINED, sw_version: Union[str, None, UndefinedType] = UNDEFINED,
via_device_id=UNDEFINED, via_device_id: Union[str, None, UndefinedType] = UNDEFINED,
remove_config_entry_id=UNDEFINED, remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
disabled_by=UNDEFINED, disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
): ) -> Optional[DeviceEntry]:
"""Update properties of a device.""" """Update properties of a device."""
return self._async_update_device( return self._async_update_device(
device_id, device_id,
@ -323,27 +327,27 @@ class DeviceRegistry:
@callback @callback
def _async_update_device( def _async_update_device(
self, self,
device_id, device_id: str,
*, *,
add_config_entry_id=UNDEFINED, add_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
remove_config_entry_id=UNDEFINED, remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
merge_connections=UNDEFINED, merge_connections: Union[set, UndefinedType] = UNDEFINED,
merge_identifiers=UNDEFINED, merge_identifiers: Union[set, UndefinedType] = UNDEFINED,
new_identifiers=UNDEFINED, new_identifiers: Union[set, UndefinedType] = UNDEFINED,
manufacturer=UNDEFINED, manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
model=UNDEFINED, model: Union[str, None, UndefinedType] = UNDEFINED,
name=UNDEFINED, name: Union[str, None, UndefinedType] = UNDEFINED,
sw_version=UNDEFINED, sw_version: Union[str, None, UndefinedType] = UNDEFINED,
entry_type=UNDEFINED, entry_type: Union[str, None, UndefinedType] = UNDEFINED,
via_device_id=UNDEFINED, via_device_id: Union[str, None, UndefinedType] = UNDEFINED,
area_id=UNDEFINED, area_id: Union[str, None, UndefinedType] = UNDEFINED,
name_by_user=UNDEFINED, name_by_user: Union[str, None, UndefinedType] = UNDEFINED,
disabled_by=UNDEFINED, disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
): ) -> Optional[DeviceEntry]:
"""Update device attributes.""" """Update device attributes."""
old = self.devices[device_id] old = self.devices[device_id]
changes = {} changes: Dict[str, Any] = {}
config_entries = old.config_entries config_entries = old.config_entries
@ -359,21 +363,21 @@ class DeviceRegistry:
): ):
if config_entries == {remove_config_entry_id}: if config_entries == {remove_config_entry_id}:
self.async_remove_device(device_id) self.async_remove_device(device_id)
return return None
config_entries = config_entries - {remove_config_entry_id} config_entries = config_entries - {remove_config_entry_id}
if config_entries != old.config_entries: if config_entries != old.config_entries:
changes["config_entries"] = config_entries changes["config_entries"] = config_entries
for attr_name, value in ( for attr_name, setvalue in (
("connections", merge_connections), ("connections", merge_connections),
("identifiers", merge_identifiers), ("identifiers", merge_identifiers),
): ):
old_value = getattr(old, attr_name) old_value = getattr(old, attr_name)
# If not undefined, check if `value` contains new items. # If not undefined, check if `value` contains new items.
if value is not UNDEFINED and not value.issubset(old_value): if setvalue is not UNDEFINED and not setvalue.issubset(old_value):
changes[attr_name] = old_value | value changes[attr_name] = old_value | setvalue
if new_identifiers is not UNDEFINED: if new_identifiers is not UNDEFINED:
changes["identifiers"] = new_identifiers changes["identifiers"] = new_identifiers
@ -434,7 +438,7 @@ class DeviceRegistry:
) )
self.async_schedule_save() self.async_schedule_save()
async def async_load(self): async def async_load(self) -> None:
"""Load the device registry.""" """Load the device registry."""
async_setup_cleanup(self.hass, self) async_setup_cleanup(self.hass, self)
@ -447,8 +451,9 @@ class DeviceRegistry:
for device in data["devices"]: for device in data["devices"]:
devices[device["id"]] = DeviceEntry( devices[device["id"]] = DeviceEntry(
config_entries=set(device["config_entries"]), config_entries=set(device["config_entries"]),
connections={tuple(conn) for conn in device["connections"]}, # type ignores (if tuple arg was cast): likely https://github.com/python/mypy/issues/8625
identifiers={tuple(iden) for iden in device["identifiers"]}, connections={tuple(conn) for conn in device["connections"]}, # type: ignore[misc]
identifiers={tuple(iden) for iden in device["identifiers"]}, # type: ignore[misc]
manufacturer=device["manufacturer"], manufacturer=device["manufacturer"],
model=device["model"], model=device["model"],
name=device["name"], name=device["name"],
@ -471,8 +476,9 @@ class DeviceRegistry:
for device in data.get("deleted_devices", []): for device in data.get("deleted_devices", []):
deleted_devices[device["id"]] = DeletedDeviceEntry( deleted_devices[device["id"]] = DeletedDeviceEntry(
config_entries=set(device["config_entries"]), config_entries=set(device["config_entries"]),
connections={tuple(conn) for conn in device["connections"]}, # type ignores (if tuple arg was cast): likely https://github.com/python/mypy/issues/8625
identifiers={tuple(iden) for iden in device["identifiers"]}, connections={tuple(conn) for conn in device["connections"]}, # type: ignore[misc]
identifiers={tuple(iden) for iden in device["identifiers"]}, # type: ignore[misc]
id=device["id"], id=device["id"],
) )
@ -614,7 +620,7 @@ def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> Non
"""Clean up device registry when entities removed.""" """Clean up device registry when entities removed."""
from . import entity_registry # pylint: disable=import-outside-toplevel from . import entity_registry # pylint: disable=import-outside-toplevel
async def cleanup(): async def cleanup() -> None:
"""Cleanup.""" """Cleanup."""
ent_reg = await entity_registry.async_get_registry(hass) ent_reg = await entity_registry.async_get_registry(hass)
async_cleanup(hass, dev_reg, ent_reg) async_cleanup(hass, dev_reg, ent_reg)

View file

@ -18,7 +18,7 @@ from typing import (
List, List,
Optional, Optional,
Tuple, Tuple,
cast, Union,
) )
import attr import attr
@ -39,13 +39,11 @@ from homeassistant.util import slugify
from homeassistant.util.yaml import load_yaml from homeassistant.util.yaml import load_yaml
from .singleton import singleton from .singleton import singleton
from .typing import UNDEFINED, HomeAssistantType from .typing import UNDEFINED, HomeAssistantType, UndefinedType
if TYPE_CHECKING: if TYPE_CHECKING:
from homeassistant.config_entries import ConfigEntry # noqa: F401 from homeassistant.config_entries import ConfigEntry # noqa: F401
# mypy: allow-untyped-defs, no-check-untyped-defs
PATH_REGISTRY = "entity_registry.yaml" PATH_REGISTRY = "entity_registry.yaml"
DATA_REGISTRY = "entity_registry" DATA_REGISTRY = "entity_registry"
EVENT_ENTITY_REGISTRY_UPDATED = "entity_registry_updated" EVENT_ENTITY_REGISTRY_UPDATED = "entity_registry_updated"
@ -222,7 +220,7 @@ class EntityRegistry:
entity_id = self.async_get_entity_id(domain, platform, unique_id) entity_id = self.async_get_entity_id(domain, platform, unique_id)
if entity_id: if entity_id:
return self._async_update_entity( # type: ignore return self._async_update_entity(
entity_id, entity_id,
config_entry_id=config_entry_id or UNDEFINED, config_entry_id=config_entry_id or UNDEFINED,
device_id=device_id or UNDEFINED, device_id=device_id or UNDEFINED,
@ -316,63 +314,56 @@ class EntityRegistry:
for entity in entities: for entity in entities:
if entity.disabled_by != DISABLED_DEVICE: if entity.disabled_by != DISABLED_DEVICE:
continue continue
self.async_update_entity( # type: ignore self.async_update_entity(entity.entity_id, disabled_by=None)
entity.entity_id, disabled_by=None
)
return return
entities = async_entries_for_device(self, event.data["device_id"]) entities = async_entries_for_device(self, event.data["device_id"])
for entity in entities: for entity in entities:
self.async_update_entity( # type: ignore self.async_update_entity(entity.entity_id, disabled_by=DISABLED_DEVICE)
entity.entity_id, disabled_by=DISABLED_DEVICE
)
@callback @callback
def async_update_entity( def async_update_entity(
self, self,
entity_id, entity_id: str,
*, *,
name=UNDEFINED, name: Union[str, None, UndefinedType] = UNDEFINED,
icon=UNDEFINED, icon: Union[str, None, UndefinedType] = UNDEFINED,
area_id=UNDEFINED, area_id: Union[str, None, UndefinedType] = UNDEFINED,
new_entity_id=UNDEFINED, new_entity_id: Union[str, UndefinedType] = UNDEFINED,
new_unique_id=UNDEFINED, new_unique_id: Union[str, UndefinedType] = UNDEFINED,
disabled_by=UNDEFINED, disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
): ) -> RegistryEntry:
"""Update properties of an entity.""" """Update properties of an entity."""
return cast( # cast until we have _async_update_entity type hinted return self._async_update_entity(
RegistryEntry, entity_id,
self._async_update_entity( name=name,
entity_id, icon=icon,
name=name, area_id=area_id,
icon=icon, new_entity_id=new_entity_id,
area_id=area_id, new_unique_id=new_unique_id,
new_entity_id=new_entity_id, disabled_by=disabled_by,
new_unique_id=new_unique_id,
disabled_by=disabled_by,
),
) )
@callback @callback
def _async_update_entity( def _async_update_entity(
self, self,
entity_id, entity_id: str,
*, *,
name=UNDEFINED, name: Union[str, None, UndefinedType] = UNDEFINED,
icon=UNDEFINED, icon: Union[str, None, UndefinedType] = UNDEFINED,
config_entry_id=UNDEFINED, config_entry_id: Union[str, None, UndefinedType] = UNDEFINED,
new_entity_id=UNDEFINED, new_entity_id: Union[str, UndefinedType] = UNDEFINED,
device_id=UNDEFINED, device_id: Union[str, None, UndefinedType] = UNDEFINED,
area_id=UNDEFINED, area_id: Union[str, None, UndefinedType] = UNDEFINED,
new_unique_id=UNDEFINED, new_unique_id: Union[str, UndefinedType] = UNDEFINED,
disabled_by=UNDEFINED, disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
capabilities=UNDEFINED, capabilities: Union[Dict[str, Any], None, UndefinedType] = UNDEFINED,
supported_features=UNDEFINED, supported_features: Union[int, UndefinedType] = UNDEFINED,
device_class=UNDEFINED, device_class: Union[str, None, UndefinedType] = UNDEFINED,
unit_of_measurement=UNDEFINED, unit_of_measurement: Union[str, None, UndefinedType] = UNDEFINED,
original_name=UNDEFINED, original_name: Union[str, None, UndefinedType] = UNDEFINED,
original_icon=UNDEFINED, original_icon: Union[str, None, UndefinedType] = UNDEFINED,
): ) -> RegistryEntry:
"""Private facing update properties method.""" """Private facing update properties method."""
old = self.entities[entity_id] old = self.entities[entity_id]
@ -526,7 +517,7 @@ class EntityRegistry:
"""Clear area id from registry entries.""" """Clear area id from registry entries."""
for entity_id, entry in self.entities.items(): for entity_id, entry in self.entities.items():
if area_id == entry.area_id: if area_id == entry.area_id:
self._async_update_entity(entity_id, area_id=None) # type: ignore self._async_update_entity(entity_id, area_id=None)
def _register_entry(self, entry: RegistryEntry) -> None: def _register_entry(self, entry: RegistryEntry) -> None:
self.entities[entry.entity_id] = entry self.entities[entry.entity_id] = entry