Improve device registry type hints (#44919)

* Fix async_get_or_create via_device type hint

* Specify collection element types
This commit is contained in:
Ville Skyttä 2021-01-08 03:38:57 +02:00 committed by GitHub
parent 0426b211f6
commit 20e2493f68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -13,6 +13,8 @@ from .debounce import Debouncer
from .singleton import singleton from .singleton import singleton
from .typing import UNDEFINED, HomeAssistantType, UndefinedType from .typing import UNDEFINED, HomeAssistantType, UndefinedType
# mypy: disallow_any_generics
if TYPE_CHECKING: if TYPE_CHECKING:
from . import entity_registry from . import entity_registry
@ -124,7 +126,7 @@ class DeviceRegistry:
devices: Dict[str, DeviceEntry] devices: Dict[str, DeviceEntry]
deleted_devices: Dict[str, DeletedDeviceEntry] deleted_devices: Dict[str, DeletedDeviceEntry]
_devices_index: Dict[str, Dict[str, Dict[str, str]]] _devices_index: Dict[str, Dict[str, Dict[Tuple[str, str], str]]]
def __init__(self, hass: HomeAssistantType) -> None: def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize the device registry.""" """Initialize the device registry."""
@ -140,8 +142,8 @@ class DeviceRegistry:
@callback @callback
def async_get_device( def async_get_device(
self, self,
identifiers: set, identifiers: Set[Tuple[str, str]],
connections: Optional[set] = None, connections: Optional[Set[Tuple[str, str]]] = None,
) -> Optional[DeviceEntry]: ) -> Optional[DeviceEntry]:
"""Check if device is registered.""" """Check if device is registered."""
device_id = self._async_get_device_id_from_index( device_id = self._async_get_device_id_from_index(
@ -152,7 +154,9 @@ class DeviceRegistry:
return self.devices[device_id] return self.devices[device_id]
def _async_get_deleted_device( def _async_get_deleted_device(
self, identifiers: set, connections: Optional[set] self,
identifiers: Set[Tuple[str, str]],
connections: Optional[Set[Tuple[str, str]]],
) -> Optional[DeletedDeviceEntry]: ) -> Optional[DeletedDeviceEntry]:
"""Check if device is deleted.""" """Check if device is deleted."""
device_id = self._async_get_device_id_from_index( device_id = self._async_get_device_id_from_index(
@ -163,7 +167,10 @@ class DeviceRegistry:
return self.deleted_devices[device_id] return self.deleted_devices[device_id]
def _async_get_device_id_from_index( def _async_get_device_id_from_index(
self, index: str, identifiers: set, connections: Optional[set] self,
index: str,
identifiers: Set[Tuple[str, str]],
connections: Optional[Set[Tuple[str, str]]],
) -> Optional[str]: ) -> Optional[str]:
"""Check if device has previously been registered.""" """Check if device has previously been registered."""
devices_index = self._devices_index[index] devices_index = self._devices_index[index]
@ -227,8 +234,8 @@ class DeviceRegistry:
self, self,
*, *,
config_entry_id: str, config_entry_id: str,
connections: Optional[set] = None, connections: Optional[Set[Tuple[str, str]]] = None,
identifiers: Optional[set] = None, identifiers: Optional[Set[Tuple[str, str]]] = None,
manufacturer: Union[str, None, UndefinedType] = UNDEFINED, manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
model: Union[str, None, UndefinedType] = UNDEFINED, model: Union[str, None, UndefinedType] = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED, name: Union[str, None, UndefinedType] = UNDEFINED,
@ -237,7 +244,7 @@ class DeviceRegistry:
default_name: Union[str, None, UndefinedType] = UNDEFINED, default_name: Union[str, None, UndefinedType] = UNDEFINED,
sw_version: Union[str, None, UndefinedType] = UNDEFINED, sw_version: Union[str, None, UndefinedType] = UNDEFINED,
entry_type: Union[str, None, UndefinedType] = UNDEFINED, entry_type: Union[str, None, UndefinedType] = UNDEFINED,
via_device: Optional[str] = None, via_device: Optional[Tuple[str, str]] = None,
# To disable a device if it gets created # To disable a device if it gets created
disabled_by: Union[str, None, UndefinedType] = UNDEFINED, disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
) -> Optional[DeviceEntry]: ) -> Optional[DeviceEntry]:
@ -305,7 +312,7 @@ class DeviceRegistry:
model: Union[str, None, UndefinedType] = UNDEFINED, model: Union[str, None, UndefinedType] = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED, name: Union[str, None, UndefinedType] = UNDEFINED,
name_by_user: Union[str, None, UndefinedType] = UNDEFINED, name_by_user: Union[str, None, UndefinedType] = UNDEFINED,
new_identifiers: Union[set, UndefinedType] = UNDEFINED, new_identifiers: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED,
sw_version: Union[str, None, UndefinedType] = UNDEFINED, sw_version: Union[str, None, UndefinedType] = UNDEFINED,
via_device_id: Union[str, None, UndefinedType] = UNDEFINED, via_device_id: Union[str, None, UndefinedType] = UNDEFINED,
remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED, remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
@ -333,9 +340,9 @@ class DeviceRegistry:
*, *,
add_config_entry_id: Union[str, UndefinedType] = UNDEFINED, add_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED, remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
merge_connections: Union[set, UndefinedType] = UNDEFINED, merge_connections: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED,
merge_identifiers: Union[set, UndefinedType] = UNDEFINED, merge_identifiers: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED,
new_identifiers: Union[set, UndefinedType] = UNDEFINED, new_identifiers: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED,
manufacturer: Union[str, None, UndefinedType] = UNDEFINED, manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
model: Union[str, None, UndefinedType] = UNDEFINED, model: Union[str, None, UndefinedType] = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED, name: Union[str, None, UndefinedType] = UNDEFINED,
@ -657,7 +664,7 @@ def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> Non
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, startup_clean) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, startup_clean)
def _normalize_connections(connections: set) -> set: def _normalize_connections(connections: Set[Tuple[str, str]]) -> Set[Tuple[str, str]]:
"""Normalize connections to ensure we can match mac addresses.""" """Normalize connections to ensure we can match mac addresses."""
return { return {
(key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value) (key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value)
@ -666,7 +673,8 @@ def _normalize_connections(connections: set) -> set:
def _add_device_to_index( def _add_device_to_index(
devices_index: dict, device: Union[DeviceEntry, DeletedDeviceEntry] devices_index: Dict[str, Dict[Tuple[str, str], str]],
device: Union[DeviceEntry, DeletedDeviceEntry],
) -> None: ) -> None:
"""Add a device to the index.""" """Add a device to the index."""
for identifier in device.identifiers: for identifier in device.identifiers:
@ -676,7 +684,8 @@ def _add_device_to_index(
def _remove_device_from_index( def _remove_device_from_index(
devices_index: dict, device: Union[DeviceEntry, DeletedDeviceEntry] devices_index: Dict[str, Dict[Tuple[str, str], str]],
device: Union[DeviceEntry, DeletedDeviceEntry],
) -> None: ) -> None:
"""Remove a device from the index.""" """Remove a device from the index."""
for identifier in device.identifiers: for identifier in device.identifiers: