Type hint device registry identifiers as set of str 2-tuples (#50355)

* Type hint device registry identifiers as set of str 2-tuples

* Fix airly test

* Really fix airly test, add another migration test
This commit is contained in:
Ville Skyttä 2021-05-10 13:13:45 +03:00 committed by GitHub
parent 1c98df5d18
commit b89c53f759
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 39 additions and 27 deletions

View file

@ -77,14 +77,25 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
entry, unique_id=f"{latitude}-{longitude}" entry, unique_id=f"{latitude}-{longitude}"
) )
# identifiers in device_info should use Tuple[str, str, str] type, but latitude and # identifiers in device_info should use tuple[str, str] type, but latitude and
# longitude are float, so we convert old device entries to use correct types # longitude are float, so we convert old device entries to use correct types
# We used to use a str 3-tuple here sometime, convert that to a 2-tuple too.
device_registry = await async_get_registry(hass) device_registry = await async_get_registry(hass)
old_ids = (DOMAIN, latitude, longitude) old_ids = (DOMAIN, latitude, longitude)
device_entry = device_registry.async_get_device({old_ids}) for old_ids in (
if device_entry and entry.entry_id in device_entry.config_entries: (DOMAIN, latitude, longitude),
new_ids = (DOMAIN, str(latitude), str(longitude)) (
device_registry.async_update_device(device_entry.id, new_identifiers={new_ids}) DOMAIN,
str(latitude),
str(longitude),
),
):
device_entry = device_registry.async_get_device({old_ids}) # type: ignore[arg-type]
if device_entry and entry.entry_id in device_entry.config_entries:
new_ids = (DOMAIN, f"{latitude}-{longitude}")
device_registry.async_update_device(
device_entry.id, new_identifiers={new_ids}
)
websession = async_get_clientsession(hass) websession = async_get_clientsession(hass)

View file

@ -109,8 +109,7 @@ class AirlyAirQuality(CoordinatorEntity, AirQualityEntity):
"identifiers": { "identifiers": {
( (
DOMAIN, DOMAIN,
str(self.coordinator.latitude), f"{self.coordinator.latitude}-{self.coordinator.longitude}",
str(self.coordinator.longitude),
) )
}, },
"name": DEFAULT_NAME, "name": DEFAULT_NAME,

View file

@ -100,8 +100,7 @@ class AirlySensor(CoordinatorEntity, SensorEntity):
"identifiers": { "identifiers": {
( (
DOMAIN, DOMAIN,
str(self.coordinator.latitude), f"{self.coordinator.latitude}-{self.coordinator.longitude}",
str(self.coordinator.longitude),
) )
}, },
"name": DEFAULT_NAME, "name": DEFAULT_NAME,

View file

@ -163,7 +163,7 @@ class Router:
return DEFAULT_DEVICE_NAME return DEFAULT_DEVICE_NAME
@property @property
def device_identifiers(self) -> set[tuple[str, ...]]: def device_identifiers(self) -> set[tuple[str, str]]:
"""Get router identifiers for device registry.""" """Get router identifiers for device registry."""
try: try:
return {(DOMAIN, self.data[KEY_DEVICE_INFORMATION]["SerialNumber"])} return {(DOMAIN, self.data[KEY_DEVICE_INFORMATION]["SerialNumber"])}

View file

@ -76,7 +76,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return unload_ok return unload_ok
def device_identifiers(printer: SyncThru) -> set[tuple[str, ...]] | None: def device_identifiers(printer: SyncThru) -> set[tuple[str, str]] | None:
"""Get device identifiers for device registry.""" """Get device identifiers for device registry."""
serial = printer.serial_number() serial = printer.serial_number()
if serial is None: if serial is None:

View file

@ -45,7 +45,7 @@ ORPHANED_DEVICE_KEEP_SECONDS = 86400 * 30
class _DeviceIndex(NamedTuple): class _DeviceIndex(NamedTuple):
identifiers: dict[tuple[str, ...], str] identifiers: dict[tuple[str, str], str]
connections: dict[tuple[str, str], str] connections: dict[tuple[str, str], str]
@ -55,7 +55,7 @@ class DeviceEntry:
config_entries: set[str] = attr.ib(converter=set, factory=set) config_entries: set[str] = attr.ib(converter=set, factory=set)
connections: set[tuple[str, str]] = attr.ib(converter=set, factory=set) connections: set[tuple[str, str]] = attr.ib(converter=set, factory=set)
identifiers: set[tuple[str, ...]] = attr.ib(converter=set, factory=set) identifiers: set[tuple[str, str]] = attr.ib(converter=set, factory=set)
manufacturer: str | None = attr.ib(default=None) manufacturer: str | None = attr.ib(default=None)
model: str | None = attr.ib(default=None) model: str | None = attr.ib(default=None)
name: str | None = attr.ib(default=None) name: str | None = attr.ib(default=None)
@ -92,7 +92,7 @@ class DeletedDeviceEntry:
config_entries: set[str] = attr.ib() config_entries: set[str] = attr.ib()
connections: set[tuple[str, str]] = attr.ib() connections: set[tuple[str, str]] = attr.ib()
identifiers: set[tuple[str, ...]] = attr.ib() identifiers: set[tuple[str, str]] = attr.ib()
id: str = attr.ib() id: str = attr.ib()
orphaned_timestamp: float | None = attr.ib() orphaned_timestamp: float | None = attr.ib()
@ -100,7 +100,7 @@ class DeletedDeviceEntry:
self, self,
config_entry_id: str, config_entry_id: str,
connections: set[tuple[str, str]], connections: set[tuple[str, str]],
identifiers: set[tuple[str, ...]], identifiers: set[tuple[str, str]],
) -> DeviceEntry: ) -> DeviceEntry:
"""Create DeviceEntry from DeletedDeviceEntry.""" """Create DeviceEntry from DeletedDeviceEntry."""
return DeviceEntry( return DeviceEntry(
@ -135,7 +135,7 @@ def format_mac(mac: str) -> str:
def _async_get_device_id_from_index( def _async_get_device_id_from_index(
devices_index: _DeviceIndex, devices_index: _DeviceIndex,
identifiers: set[tuple[str, ...]], identifiers: set[tuple[str, str]],
connections: set[tuple[str, str]] | None, connections: set[tuple[str, str]] | None,
) -> str | None: ) -> str | None:
"""Check if device has previously been registered.""" """Check if device has previously been registered."""
@ -172,7 +172,7 @@ class DeviceRegistry:
@callback @callback
def async_get_device( def async_get_device(
self, self,
identifiers: set[tuple[str, ...]], identifiers: set[tuple[str, str]],
connections: set[tuple[str, str]] | None = None, connections: set[tuple[str, str]] | None = None,
) -> DeviceEntry | None: ) -> DeviceEntry | None:
"""Check if device is registered.""" """Check if device is registered."""
@ -185,7 +185,7 @@ class DeviceRegistry:
def _async_get_deleted_device( def _async_get_deleted_device(
self, self,
identifiers: set[tuple[str, ...]], identifiers: set[tuple[str, str]],
connections: set[tuple[str, str]] | None, connections: set[tuple[str, str]] | None,
) -> DeletedDeviceEntry | None: ) -> DeletedDeviceEntry | None:
"""Check if device is deleted.""" """Check if device is deleted."""
@ -245,7 +245,7 @@ class DeviceRegistry:
*, *,
config_entry_id: str, config_entry_id: str,
connections: set[tuple[str, str]] | None = None, connections: set[tuple[str, str]] | None = None,
identifiers: set[tuple[str, ...]] | None = None, identifiers: set[tuple[str, str]] | None = None,
manufacturer: str | None | UndefinedType = UNDEFINED, manufacturer: str | None | UndefinedType = UNDEFINED,
model: str | None | UndefinedType = UNDEFINED, model: str | None | UndefinedType = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED, name: str | None | UndefinedType = UNDEFINED,
@ -329,7 +329,7 @@ class DeviceRegistry:
model: str | None | UndefinedType = UNDEFINED, model: str | None | UndefinedType = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED, name: str | None | UndefinedType = UNDEFINED,
name_by_user: str | None | UndefinedType = UNDEFINED, name_by_user: str | None | UndefinedType = UNDEFINED,
new_identifiers: set[tuple[str, ...]] | UndefinedType = UNDEFINED, new_identifiers: set[tuple[str, str]] | UndefinedType = UNDEFINED,
sw_version: str | None | UndefinedType = UNDEFINED, sw_version: str | None | UndefinedType = UNDEFINED,
via_device_id: str | None | UndefinedType = UNDEFINED, via_device_id: str | None | UndefinedType = UNDEFINED,
remove_config_entry_id: str | UndefinedType = UNDEFINED, remove_config_entry_id: str | UndefinedType = UNDEFINED,
@ -360,8 +360,8 @@ class DeviceRegistry:
add_config_entry_id: str | UndefinedType = UNDEFINED, add_config_entry_id: str | UndefinedType = UNDEFINED,
remove_config_entry_id: str | UndefinedType = UNDEFINED, remove_config_entry_id: str | UndefinedType = UNDEFINED,
merge_connections: set[tuple[str, str]] | UndefinedType = UNDEFINED, merge_connections: set[tuple[str, str]] | UndefinedType = UNDEFINED,
merge_identifiers: set[tuple[str, ...]] | UndefinedType = UNDEFINED, merge_identifiers: set[tuple[str, str]] | UndefinedType = UNDEFINED,
new_identifiers: set[tuple[str, ...]] | UndefinedType = UNDEFINED, new_identifiers: set[tuple[str, str]] | UndefinedType = UNDEFINED,
manufacturer: str | None | UndefinedType = UNDEFINED, manufacturer: str | None | UndefinedType = UNDEFINED,
model: str | None | UndefinedType = UNDEFINED, model: str | None | UndefinedType = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED, name: str | None | UndefinedType = UNDEFINED,
@ -519,7 +519,7 @@ class DeviceRegistry:
config_entries=set(device["config_entries"]), config_entries=set(device["config_entries"]),
# type ignores (if tuple arg was cast): likely https://github.com/python/mypy/issues/8625 # type ignores (if tuple arg was cast): likely https://github.com/python/mypy/issues/8625
connections={tuple(conn) for conn in device["connections"]}, # type: ignore[misc] connections={tuple(conn) for conn in device["connections"]}, # type: ignore[misc]
identifiers={tuple(iden) for iden in device["identifiers"]}, identifiers={tuple(iden) for iden in device["identifiers"]}, # type: ignore[misc]
id=device["id"], id=device["id"],
# Introduced in 2021.2 # Introduced in 2021.2
orphaned_timestamp=device.get("orphaned_timestamp"), orphaned_timestamp=device.get("orphaned_timestamp"),

View file

@ -115,7 +115,7 @@ class DeviceInfo(TypedDict, total=False):
name: str name: str
connections: set[tuple[str, str]] connections: set[tuple[str, str]]
identifiers: set[tuple[str, ...]] identifiers: set[tuple[str, str]]
manufacturer: str manufacturer: str
model: str model: str
suggested_area: str suggested_area: str

View file

@ -1,6 +1,8 @@
"""Test init of Airly integration.""" """Test init of Airly integration."""
from unittest.mock import patch from unittest.mock import patch
import pytest
from homeassistant.components.airly import set_update_interval from homeassistant.components.airly import set_update_interval
from homeassistant.components.airly.const import DOMAIN from homeassistant.components.airly.const import DOMAIN
from homeassistant.config_entries import ( from homeassistant.config_entries import (
@ -188,7 +190,8 @@ async def test_unload_entry(hass, aioclient_mock):
assert not hass.data.get(DOMAIN) assert not hass.data.get(DOMAIN)
async def test_migrate_device_entry(hass, aioclient_mock): @pytest.mark.parametrize("old_identifier", ((DOMAIN, 123, 456), (DOMAIN, "123", "456")))
async def test_migrate_device_entry(hass, aioclient_mock, old_identifier):
"""Test device_info identifiers migration.""" """Test device_info identifiers migration."""
config_entry = MockConfigEntry( config_entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
@ -207,13 +210,13 @@ async def test_migrate_device_entry(hass, aioclient_mock):
device_reg = mock_device_registry(hass) device_reg = mock_device_registry(hass)
device_entry = device_reg.async_get_or_create( device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id, identifiers={(DOMAIN, 123, 456)} config_entry_id=config_entry.entry_id, identifiers={old_identifier}
) )
await hass.config_entries.async_setup(config_entry.entry_id) await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
migrated_device_entry = device_reg.async_get_or_create( migrated_device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id, identifiers={(DOMAIN, "123", "456")} config_entry_id=config_entry.entry_id, identifiers={(DOMAIN, "123-456")}
) )
assert device_entry.id == migrated_device_entry.id assert device_entry.id == migrated_device_entry.id