Prevent apple_tv rediscovery from secondary identifiers (#61973)
This commit is contained in:
parent
06c1949d2f
commit
048102e053
2 changed files with 77 additions and 17 deletions
|
@ -1,4 +1,6 @@
|
|||
"""Config flow for Apple TV integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from ipaddress import ip_address
|
||||
|
@ -98,12 +100,19 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
re-used, otherwise the newly discovered identifier is used instead.
|
||||
"""
|
||||
all_identifiers = set(self.atv.all_identifiers)
|
||||
if unique_id := self._entry_unique_id_from_identifers(all_identifiers):
|
||||
return unique_id
|
||||
return self.atv.identifier
|
||||
|
||||
@callback
|
||||
def _entry_unique_id_from_identifers(self, all_identifiers: set[str]) -> str | None:
|
||||
"""Search existing entries for an identifier and return the unique id."""
|
||||
for entry in self._async_current_entries():
|
||||
if all_identifiers.intersection(
|
||||
entry.data.get(CONF_IDENTIFIERS, [entry.unique_id])
|
||||
):
|
||||
return entry.unique_id
|
||||
return self.atv.identifier
|
||||
return None
|
||||
|
||||
async def async_step_reauth(self, user_input=None):
|
||||
"""Handle initial step when updating invalid credentials."""
|
||||
|
@ -166,6 +175,20 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
if unique_id is None:
|
||||
return self.async_abort(reason="unknown")
|
||||
|
||||
if existing_unique_id := self._entry_unique_id_from_identifers({unique_id}):
|
||||
await self.async_set_unique_id(existing_unique_id)
|
||||
self._abort_if_unique_id_configured(updates={CONF_ADDRESS: host})
|
||||
|
||||
self._async_abort_entries_match({CONF_ADDRESS: host})
|
||||
await self._async_aggregate_discoveries(host, unique_id)
|
||||
# Scan for the device in order to extract _all_ unique identifiers assigned to
|
||||
# it. Not doing it like this will yield multiple config flows for the same
|
||||
# device, one per protocol, which is undesired.
|
||||
self.scan_filter = host
|
||||
return await self.async_find_device_wrapper(self.async_found_zeroconf_device)
|
||||
|
||||
async def _async_aggregate_discoveries(self, host: str, unique_id: str) -> None:
|
||||
"""Wait for multiple zeroconf services to be discovered an aggregate them."""
|
||||
#
|
||||
# Suppose we have a device with three services: A, B and C. Let's assume
|
||||
# service A is discovered by Zeroconf, triggering a device scan that also finds
|
||||
|
@ -195,23 +218,18 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
# apple_tv device has multiple services that are discovered by
|
||||
# zeroconf.
|
||||
#
|
||||
self._async_check_and_update_in_progress(host, unique_id)
|
||||
await asyncio.sleep(DISCOVERY_AGGREGATION_TIME)
|
||||
|
||||
self._async_check_in_progress_and_set_address(host, unique_id)
|
||||
|
||||
# Scan for the device in order to extract _all_ unique identifiers assigned to
|
||||
# it. Not doing it like this will yield multiple config flows for the same
|
||||
# device, one per protocol, which is undesired.
|
||||
self.scan_filter = host
|
||||
return await self.async_find_device_wrapper(self.async_found_zeroconf_device)
|
||||
# Check again after sleeping in case another flow
|
||||
# has made progress while we yielded to the event loop
|
||||
self._async_check_and_update_in_progress(host, unique_id)
|
||||
# Host must only be set AFTER checking and updating in progress
|
||||
# flows or we will have a race condition where no flows move forward.
|
||||
self.context[CONF_ADDRESS] = host
|
||||
|
||||
@callback
|
||||
def _async_check_in_progress_and_set_address(self, host: str, unique_id: str):
|
||||
"""Check for in-progress flows and update them with identifiers if needed.
|
||||
|
||||
This code must not await between checking in progress and setting the host
|
||||
or it will have a race condition where no flows move forward.
|
||||
"""
|
||||
def _async_check_and_update_in_progress(self, host: str, unique_id: str) -> None:
|
||||
"""Check for in-progress flows and update them with identifiers if needed."""
|
||||
for flow in self._async_in_progress(include_uninitialized=True):
|
||||
context = flow["context"]
|
||||
if (
|
||||
|
@ -226,7 +244,6 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
# Add potentially new identifiers from this device to the existing flow
|
||||
context["all_identifiers"].append(unique_id)
|
||||
raise data_entry_flow.AbortFlow("already_in_progress")
|
||||
self.context[CONF_ADDRESS] = host
|
||||
|
||||
async def async_found_zeroconf_device(self, user_input=None):
|
||||
"""Handle device found after Zeroconf discovery."""
|
||||
|
|
|
@ -10,7 +10,11 @@ import pytest
|
|||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant.components import zeroconf
|
||||
from homeassistant.components.apple_tv import CONF_ADDRESS, config_flow
|
||||
from homeassistant.components.apple_tv.const import CONF_START_OFF, DOMAIN
|
||||
from homeassistant.components.apple_tv.const import (
|
||||
CONF_IDENTIFIERS,
|
||||
CONF_START_OFF,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
from .common import airplay_service, create_conf, mrp_service, raop_service
|
||||
|
||||
|
@ -652,6 +656,45 @@ async def test_zeroconf_ip_change(hass, mock_scan):
|
|||
assert unrelated_entry.data[CONF_ADDRESS] == "127.0.0.2"
|
||||
|
||||
|
||||
async def test_zeroconf_ip_change_via_secondary_identifier(hass, mock_scan):
|
||||
"""Test that the config entry gets updated when the ip changes and reloads.
|
||||
|
||||
Instead of checking only the unique id, all the identifiers
|
||||
in the config entry are checked
|
||||
"""
|
||||
entry = MockConfigEntry(
|
||||
domain="apple_tv",
|
||||
unique_id="aa:bb:cc:dd:ee:ff",
|
||||
data={CONF_IDENTIFIERS: ["mrpid"], CONF_ADDRESS: "127.0.0.2"},
|
||||
)
|
||||
unrelated_entry = MockConfigEntry(
|
||||
domain="apple_tv", unique_id="unrelated", data={CONF_ADDRESS: "127.0.0.2"}
|
||||
)
|
||||
unrelated_entry.add_to_hass(hass)
|
||||
entry.add_to_hass(hass)
|
||||
mock_scan.result = [
|
||||
create_conf(
|
||||
IPv4Address("127.0.0.1"), "Device", mrp_service(), airplay_service()
|
||||
)
|
||||
]
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.apple_tv.async_setup_entry", return_value=True
|
||||
) as mock_async_setup:
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": config_entries.SOURCE_ZEROCONF},
|
||||
data=DMAP_SERVICE,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "already_configured"
|
||||
assert len(mock_async_setup.mock_calls) == 2
|
||||
assert entry.data[CONF_ADDRESS] == "127.0.0.1"
|
||||
assert unrelated_entry.data[CONF_ADDRESS] == "127.0.0.2"
|
||||
|
||||
|
||||
async def test_zeroconf_add_existing_aborts(hass, dmap_device):
|
||||
"""Test start new zeroconf flow while existing flow is active aborts."""
|
||||
await hass.config_entries.flow.async_init(
|
||||
|
|
Loading…
Add table
Reference in a new issue