Use zeroconf for scanning in apple_tv (#64528)

This commit is contained in:
J. Nick Koston 2022-01-24 02:07:22 -10:00 committed by GitHub
parent d47a25856b
commit 7112c5b52a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 37 additions and 8 deletions

View file

@ -7,6 +7,7 @@ from pyatv import connect, exceptions, scan
from pyatv.const import DeviceModel, Protocol
from pyatv.convert import model_str
from homeassistant.components import zeroconf
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
ATTR_CONNECTIONS,
@ -269,8 +270,13 @@ class AppleTVManager:
}
_LOGGER.debug("Discovering device %s", self.config_entry.title)
aiozc = await zeroconf.async_get_async_instance(self.hass)
atvs = await scan(
self.hass.loop, identifier=identifiers, protocol=protocols, hosts=[address]
self.hass.loop,
identifier=identifiers,
protocol=protocols,
hosts=[address],
aiozc=aiozc,
)
if atvs:
return atvs[0]

View file

@ -33,7 +33,7 @@ DEFAULT_START_OFF = False
DISCOVERY_AGGREGATION_TIME = 15 # seconds
async def device_scan(identifier, loop):
async def device_scan(hass, identifier, loop):
"""Scan for a specific device using identifier as filter."""
def _filter_device(dev):
@ -53,7 +53,8 @@ async def device_scan(identifier, loop):
# If we have an address, only probe that address to avoid
# broadcast traffic on the network
scan_result = await scan(loop, timeout=3, hosts=_host_filter())
aiozc = await zeroconf.async_get_async_instance(hass)
scan_result = await scan(loop, timeout=3, hosts=_host_filter(), aiozc=aiozc)
matches = [atv for atv in scan_result if _filter_device(atv)]
if matches:
@ -180,6 +181,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
self._abort_if_unique_id_configured(updates={CONF_ADDRESS: host})
self._async_abort_entries_match({CONF_ADDRESS: host})
await self._async_aggregate_discoveries(host, unique_id)
# Scan for the device in order to extract _all_ unique identifiers assigned to
# it. Not doing it like this will yield multiple config flows for the same
@ -279,7 +281,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
async def async_find_device(self, allow_exist=False):
"""Scan for the selected device to discover services."""
self.atv, self.atv_identifiers = await device_scan(
self.scan_filter, self.hass.loop
self.hass, self.scan_filter, self.hass.loop
)
if not self.atv:
raise DeviceNotFound()

View file

@ -3,9 +3,13 @@
"name": "Apple TV",
"config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/apple_tv",
"requirements": ["pyatv==0.9.8"],
"requirements": ["pyatv==0.10.0"],
"dependencies": ["zeroconf"],
"zeroconf": [
"_mediaremotetv._tcp.local.",
"_companion-link._tcp.local.",
"_airport._tcp.local.",
"_sleep-proxy._udp.local.",
"_touch-able._tcp.local.",
"_appletv-v2._tcp.local.",
"_hscp._tcp.local.",

View file

@ -37,6 +37,11 @@ ZEROCONF = {
}
}
],
"_airport._tcp.local.": [
{
"domain": "apple_tv"
}
],
"_api._udp.local.": [
{
"domain": "guardian"
@ -78,6 +83,11 @@ ZEROCONF = {
"domain": "bond"
}
],
"_companion-link._tcp.local.": [
{
"domain": "apple_tv"
}
],
"_daap._tcp.local.": [
{
"domain": "forked_daapd"
@ -302,6 +312,11 @@ ZEROCONF = {
}
}
],
"_sleep-proxy._udp.local.": [
{
"domain": "apple_tv"
}
],
"_sonos._tcp.local.": [
{
"domain": "sonos"

View file

@ -1401,7 +1401,7 @@ pyatmo==6.2.2
pyatome==0.1.1
# homeassistant.components.apple_tv
pyatv==0.9.8
pyatv==0.10.0
# homeassistant.components.aussie_broadband
pyaussiebb==0.0.9

View file

@ -875,7 +875,7 @@ pyatag==0.3.5.3
pyatmo==6.2.2
# homeassistant.components.apple_tv
pyatv==0.9.8
pyatv==0.10.0
# homeassistant.components.aussie_broadband
pyaussiebb==0.0.9

View file

@ -15,7 +15,9 @@ def mock_scan_fixture():
"""Mock pyatv.scan."""
with patch("homeassistant.components.apple_tv.config_flow.scan") as mock_scan:
async def _scan(loop, timeout=5, identifier=None, protocol=None, hosts=None):
async def _scan(
loop, timeout=5, identifier=None, protocol=None, hosts=None, aiozc=None
):
if not mock_scan.hosts:
mock_scan.hosts = hosts
return mock_scan.result