From f5ffef3f7259028122097bed1f763924a8794dfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pierre=20St=C3=A5hl?= Date: Wed, 13 Sep 2017 04:57:31 +0200 Subject: [PATCH] Support specifying no Apple TVs (#9394) --- homeassistant/components/apple_tv.py | 36 ++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/apple_tv.py b/homeassistant/components/apple_tv.py index 7a2ff7610f7..4fce508ba7e 100644 --- a/homeassistant/components/apple_tv.py +++ b/homeassistant/components/apple_tv.py @@ -10,6 +10,7 @@ import logging import voluptuous as vol +from typing import Union, TypeVar, Sequence from homeassistant.const import (CONF_HOST, CONF_NAME, ATTR_ENTITY_ID) from homeassistant.config import load_yaml_config_file from homeassistant.helpers.aiohttp_client import async_get_clientsession @@ -45,8 +46,19 @@ NOTIFICATION_AUTH_TITLE = 'Apple TV Authentication' NOTIFICATION_SCAN_ID = 'apple_tv_scan_notification' NOTIFICATION_SCAN_TITLE = 'Apple TV Scan' +T = TypeVar('T') + + +# This version of ensure_list interprets an empty dict as no value +def ensure_list(value: Union[T, Sequence[T]]) -> Sequence[T]: + """Wrap value in list if it is not one.""" + if value is None or (isinstance(value, dict) and not value): + return [] + return value if isinstance(value, list) else [value] + + CONFIG_SCHEMA = vol.Schema({ - DOMAIN: vol.All(cv.ensure_list, [vol.Schema({ + DOMAIN: vol.All(ensure_list, [vol.Schema({ vol.Required(CONF_HOST): cv.string, vol.Required(CONF_LOGIN_ID): cv.string, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, @@ -133,6 +145,10 @@ def async_setup(hass, config): """Handler for service calls.""" entity_ids = service.data.get(ATTR_ENTITY_ID) + if service.service == SERVICE_SCAN: + hass.async_add_job(scan_for_apple_tvs, hass) + return + if entity_ids: devices = [device for device in hass.data[DATA_ENTITIES] if device.entity_id in entity_ids] @@ -140,16 +156,16 @@ def async_setup(hass, config): devices = hass.data[DATA_ENTITIES] for device in devices: + if service.service != SERVICE_AUTHENTICATE: + continue + atv = device.atv - if service.service == SERVICE_AUTHENTICATE: - credentials = yield from atv.airplay.generate_credentials() - yield from atv.airplay.load_credentials(credentials) - _LOGGER.debug('Generated new credentials: %s', credentials) - yield from atv.airplay.start_authentication() - hass.async_add_job(request_configuration, - hass, config, atv, credentials) - elif service.service == SERVICE_SCAN: - hass.async_add_job(scan_for_apple_tvs, hass) + credentials = yield from atv.airplay.generate_credentials() + yield from atv.airplay.load_credentials(credentials) + _LOGGER.debug('Generated new credentials: %s', credentials) + yield from atv.airplay.start_authentication() + hass.async_add_job(request_configuration, + hass, config, atv, credentials) @asyncio.coroutine def atv_discovered(service, info):