Prevent apple_tv rediscovery from secondary identifiers (#61973)

This commit is contained in:
J. Nick Koston 2021-12-16 02:25:18 -06:00 committed by GitHub
parent 06c1949d2f
commit 048102e053
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 17 deletions

View file

@ -1,4 +1,6 @@
"""Config flow for Apple TV integration."""
from __future__ import annotations
import asyncio
from collections import deque
from ipaddress import ip_address
@ -98,12 +100,19 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
re-used, otherwise the newly discovered identifier is used instead.
"""
all_identifiers = set(self.atv.all_identifiers)
if unique_id := self._entry_unique_id_from_identifers(all_identifiers):
return unique_id
return self.atv.identifier
@callback
def _entry_unique_id_from_identifers(self, all_identifiers: set[str]) -> str | None:
"""Search existing entries for an identifier and return the unique id."""
for entry in self._async_current_entries():
if all_identifiers.intersection(
entry.data.get(CONF_IDENTIFIERS, [entry.unique_id])
):
return entry.unique_id
return self.atv.identifier
return None
async def async_step_reauth(self, user_input=None):
"""Handle initial step when updating invalid credentials."""
@ -166,6 +175,20 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
if unique_id is None:
return self.async_abort(reason="unknown")
if existing_unique_id := self._entry_unique_id_from_identifers({unique_id}):
await self.async_set_unique_id(existing_unique_id)
self._abort_if_unique_id_configured(updates={CONF_ADDRESS: host})
self._async_abort_entries_match({CONF_ADDRESS: host})
await self._async_aggregate_discoveries(host, unique_id)
# Scan for the device in order to extract _all_ unique identifiers assigned to
# it. Not doing it like this will yield multiple config flows for the same
# device, one per protocol, which is undesired.
self.scan_filter = host
return await self.async_find_device_wrapper(self.async_found_zeroconf_device)
async def _async_aggregate_discoveries(self, host: str, unique_id: str) -> None:
"""Wait for multiple zeroconf services to be discovered an aggregate them."""
#
# Suppose we have a device with three services: A, B and C. Let's assume
# service A is discovered by Zeroconf, triggering a device scan that also finds
@ -195,23 +218,18 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
# apple_tv device has multiple services that are discovered by
# zeroconf.
#
self._async_check_and_update_in_progress(host, unique_id)
await asyncio.sleep(DISCOVERY_AGGREGATION_TIME)
self._async_check_in_progress_and_set_address(host, unique_id)
# Scan for the device in order to extract _all_ unique identifiers assigned to
# it. Not doing it like this will yield multiple config flows for the same
# device, one per protocol, which is undesired.
self.scan_filter = host
return await self.async_find_device_wrapper(self.async_found_zeroconf_device)
# Check again after sleeping in case another flow
# has made progress while we yielded to the event loop
self._async_check_and_update_in_progress(host, unique_id)
# Host must only be set AFTER checking and updating in progress
# flows or we will have a race condition where no flows move forward.
self.context[CONF_ADDRESS] = host
@callback
def _async_check_in_progress_and_set_address(self, host: str, unique_id: str):
"""Check for in-progress flows and update them with identifiers if needed.
This code must not await between checking in progress and setting the host
or it will have a race condition where no flows move forward.
"""
def _async_check_and_update_in_progress(self, host: str, unique_id: str) -> None:
"""Check for in-progress flows and update them with identifiers if needed."""
for flow in self._async_in_progress(include_uninitialized=True):
context = flow["context"]
if (
@ -226,7 +244,6 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
# Add potentially new identifiers from this device to the existing flow
context["all_identifiers"].append(unique_id)
raise data_entry_flow.AbortFlow("already_in_progress")
self.context[CONF_ADDRESS] = host
async def async_found_zeroconf_device(self, user_input=None):
"""Handle device found after Zeroconf discovery."""

View file

@ -10,7 +10,11 @@ import pytest
from homeassistant import config_entries, data_entry_flow
from homeassistant.components import zeroconf
from homeassistant.components.apple_tv import CONF_ADDRESS, config_flow
from homeassistant.components.apple_tv.const import CONF_START_OFF, DOMAIN
from homeassistant.components.apple_tv.const import (
CONF_IDENTIFIERS,
CONF_START_OFF,
DOMAIN,
)
from .common import airplay_service, create_conf, mrp_service, raop_service
@ -652,6 +656,45 @@ async def test_zeroconf_ip_change(hass, mock_scan):
assert unrelated_entry.data[CONF_ADDRESS] == "127.0.0.2"
async def test_zeroconf_ip_change_via_secondary_identifier(hass, mock_scan):
"""Test that the config entry gets updated when the ip changes and reloads.
Instead of checking only the unique id, all the identifiers
in the config entry are checked
"""
entry = MockConfigEntry(
domain="apple_tv",
unique_id="aa:bb:cc:dd:ee:ff",
data={CONF_IDENTIFIERS: ["mrpid"], CONF_ADDRESS: "127.0.0.2"},
)
unrelated_entry = MockConfigEntry(
domain="apple_tv", unique_id="unrelated", data={CONF_ADDRESS: "127.0.0.2"}
)
unrelated_entry.add_to_hass(hass)
entry.add_to_hass(hass)
mock_scan.result = [
create_conf(
IPv4Address("127.0.0.1"), "Device", mrp_service(), airplay_service()
)
]
with patch(
"homeassistant.components.apple_tv.async_setup_entry", return_value=True
) as mock_async_setup:
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=DMAP_SERVICE,
)
await hass.async_block_till_done()
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == "already_configured"
assert len(mock_async_setup.mock_calls) == 2
assert entry.data[CONF_ADDRESS] == "127.0.0.1"
assert unrelated_entry.data[CONF_ADDRESS] == "127.0.0.2"
async def test_zeroconf_add_existing_aborts(hass, dmap_device):
"""Test start new zeroconf flow while existing flow is active aborts."""
await hass.config_entries.flow.async_init(