Prevent apple_tv rediscovery from secondary identifiers (#61973)

This commit is contained in:
J. Nick Koston 2021-12-16 02:25:18 -06:00 committed by GitHub
parent 06c1949d2f
commit 048102e053
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 17 deletions

View file

@ -1,4 +1,6 @@
"""Config flow for Apple TV integration.""" """Config flow for Apple TV integration."""
from __future__ import annotations
import asyncio import asyncio
from collections import deque from collections import deque
from ipaddress import ip_address from ipaddress import ip_address
@ -98,12 +100,19 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
re-used, otherwise the newly discovered identifier is used instead. re-used, otherwise the newly discovered identifier is used instead.
""" """
all_identifiers = set(self.atv.all_identifiers) all_identifiers = set(self.atv.all_identifiers)
if unique_id := self._entry_unique_id_from_identifers(all_identifiers):
return unique_id
return self.atv.identifier
@callback
def _entry_unique_id_from_identifers(self, all_identifiers: set[str]) -> str | None:
"""Search existing entries for an identifier and return the unique id."""
for entry in self._async_current_entries(): for entry in self._async_current_entries():
if all_identifiers.intersection( if all_identifiers.intersection(
entry.data.get(CONF_IDENTIFIERS, [entry.unique_id]) entry.data.get(CONF_IDENTIFIERS, [entry.unique_id])
): ):
return entry.unique_id return entry.unique_id
return self.atv.identifier return None
async def async_step_reauth(self, user_input=None): async def async_step_reauth(self, user_input=None):
"""Handle initial step when updating invalid credentials.""" """Handle initial step when updating invalid credentials."""
@ -166,6 +175,20 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
if unique_id is None: if unique_id is None:
return self.async_abort(reason="unknown") return self.async_abort(reason="unknown")
if existing_unique_id := self._entry_unique_id_from_identifers({unique_id}):
await self.async_set_unique_id(existing_unique_id)
self._abort_if_unique_id_configured(updates={CONF_ADDRESS: host})
self._async_abort_entries_match({CONF_ADDRESS: host})
await self._async_aggregate_discoveries(host, unique_id)
# Scan for the device in order to extract _all_ unique identifiers assigned to
# it. Not doing it like this will yield multiple config flows for the same
# device, one per protocol, which is undesired.
self.scan_filter = host
return await self.async_find_device_wrapper(self.async_found_zeroconf_device)
async def _async_aggregate_discoveries(self, host: str, unique_id: str) -> None:
"""Wait for multiple zeroconf services to be discovered an aggregate them."""
# #
# Suppose we have a device with three services: A, B and C. Let's assume # Suppose we have a device with three services: A, B and C. Let's assume
# service A is discovered by Zeroconf, triggering a device scan that also finds # service A is discovered by Zeroconf, triggering a device scan that also finds
@ -195,23 +218,18 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
# apple_tv device has multiple services that are discovered by # apple_tv device has multiple services that are discovered by
# zeroconf. # zeroconf.
# #
self._async_check_and_update_in_progress(host, unique_id)
await asyncio.sleep(DISCOVERY_AGGREGATION_TIME) await asyncio.sleep(DISCOVERY_AGGREGATION_TIME)
# Check again after sleeping in case another flow
self._async_check_in_progress_and_set_address(host, unique_id) # has made progress while we yielded to the event loop
self._async_check_and_update_in_progress(host, unique_id)
# Scan for the device in order to extract _all_ unique identifiers assigned to # Host must only be set AFTER checking and updating in progress
# it. Not doing it like this will yield multiple config flows for the same # flows or we will have a race condition where no flows move forward.
# device, one per protocol, which is undesired. self.context[CONF_ADDRESS] = host
self.scan_filter = host
return await self.async_find_device_wrapper(self.async_found_zeroconf_device)
@callback @callback
def _async_check_in_progress_and_set_address(self, host: str, unique_id: str): def _async_check_and_update_in_progress(self, host: str, unique_id: str) -> None:
"""Check for in-progress flows and update them with identifiers if needed. """Check for in-progress flows and update them with identifiers if needed."""
This code must not await between checking in progress and setting the host
or it will have a race condition where no flows move forward.
"""
for flow in self._async_in_progress(include_uninitialized=True): for flow in self._async_in_progress(include_uninitialized=True):
context = flow["context"] context = flow["context"]
if ( if (
@ -226,7 +244,6 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
# Add potentially new identifiers from this device to the existing flow # Add potentially new identifiers from this device to the existing flow
context["all_identifiers"].append(unique_id) context["all_identifiers"].append(unique_id)
raise data_entry_flow.AbortFlow("already_in_progress") raise data_entry_flow.AbortFlow("already_in_progress")
self.context[CONF_ADDRESS] = host
async def async_found_zeroconf_device(self, user_input=None): async def async_found_zeroconf_device(self, user_input=None):
"""Handle device found after Zeroconf discovery.""" """Handle device found after Zeroconf discovery."""

View file

@ -10,7 +10,11 @@ import pytest
from homeassistant import config_entries, data_entry_flow from homeassistant import config_entries, data_entry_flow
from homeassistant.components import zeroconf from homeassistant.components import zeroconf
from homeassistant.components.apple_tv import CONF_ADDRESS, config_flow from homeassistant.components.apple_tv import CONF_ADDRESS, config_flow
from homeassistant.components.apple_tv.const import CONF_START_OFF, DOMAIN from homeassistant.components.apple_tv.const import (
CONF_IDENTIFIERS,
CONF_START_OFF,
DOMAIN,
)
from .common import airplay_service, create_conf, mrp_service, raop_service from .common import airplay_service, create_conf, mrp_service, raop_service
@ -652,6 +656,45 @@ async def test_zeroconf_ip_change(hass, mock_scan):
assert unrelated_entry.data[CONF_ADDRESS] == "127.0.0.2" assert unrelated_entry.data[CONF_ADDRESS] == "127.0.0.2"
async def test_zeroconf_ip_change_via_secondary_identifier(hass, mock_scan):
"""Test that the config entry gets updated when the ip changes and reloads.
Instead of checking only the unique id, all the identifiers
in the config entry are checked
"""
entry = MockConfigEntry(
domain="apple_tv",
unique_id="aa:bb:cc:dd:ee:ff",
data={CONF_IDENTIFIERS: ["mrpid"], CONF_ADDRESS: "127.0.0.2"},
)
unrelated_entry = MockConfigEntry(
domain="apple_tv", unique_id="unrelated", data={CONF_ADDRESS: "127.0.0.2"}
)
unrelated_entry.add_to_hass(hass)
entry.add_to_hass(hass)
mock_scan.result = [
create_conf(
IPv4Address("127.0.0.1"), "Device", mrp_service(), airplay_service()
)
]
with patch(
"homeassistant.components.apple_tv.async_setup_entry", return_value=True
) as mock_async_setup:
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=DMAP_SERVICE,
)
await hass.async_block_till_done()
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == "already_configured"
assert len(mock_async_setup.mock_calls) == 2
assert entry.data[CONF_ADDRESS] == "127.0.0.1"
assert unrelated_entry.data[CONF_ADDRESS] == "127.0.0.2"
async def test_zeroconf_add_existing_aborts(hass, dmap_device): async def test_zeroconf_add_existing_aborts(hass, dmap_device):
"""Test start new zeroconf flow while existing flow is active aborts.""" """Test start new zeroconf flow while existing flow is active aborts."""
await hass.config_entries.flow.async_init( await hass.config_entries.flow.async_init(