Prevent apple_tv rediscovery from secondary identifiers (#61973)
This commit is contained in:
parent
06c1949d2f
commit
048102e053
2 changed files with 77 additions and 17 deletions
|
@ -1,4 +1,6 @@
|
||||||
"""Config flow for Apple TV integration."""
|
"""Config flow for Apple TV integration."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
|
@ -98,12 +100,19 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
re-used, otherwise the newly discovered identifier is used instead.
|
re-used, otherwise the newly discovered identifier is used instead.
|
||||||
"""
|
"""
|
||||||
all_identifiers = set(self.atv.all_identifiers)
|
all_identifiers = set(self.atv.all_identifiers)
|
||||||
|
if unique_id := self._entry_unique_id_from_identifers(all_identifiers):
|
||||||
|
return unique_id
|
||||||
|
return self.atv.identifier
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _entry_unique_id_from_identifers(self, all_identifiers: set[str]) -> str | None:
|
||||||
|
"""Search existing entries for an identifier and return the unique id."""
|
||||||
for entry in self._async_current_entries():
|
for entry in self._async_current_entries():
|
||||||
if all_identifiers.intersection(
|
if all_identifiers.intersection(
|
||||||
entry.data.get(CONF_IDENTIFIERS, [entry.unique_id])
|
entry.data.get(CONF_IDENTIFIERS, [entry.unique_id])
|
||||||
):
|
):
|
||||||
return entry.unique_id
|
return entry.unique_id
|
||||||
return self.atv.identifier
|
return None
|
||||||
|
|
||||||
async def async_step_reauth(self, user_input=None):
|
async def async_step_reauth(self, user_input=None):
|
||||||
"""Handle initial step when updating invalid credentials."""
|
"""Handle initial step when updating invalid credentials."""
|
||||||
|
@ -166,6 +175,20 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
if unique_id is None:
|
if unique_id is None:
|
||||||
return self.async_abort(reason="unknown")
|
return self.async_abort(reason="unknown")
|
||||||
|
|
||||||
|
if existing_unique_id := self._entry_unique_id_from_identifers({unique_id}):
|
||||||
|
await self.async_set_unique_id(existing_unique_id)
|
||||||
|
self._abort_if_unique_id_configured(updates={CONF_ADDRESS: host})
|
||||||
|
|
||||||
|
self._async_abort_entries_match({CONF_ADDRESS: host})
|
||||||
|
await self._async_aggregate_discoveries(host, unique_id)
|
||||||
|
# Scan for the device in order to extract _all_ unique identifiers assigned to
|
||||||
|
# it. Not doing it like this will yield multiple config flows for the same
|
||||||
|
# device, one per protocol, which is undesired.
|
||||||
|
self.scan_filter = host
|
||||||
|
return await self.async_find_device_wrapper(self.async_found_zeroconf_device)
|
||||||
|
|
||||||
|
async def _async_aggregate_discoveries(self, host: str, unique_id: str) -> None:
|
||||||
|
"""Wait for multiple zeroconf services to be discovered an aggregate them."""
|
||||||
#
|
#
|
||||||
# Suppose we have a device with three services: A, B and C. Let's assume
|
# Suppose we have a device with three services: A, B and C. Let's assume
|
||||||
# service A is discovered by Zeroconf, triggering a device scan that also finds
|
# service A is discovered by Zeroconf, triggering a device scan that also finds
|
||||||
|
@ -195,23 +218,18 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
# apple_tv device has multiple services that are discovered by
|
# apple_tv device has multiple services that are discovered by
|
||||||
# zeroconf.
|
# zeroconf.
|
||||||
#
|
#
|
||||||
|
self._async_check_and_update_in_progress(host, unique_id)
|
||||||
await asyncio.sleep(DISCOVERY_AGGREGATION_TIME)
|
await asyncio.sleep(DISCOVERY_AGGREGATION_TIME)
|
||||||
|
# Check again after sleeping in case another flow
|
||||||
self._async_check_in_progress_and_set_address(host, unique_id)
|
# has made progress while we yielded to the event loop
|
||||||
|
self._async_check_and_update_in_progress(host, unique_id)
|
||||||
# Scan for the device in order to extract _all_ unique identifiers assigned to
|
# Host must only be set AFTER checking and updating in progress
|
||||||
# it. Not doing it like this will yield multiple config flows for the same
|
# flows or we will have a race condition where no flows move forward.
|
||||||
# device, one per protocol, which is undesired.
|
self.context[CONF_ADDRESS] = host
|
||||||
self.scan_filter = host
|
|
||||||
return await self.async_find_device_wrapper(self.async_found_zeroconf_device)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_check_in_progress_and_set_address(self, host: str, unique_id: str):
|
def _async_check_and_update_in_progress(self, host: str, unique_id: str) -> None:
|
||||||
"""Check for in-progress flows and update them with identifiers if needed.
|
"""Check for in-progress flows and update them with identifiers if needed."""
|
||||||
|
|
||||||
This code must not await between checking in progress and setting the host
|
|
||||||
or it will have a race condition where no flows move forward.
|
|
||||||
"""
|
|
||||||
for flow in self._async_in_progress(include_uninitialized=True):
|
for flow in self._async_in_progress(include_uninitialized=True):
|
||||||
context = flow["context"]
|
context = flow["context"]
|
||||||
if (
|
if (
|
||||||
|
@ -226,7 +244,6 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
# Add potentially new identifiers from this device to the existing flow
|
# Add potentially new identifiers from this device to the existing flow
|
||||||
context["all_identifiers"].append(unique_id)
|
context["all_identifiers"].append(unique_id)
|
||||||
raise data_entry_flow.AbortFlow("already_in_progress")
|
raise data_entry_flow.AbortFlow("already_in_progress")
|
||||||
self.context[CONF_ADDRESS] = host
|
|
||||||
|
|
||||||
async def async_found_zeroconf_device(self, user_input=None):
|
async def async_found_zeroconf_device(self, user_input=None):
|
||||||
"""Handle device found after Zeroconf discovery."""
|
"""Handle device found after Zeroconf discovery."""
|
||||||
|
|
|
@ -10,7 +10,11 @@ import pytest
|
||||||
from homeassistant import config_entries, data_entry_flow
|
from homeassistant import config_entries, data_entry_flow
|
||||||
from homeassistant.components import zeroconf
|
from homeassistant.components import zeroconf
|
||||||
from homeassistant.components.apple_tv import CONF_ADDRESS, config_flow
|
from homeassistant.components.apple_tv import CONF_ADDRESS, config_flow
|
||||||
from homeassistant.components.apple_tv.const import CONF_START_OFF, DOMAIN
|
from homeassistant.components.apple_tv.const import (
|
||||||
|
CONF_IDENTIFIERS,
|
||||||
|
CONF_START_OFF,
|
||||||
|
DOMAIN,
|
||||||
|
)
|
||||||
|
|
||||||
from .common import airplay_service, create_conf, mrp_service, raop_service
|
from .common import airplay_service, create_conf, mrp_service, raop_service
|
||||||
|
|
||||||
|
@ -652,6 +656,45 @@ async def test_zeroconf_ip_change(hass, mock_scan):
|
||||||
assert unrelated_entry.data[CONF_ADDRESS] == "127.0.0.2"
|
assert unrelated_entry.data[CONF_ADDRESS] == "127.0.0.2"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_zeroconf_ip_change_via_secondary_identifier(hass, mock_scan):
|
||||||
|
"""Test that the config entry gets updated when the ip changes and reloads.
|
||||||
|
|
||||||
|
Instead of checking only the unique id, all the identifiers
|
||||||
|
in the config entry are checked
|
||||||
|
"""
|
||||||
|
entry = MockConfigEntry(
|
||||||
|
domain="apple_tv",
|
||||||
|
unique_id="aa:bb:cc:dd:ee:ff",
|
||||||
|
data={CONF_IDENTIFIERS: ["mrpid"], CONF_ADDRESS: "127.0.0.2"},
|
||||||
|
)
|
||||||
|
unrelated_entry = MockConfigEntry(
|
||||||
|
domain="apple_tv", unique_id="unrelated", data={CONF_ADDRESS: "127.0.0.2"}
|
||||||
|
)
|
||||||
|
unrelated_entry.add_to_hass(hass)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
mock_scan.result = [
|
||||||
|
create_conf(
|
||||||
|
IPv4Address("127.0.0.1"), "Device", mrp_service(), airplay_service()
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.apple_tv.async_setup_entry", return_value=True
|
||||||
|
) as mock_async_setup:
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN,
|
||||||
|
context={"source": config_entries.SOURCE_ZEROCONF},
|
||||||
|
data=DMAP_SERVICE,
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||||
|
assert result["reason"] == "already_configured"
|
||||||
|
assert len(mock_async_setup.mock_calls) == 2
|
||||||
|
assert entry.data[CONF_ADDRESS] == "127.0.0.1"
|
||||||
|
assert unrelated_entry.data[CONF_ADDRESS] == "127.0.0.2"
|
||||||
|
|
||||||
|
|
||||||
async def test_zeroconf_add_existing_aborts(hass, dmap_device):
|
async def test_zeroconf_add_existing_aborts(hass, dmap_device):
|
||||||
"""Test start new zeroconf flow while existing flow is active aborts."""
|
"""Test start new zeroconf flow while existing flow is active aborts."""
|
||||||
await hass.config_entries.flow.async_init(
|
await hass.config_entries.flow.async_init(
|
||||||
|
|
Loading…
Add table
Reference in a new issue