Improve device registry type hints (#44919)

* Fix async_get_or_create via_device type hint

* Specify collection element types
This commit is contained in:
Ville Skyttä 2021-01-08 03:38:57 +02:00 committed by GitHub
parent 0426b211f6
commit 20e2493f68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -13,6 +13,8 @@ from .debounce import Debouncer
from .singleton import singleton
from .typing import UNDEFINED, HomeAssistantType, UndefinedType
# mypy: disallow_any_generics
if TYPE_CHECKING:
from . import entity_registry
@ -124,7 +126,7 @@ class DeviceRegistry:
devices: Dict[str, DeviceEntry]
deleted_devices: Dict[str, DeletedDeviceEntry]
_devices_index: Dict[str, Dict[str, Dict[str, str]]]
_devices_index: Dict[str, Dict[str, Dict[Tuple[str, str], str]]]
def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize the device registry."""
@ -140,8 +142,8 @@ class DeviceRegistry:
@callback
def async_get_device(
self,
identifiers: set,
connections: Optional[set] = None,
identifiers: Set[Tuple[str, str]],
connections: Optional[Set[Tuple[str, str]]] = None,
) -> Optional[DeviceEntry]:
"""Check if device is registered."""
device_id = self._async_get_device_id_from_index(
@ -152,7 +154,9 @@ class DeviceRegistry:
return self.devices[device_id]
def _async_get_deleted_device(
self, identifiers: set, connections: Optional[set]
self,
identifiers: Set[Tuple[str, str]],
connections: Optional[Set[Tuple[str, str]]],
) -> Optional[DeletedDeviceEntry]:
"""Check if device is deleted."""
device_id = self._async_get_device_id_from_index(
@ -163,7 +167,10 @@ class DeviceRegistry:
return self.deleted_devices[device_id]
def _async_get_device_id_from_index(
self, index: str, identifiers: set, connections: Optional[set]
self,
index: str,
identifiers: Set[Tuple[str, str]],
connections: Optional[Set[Tuple[str, str]]],
) -> Optional[str]:
"""Check if device has previously been registered."""
devices_index = self._devices_index[index]
@ -227,8 +234,8 @@ class DeviceRegistry:
self,
*,
config_entry_id: str,
connections: Optional[set] = None,
identifiers: Optional[set] = None,
connections: Optional[Set[Tuple[str, str]]] = None,
identifiers: Optional[Set[Tuple[str, str]]] = None,
manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
model: Union[str, None, UndefinedType] = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED,
@ -237,7 +244,7 @@ class DeviceRegistry:
default_name: Union[str, None, UndefinedType] = UNDEFINED,
sw_version: Union[str, None, UndefinedType] = UNDEFINED,
entry_type: Union[str, None, UndefinedType] = UNDEFINED,
via_device: Optional[str] = None,
via_device: Optional[Tuple[str, str]] = None,
# To disable a device if it gets created
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
) -> Optional[DeviceEntry]:
@ -305,7 +312,7 @@ class DeviceRegistry:
model: Union[str, None, UndefinedType] = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED,
name_by_user: Union[str, None, UndefinedType] = UNDEFINED,
new_identifiers: Union[set, UndefinedType] = UNDEFINED,
new_identifiers: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED,
sw_version: Union[str, None, UndefinedType] = UNDEFINED,
via_device_id: Union[str, None, UndefinedType] = UNDEFINED,
remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
@ -333,9 +340,9 @@ class DeviceRegistry:
*,
add_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
merge_connections: Union[set, UndefinedType] = UNDEFINED,
merge_identifiers: Union[set, UndefinedType] = UNDEFINED,
new_identifiers: Union[set, UndefinedType] = UNDEFINED,
merge_connections: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED,
merge_identifiers: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED,
new_identifiers: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED,
manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
model: Union[str, None, UndefinedType] = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED,
@ -657,7 +664,7 @@ def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> Non
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, startup_clean)
def _normalize_connections(connections: set) -> set:
def _normalize_connections(connections: Set[Tuple[str, str]]) -> Set[Tuple[str, str]]:
"""Normalize connections to ensure we can match mac addresses."""
return {
(key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value)
@ -666,7 +673,8 @@ def _normalize_connections(connections: set) -> set:
def _add_device_to_index(
devices_index: dict, device: Union[DeviceEntry, DeletedDeviceEntry]
devices_index: Dict[str, Dict[Tuple[str, str], str]],
device: Union[DeviceEntry, DeletedDeviceEntry],
) -> None:
"""Add a device to the index."""
for identifier in device.identifiers:
@ -676,7 +684,8 @@ def _add_device_to_index(
def _remove_device_from_index(
devices_index: dict, device: Union[DeviceEntry, DeletedDeviceEntry]
devices_index: Dict[str, Dict[Tuple[str, str], str]],
device: Union[DeviceEntry, DeletedDeviceEntry],
) -> None:
"""Remove a device from the index."""
for identifier in device.identifiers: