Config flow tradfri (#16665)
* Fix comments * Add config flow tests * Fix Tradfri light tests * Lint * Remove import group from config flow * fix stale comments
This commit is contained in:
parent
3160fa5de8
commit
a1c524d372
13 changed files with 519 additions and 262 deletions
|
@ -48,6 +48,7 @@ CONFIG_ENTRY_HANDLERS = {
|
||||||
SERVICE_DECONZ: 'deconz',
|
SERVICE_DECONZ: 'deconz',
|
||||||
'google_cast': 'cast',
|
'google_cast': 'cast',
|
||||||
SERVICE_HUE: 'hue',
|
SERVICE_HUE: 'hue',
|
||||||
|
SERVICE_IKEA_TRADFRI: 'tradfri',
|
||||||
'sonos': 'sonos',
|
'sonos': 'sonos',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,7 +56,6 @@ SERVICE_HANDLERS = {
|
||||||
SERVICE_HASS_IOS_APP: ('ios', None),
|
SERVICE_HASS_IOS_APP: ('ios', None),
|
||||||
SERVICE_NETGEAR: ('device_tracker', None),
|
SERVICE_NETGEAR: ('device_tracker', None),
|
||||||
SERVICE_WEMO: ('wemo', None),
|
SERVICE_WEMO: ('wemo', None),
|
||||||
SERVICE_IKEA_TRADFRI: ('tradfri', None),
|
|
||||||
SERVICE_HASSIO: ('hassio', None),
|
SERVICE_HASSIO: ('hassio', None),
|
||||||
SERVICE_AXIS: ('axis', None),
|
SERVICE_AXIS: ('axis', None),
|
||||||
SERVICE_APPLE_TV: ('apple_tv', None),
|
SERVICE_APPLE_TV: ('apple_tv', None),
|
||||||
|
|
|
@ -13,8 +13,9 @@ from homeassistant.components.light import (
|
||||||
SUPPORT_COLOR, Light)
|
SUPPORT_COLOR, Light)
|
||||||
from homeassistant.components.light import \
|
from homeassistant.components.light import \
|
||||||
PLATFORM_SCHEMA as LIGHT_PLATFORM_SCHEMA
|
PLATFORM_SCHEMA as LIGHT_PLATFORM_SCHEMA
|
||||||
from homeassistant.components.tradfri import KEY_GATEWAY, KEY_TRADFRI_GROUPS, \
|
from homeassistant.components.tradfri import KEY_GATEWAY, KEY_API
|
||||||
KEY_API
|
from homeassistant.components.tradfri.const import (
|
||||||
|
CONF_IMPORT_GROUPS, CONF_GATEWAY_ID)
|
||||||
import homeassistant.util.color as color_util
|
import homeassistant.util.color as color_util
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
@ -31,28 +32,21 @@ SUPPORTED_FEATURES = SUPPORT_TRANSITION
|
||||||
SUPPORTED_GROUP_FEATURES = SUPPORT_BRIGHTNESS | SUPPORT_TRANSITION
|
SUPPORTED_GROUP_FEATURES = SUPPORT_BRIGHTNESS | SUPPORT_TRANSITION
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_platform(hass, config,
|
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||||
async_add_entities, discovery_info=None):
|
"""Load Tradfri lights based on a config entry."""
|
||||||
"""Set up the IKEA Tradfri Light platform."""
|
gateway_id = config_entry.data[CONF_GATEWAY_ID]
|
||||||
if discovery_info is None:
|
api = hass.data[KEY_API][config_entry.entry_id]
|
||||||
return
|
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
|
||||||
|
|
||||||
gateway_id = discovery_info['gateway']
|
devices_commands = await api(gateway.get_devices())
|
||||||
api = hass.data[KEY_API][gateway_id]
|
|
||||||
gateway = hass.data[KEY_GATEWAY][gateway_id]
|
|
||||||
|
|
||||||
devices_command = gateway.get_devices()
|
|
||||||
devices_commands = await api(devices_command)
|
|
||||||
devices = await api(devices_commands)
|
devices = await api(devices_commands)
|
||||||
lights = [dev for dev in devices if dev.has_light_control]
|
lights = [dev for dev in devices if dev.has_light_control]
|
||||||
if lights:
|
if lights:
|
||||||
async_add_entities(
|
async_add_entities(
|
||||||
TradfriLight(light, api, gateway_id) for light in lights)
|
TradfriLight(light, api, gateway_id) for light in lights)
|
||||||
|
|
||||||
allow_tradfri_groups = hass.data[KEY_TRADFRI_GROUPS][gateway_id]
|
if config_entry.data[CONF_IMPORT_GROUPS]:
|
||||||
if allow_tradfri_groups:
|
groups_commands = await api(gateway.get_groups())
|
||||||
groups_command = gateway.get_groups()
|
|
||||||
groups_commands = await api(groups_command)
|
|
||||||
groups = await api(groups_commands)
|
groups = await api(groups_commands)
|
||||||
if groups:
|
if groups:
|
||||||
async_add_entities(
|
async_add_entities(
|
||||||
|
|
|
@ -19,20 +19,14 @@ DEPENDENCIES = ['tradfri']
|
||||||
SCAN_INTERVAL = timedelta(minutes=5)
|
SCAN_INTERVAL = timedelta(minutes=5)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_platform(hass, config, async_add_entities,
|
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||||
discovery_info=None):
|
"""Set up a Tradfri config entry."""
|
||||||
"""Set up the IKEA Tradfri device platform."""
|
api = hass.data[KEY_API][config_entry.entry_id]
|
||||||
if discovery_info is None:
|
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
|
||||||
return
|
|
||||||
|
|
||||||
gateway_id = discovery_info['gateway']
|
devices_commands = await api(gateway.get_devices())
|
||||||
api = hass.data[KEY_API][gateway_id]
|
|
||||||
gateway = hass.data[KEY_GATEWAY][gateway_id]
|
|
||||||
|
|
||||||
devices_command = gateway.get_devices()
|
|
||||||
devices_commands = await api(devices_command)
|
|
||||||
all_devices = await api(devices_commands)
|
all_devices = await api(devices_commands)
|
||||||
devices = [dev for dev in all_devices if not dev.has_light_control]
|
devices = (dev for dev in all_devices if not dev.has_light_control)
|
||||||
async_add_entities(TradfriDevice(device, api) for device in devices)
|
async_add_entities(TradfriDevice(device, api) for device in devices)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,173 +0,0 @@
|
||||||
"""
|
|
||||||
Support for IKEA Tradfri.
|
|
||||||
|
|
||||||
For more details about this component, please refer to the documentation at
|
|
||||||
https://home-assistant.io/components/ikea_tradfri/
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import voluptuous as vol
|
|
||||||
|
|
||||||
import homeassistant.helpers.config_validation as cv
|
|
||||||
from homeassistant.helpers import discovery
|
|
||||||
from homeassistant.const import CONF_HOST
|
|
||||||
from homeassistant.components.discovery import SERVICE_IKEA_TRADFRI
|
|
||||||
from homeassistant.util.json import load_json, save_json
|
|
||||||
|
|
||||||
REQUIREMENTS = ['pytradfri[async]==5.5.1']
|
|
||||||
|
|
||||||
DOMAIN = 'tradfri'
|
|
||||||
GATEWAY_IDENTITY = 'homeassistant'
|
|
||||||
CONFIG_FILE = '.tradfri_psk.conf'
|
|
||||||
KEY_CONFIG = 'tradfri_configuring'
|
|
||||||
KEY_GATEWAY = 'tradfri_gateway'
|
|
||||||
KEY_API = 'tradfri_api'
|
|
||||||
KEY_TRADFRI_GROUPS = 'tradfri_allow_tradfri_groups'
|
|
||||||
CONF_ALLOW_TRADFRI_GROUPS = 'allow_tradfri_groups'
|
|
||||||
DEFAULT_ALLOW_TRADFRI_GROUPS = True
|
|
||||||
|
|
||||||
CONFIG_SCHEMA = vol.Schema({
|
|
||||||
DOMAIN: vol.Schema({
|
|
||||||
vol.Inclusive(CONF_HOST, 'gateway'): cv.string,
|
|
||||||
vol.Optional(CONF_ALLOW_TRADFRI_GROUPS,
|
|
||||||
default=DEFAULT_ALLOW_TRADFRI_GROUPS): cv.boolean,
|
|
||||||
})
|
|
||||||
}, extra=vol.ALLOW_EXTRA)
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def request_configuration(hass, config, host):
|
|
||||||
"""Request configuration steps from the user."""
|
|
||||||
configurator = hass.components.configurator
|
|
||||||
hass.data.setdefault(KEY_CONFIG, {})
|
|
||||||
instance = hass.data[KEY_CONFIG].get(host)
|
|
||||||
|
|
||||||
# Configuration already in progress
|
|
||||||
if instance:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def configuration_callback(callback_data):
|
|
||||||
"""Handle the submitted configuration."""
|
|
||||||
try:
|
|
||||||
from pytradfri.api.aiocoap_api import APIFactory
|
|
||||||
from pytradfri import RequestError
|
|
||||||
except ImportError:
|
|
||||||
_LOGGER.exception("Looks like something isn't installed!")
|
|
||||||
return
|
|
||||||
|
|
||||||
identity = uuid4().hex
|
|
||||||
security_code = callback_data.get('security_code')
|
|
||||||
|
|
||||||
api_factory = APIFactory(host, psk_id=identity, loop=hass.loop)
|
|
||||||
# Need To Fix: currently entering a wrong security code sends
|
|
||||||
# pytradfri aiocoap API into an endless loop.
|
|
||||||
# Should just raise a requestError or something.
|
|
||||||
try:
|
|
||||||
key = await api_factory.generate_psk(security_code)
|
|
||||||
except RequestError:
|
|
||||||
configurator.async_notify_errors(hass, instance,
|
|
||||||
"Security Code not accepted.")
|
|
||||||
return
|
|
||||||
|
|
||||||
res = await _setup_gateway(hass, config, host, identity, key,
|
|
||||||
DEFAULT_ALLOW_TRADFRI_GROUPS)
|
|
||||||
|
|
||||||
if not res:
|
|
||||||
configurator.async_notify_errors(hass, instance,
|
|
||||||
"Unable to connect.")
|
|
||||||
return
|
|
||||||
|
|
||||||
def success():
|
|
||||||
"""Set up was successful."""
|
|
||||||
conf = load_json(hass.config.path(CONFIG_FILE))
|
|
||||||
conf[host] = {'identity': identity,
|
|
||||||
'key': key}
|
|
||||||
save_json(hass.config.path(CONFIG_FILE), conf)
|
|
||||||
configurator.request_done(instance)
|
|
||||||
|
|
||||||
hass.async_add_job(success)
|
|
||||||
|
|
||||||
instance = configurator.request_config(
|
|
||||||
"IKEA Trådfri", configuration_callback,
|
|
||||||
description='Please enter the security code written at the bottom of '
|
|
||||||
'your IKEA Trådfri Gateway.',
|
|
||||||
submit_caption="Confirm",
|
|
||||||
fields=[{'id': 'security_code', 'name': 'Security Code',
|
|
||||||
'type': 'password'}]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass, config):
|
|
||||||
"""Set up the Tradfri component."""
|
|
||||||
conf = config.get(DOMAIN, {})
|
|
||||||
host = conf.get(CONF_HOST)
|
|
||||||
allow_tradfri_groups = conf.get(CONF_ALLOW_TRADFRI_GROUPS)
|
|
||||||
known_hosts = await hass.async_add_job(load_json,
|
|
||||||
hass.config.path(CONFIG_FILE))
|
|
||||||
|
|
||||||
async def gateway_discovered(service, info,
|
|
||||||
allow_groups=DEFAULT_ALLOW_TRADFRI_GROUPS):
|
|
||||||
"""Run when a gateway is discovered."""
|
|
||||||
host = info['host']
|
|
||||||
|
|
||||||
if host in known_hosts:
|
|
||||||
# use fallbacks for old config style
|
|
||||||
# identity was hard coded as 'homeassistant'
|
|
||||||
identity = known_hosts[host].get('identity', 'homeassistant')
|
|
||||||
key = known_hosts[host].get('key')
|
|
||||||
await _setup_gateway(hass, config, host, identity, key,
|
|
||||||
allow_groups)
|
|
||||||
else:
|
|
||||||
hass.async_add_job(request_configuration, hass, config, host)
|
|
||||||
|
|
||||||
discovery.async_listen(hass, SERVICE_IKEA_TRADFRI, gateway_discovered)
|
|
||||||
|
|
||||||
if host:
|
|
||||||
await gateway_discovered(None,
|
|
||||||
{'host': host},
|
|
||||||
allow_tradfri_groups)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
async def _setup_gateway(hass, hass_config, host, identity, key,
|
|
||||||
allow_tradfri_groups):
|
|
||||||
"""Create a gateway."""
|
|
||||||
from pytradfri import Gateway, RequestError # pylint: disable=import-error
|
|
||||||
try:
|
|
||||||
from pytradfri.api.aiocoap_api import APIFactory
|
|
||||||
except ImportError:
|
|
||||||
_LOGGER.exception("Looks like something isn't installed!")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
factory = APIFactory(host, psk_id=identity, psk=key,
|
|
||||||
loop=hass.loop)
|
|
||||||
api = factory.request
|
|
||||||
gateway = Gateway()
|
|
||||||
gateway_info_result = await api(gateway.get_gateway_info())
|
|
||||||
except RequestError:
|
|
||||||
_LOGGER.exception("Tradfri setup failed.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
gateway_id = gateway_info_result.id
|
|
||||||
hass.data.setdefault(KEY_API, {})
|
|
||||||
hass.data.setdefault(KEY_GATEWAY, {})
|
|
||||||
gateways = hass.data[KEY_GATEWAY]
|
|
||||||
hass.data[KEY_API][gateway_id] = api
|
|
||||||
|
|
||||||
hass.data.setdefault(KEY_TRADFRI_GROUPS, {})
|
|
||||||
tradfri_groups = hass.data[KEY_TRADFRI_GROUPS]
|
|
||||||
tradfri_groups[gateway_id] = allow_tradfri_groups
|
|
||||||
|
|
||||||
# Check if already set up
|
|
||||||
if gateway_id in gateways:
|
|
||||||
return True
|
|
||||||
|
|
||||||
gateways[gateway_id] = gateway
|
|
||||||
hass.async_create_task(discovery.async_load_platform(
|
|
||||||
hass, 'light', DOMAIN, {'gateway': gateway_id}, hass_config))
|
|
||||||
hass.async_create_task(discovery.async_load_platform(
|
|
||||||
hass, 'sensor', DOMAIN, {'gateway': gateway_id}, hass_config))
|
|
||||||
return True
|
|
23
homeassistant/components/tradfri/.translations/en.json
Normal file
23
homeassistant/components/tradfri/.translations/en.json
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"abort": {
|
||||||
|
"already_configured": "Bridge is already configured"
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"cannot_connect": "Unable to connect to the gateway.",
|
||||||
|
"invalid_key": "Failed to register with provided key. If this keeps happening, try restarting the gateway.",
|
||||||
|
"timeout": "Timeout validating the code."
|
||||||
|
},
|
||||||
|
"step": {
|
||||||
|
"auth": {
|
||||||
|
"data": {
|
||||||
|
"host": "Host",
|
||||||
|
"security_code": "Security Code"
|
||||||
|
},
|
||||||
|
"description": "You can find the security code on the back of your gateway.",
|
||||||
|
"title": "Enter security code"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"title": "IKEA TR\u00c5DFRI"
|
||||||
|
}
|
||||||
|
}
|
102
homeassistant/components/tradfri/__init__.py
Normal file
102
homeassistant/components/tradfri/__init__.py
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
"""
|
||||||
|
Support for IKEA Tradfri.
|
||||||
|
|
||||||
|
For more details about this component, please refer to the documentation at
|
||||||
|
https://home-assistant.io/components/ikea_tradfri/
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant import config_entries
|
||||||
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
from homeassistant.util.json import load_json
|
||||||
|
|
||||||
|
from .const import CONF_IMPORT_GROUPS, CONF_IDENTITY, CONF_HOST, CONF_KEY
|
||||||
|
|
||||||
|
from . import config_flow # noqa pylint_disable=unused-import
|
||||||
|
|
||||||
|
REQUIREMENTS = ['pytradfri[async]==5.5.1']
|
||||||
|
|
||||||
|
DOMAIN = 'tradfri'
|
||||||
|
CONFIG_FILE = '.tradfri_psk.conf'
|
||||||
|
KEY_GATEWAY = 'tradfri_gateway'
|
||||||
|
KEY_API = 'tradfri_api'
|
||||||
|
CONF_ALLOW_TRADFRI_GROUPS = 'allow_tradfri_groups'
|
||||||
|
DEFAULT_ALLOW_TRADFRI_GROUPS = True
|
||||||
|
|
||||||
|
CONFIG_SCHEMA = vol.Schema({
|
||||||
|
DOMAIN: vol.Schema({
|
||||||
|
vol.Inclusive(CONF_HOST, 'gateway'): cv.string,
|
||||||
|
vol.Optional(CONF_ALLOW_TRADFRI_GROUPS,
|
||||||
|
default=DEFAULT_ALLOW_TRADFRI_GROUPS): cv.boolean,
|
||||||
|
})
|
||||||
|
}, extra=vol.ALLOW_EXTRA)
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup(hass, config):
|
||||||
|
"""Set up the Tradfri component."""
|
||||||
|
conf = config.get(DOMAIN)
|
||||||
|
|
||||||
|
if conf is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
known_hosts = await hass.async_add_executor_job(
|
||||||
|
load_json, hass.config.path(CONFIG_FILE))
|
||||||
|
|
||||||
|
for host, info in known_hosts.items():
|
||||||
|
info[CONF_HOST] = host
|
||||||
|
info[CONF_IMPORT_GROUPS] = conf[CONF_ALLOW_TRADFRI_GROUPS]
|
||||||
|
|
||||||
|
hass.async_create_task(hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN, context={'source': config_entries.SOURCE_IMPORT},
|
||||||
|
data=info
|
||||||
|
))
|
||||||
|
|
||||||
|
host = conf.get(CONF_HOST)
|
||||||
|
|
||||||
|
if host is None or host in known_hosts:
|
||||||
|
return True
|
||||||
|
|
||||||
|
hass.async_create_task(hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN, context={'source': config_entries.SOURCE_IMPORT},
|
||||||
|
data={'host': host}
|
||||||
|
))
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(hass, entry):
|
||||||
|
"""Create a gateway."""
|
||||||
|
# host, identity, key, allow_tradfri_groups
|
||||||
|
from pytradfri import Gateway, RequestError # pylint: disable=import-error
|
||||||
|
from pytradfri.api.aiocoap_api import APIFactory
|
||||||
|
|
||||||
|
factory = APIFactory(
|
||||||
|
entry.data[CONF_HOST],
|
||||||
|
psk_id=entry.data[CONF_IDENTITY],
|
||||||
|
psk=entry.data[CONF_KEY],
|
||||||
|
loop=hass.loop
|
||||||
|
)
|
||||||
|
api = factory.request
|
||||||
|
gateway = Gateway()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await api(gateway.get_gateway_info())
|
||||||
|
except RequestError:
|
||||||
|
_LOGGER.error("Tradfri setup failed.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
hass.data.setdefault(KEY_API, {})[entry.entry_id] = api
|
||||||
|
hass.data.setdefault(KEY_GATEWAY, {})[entry.entry_id] = gateway
|
||||||
|
|
||||||
|
hass.async_create_task(hass.config_entries.async_forward_entry_setup(
|
||||||
|
entry, 'light'
|
||||||
|
))
|
||||||
|
hass.async_create_task(hass.config_entries.async_forward_entry_setup(
|
||||||
|
entry, 'sensor'
|
||||||
|
))
|
||||||
|
|
||||||
|
return True
|
172
homeassistant/components/tradfri/config_flow.py
Normal file
172
homeassistant/components/tradfri/config_flow.py
Normal file
|
@ -0,0 +1,172 @@
|
||||||
|
"""Config flow for Tradfri."""
|
||||||
|
import asyncio
|
||||||
|
from collections import OrderedDict
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import async_timeout
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant import config_entries
|
||||||
|
|
||||||
|
from .const import (
|
||||||
|
CONF_IMPORT_GROUPS, CONF_IDENTITY, CONF_HOST, CONF_KEY, CONF_GATEWAY_ID)
|
||||||
|
|
||||||
|
KEY_HOST = 'host'
|
||||||
|
KEY_SECURITY_CODE = 'security_code'
|
||||||
|
KEY_IMPORT_GROUPS = 'import_groups'
|
||||||
|
|
||||||
|
|
||||||
|
class AuthError(Exception):
|
||||||
|
"""Exception if authentication occurs."""
|
||||||
|
|
||||||
|
def __init__(self, code):
|
||||||
|
"""Initialize exception."""
|
||||||
|
super().__init__()
|
||||||
|
self.code = code
|
||||||
|
|
||||||
|
|
||||||
|
@config_entries.HANDLERS.register('tradfri')
|
||||||
|
class FlowHandler(config_entries.ConfigFlow):
|
||||||
|
"""Handle a config flow."""
|
||||||
|
|
||||||
|
VERSION = 1
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize flow."""
|
||||||
|
self._host = None
|
||||||
|
|
||||||
|
async def async_step_user(self, user_input=None):
|
||||||
|
"""Handle a flow initialized by the user."""
|
||||||
|
return await self.async_step_auth()
|
||||||
|
|
||||||
|
async def async_step_auth(self, user_input=None):
|
||||||
|
"""Handle the authentication with a gateway."""
|
||||||
|
errors = {}
|
||||||
|
|
||||||
|
if user_input is not None:
|
||||||
|
host = user_input.get(KEY_HOST, self._host)
|
||||||
|
try:
|
||||||
|
auth = await authenticate(
|
||||||
|
self.hass, host,
|
||||||
|
user_input[KEY_SECURITY_CODE])
|
||||||
|
|
||||||
|
# We don't ask for import group anymore as group state
|
||||||
|
# is not reliable, don't want to show that to the user.
|
||||||
|
auth[CONF_IMPORT_GROUPS] = False
|
||||||
|
|
||||||
|
return await self._entry_from_data(auth)
|
||||||
|
|
||||||
|
except AuthError as err:
|
||||||
|
if err.code == 'invalid_security_code':
|
||||||
|
errors[KEY_SECURITY_CODE] = err.code
|
||||||
|
else:
|
||||||
|
errors['base'] = err.code
|
||||||
|
|
||||||
|
fields = OrderedDict()
|
||||||
|
|
||||||
|
if self._host is None:
|
||||||
|
fields[vol.Required(KEY_HOST)] = str
|
||||||
|
|
||||||
|
fields[vol.Required(KEY_SECURITY_CODE)] = str
|
||||||
|
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id='auth',
|
||||||
|
data_schema=vol.Schema(fields),
|
||||||
|
errors=errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_step_discovery(self, user_input):
|
||||||
|
"""Handle discovery."""
|
||||||
|
self._host = user_input['host']
|
||||||
|
return await self.async_step_auth()
|
||||||
|
|
||||||
|
async def async_step_import(self, user_input):
|
||||||
|
"""Import a config entry."""
|
||||||
|
for entry in self._async_current_entries():
|
||||||
|
if entry.data[CONF_HOST] == user_input['host']:
|
||||||
|
return self.async_abort(
|
||||||
|
reason='already_configured'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Happens if user has host directly in configuration.yaml
|
||||||
|
if 'key' not in user_input:
|
||||||
|
self._host = user_input['host']
|
||||||
|
return await self.async_step_auth()
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await get_gateway_info(
|
||||||
|
self.hass, user_input['host'], user_input['identity'],
|
||||||
|
user_input['key'])
|
||||||
|
|
||||||
|
data[CONF_IMPORT_GROUPS] = user_input[CONF_IMPORT_GROUPS]
|
||||||
|
|
||||||
|
return await self._entry_from_data(data)
|
||||||
|
except AuthError:
|
||||||
|
# If we fail to connect, just pass it on to discovery
|
||||||
|
self._host = user_input['host']
|
||||||
|
return await self.async_step_auth()
|
||||||
|
|
||||||
|
async def _entry_from_data(self, data):
|
||||||
|
"""Create an entry from data."""
|
||||||
|
host = data[CONF_HOST]
|
||||||
|
gateway_id = data[CONF_GATEWAY_ID]
|
||||||
|
|
||||||
|
same_hub_entries = [entry.entry_id for entry
|
||||||
|
in self._async_current_entries()
|
||||||
|
if entry.data[CONF_GATEWAY_ID] == gateway_id or
|
||||||
|
entry.data[CONF_HOST] == host]
|
||||||
|
|
||||||
|
if same_hub_entries:
|
||||||
|
await asyncio.wait([self.hass.config_entries.async_remove(entry_id)
|
||||||
|
for entry_id in same_hub_entries])
|
||||||
|
|
||||||
|
return self.async_create_entry(
|
||||||
|
title=host,
|
||||||
|
data=data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def authenticate(hass, host, security_code):
|
||||||
|
"""Authenticate with a Tradfri hub."""
|
||||||
|
from pytradfri.api.aiocoap_api import APIFactory
|
||||||
|
from pytradfri import RequestError
|
||||||
|
|
||||||
|
identity = uuid4().hex
|
||||||
|
|
||||||
|
api_factory = APIFactory(host, psk_id=identity, loop=hass.loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with async_timeout.timeout(5):
|
||||||
|
key = await api_factory.generate_psk(security_code)
|
||||||
|
except RequestError:
|
||||||
|
raise AuthError('invalid_security_code')
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise AuthError('timeout')
|
||||||
|
|
||||||
|
return await get_gateway_info(hass, host, identity, key)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_gateway_info(hass, host, identity, key):
|
||||||
|
"""Return info for the gateway."""
|
||||||
|
from pytradfri.api.aiocoap_api import APIFactory
|
||||||
|
from pytradfri import Gateway, RequestError
|
||||||
|
|
||||||
|
try:
|
||||||
|
factory = APIFactory(
|
||||||
|
host,
|
||||||
|
psk_id=identity,
|
||||||
|
psk=key,
|
||||||
|
loop=hass.loop
|
||||||
|
)
|
||||||
|
api = factory.request
|
||||||
|
gateway = Gateway()
|
||||||
|
gateway_info_result = await api(gateway.get_gateway_info())
|
||||||
|
except RequestError:
|
||||||
|
raise AuthError('cannot_connect')
|
||||||
|
|
||||||
|
return {
|
||||||
|
CONF_HOST: host,
|
||||||
|
CONF_IDENTITY: identity,
|
||||||
|
CONF_KEY: key,
|
||||||
|
CONF_GATEWAY_ID: gateway_info_result.id,
|
||||||
|
}
|
7
homeassistant/components/tradfri/const.py
Normal file
7
homeassistant/components/tradfri/const.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
"""Consts used by Tradfri."""
|
||||||
|
from homeassistant.const import CONF_HOST # noqa pylint: disable=unused-import
|
||||||
|
|
||||||
|
CONF_IMPORT_GROUPS = 'import_groups'
|
||||||
|
CONF_IDENTITY = 'identity'
|
||||||
|
CONF_KEY = 'key'
|
||||||
|
CONF_GATEWAY_ID = 'gateway_id'
|
23
homeassistant/components/tradfri/strings.json
Normal file
23
homeassistant/components/tradfri/strings.json
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"title": "IKEA TRÅDFRI",
|
||||||
|
"step": {
|
||||||
|
"auth": {
|
||||||
|
"title": "Enter security code",
|
||||||
|
"description": "You can find the security code on the back of your gateway.",
|
||||||
|
"data": {
|
||||||
|
"host": "Host",
|
||||||
|
"security_code": "Security Code"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"invalid_key": "Failed to register with provided key. If this keeps happening, try restarting the gateway.",
|
||||||
|
"cannot_connect": "Unable to connect to the gateway.",
|
||||||
|
"timeout": "Timeout validating the code."
|
||||||
|
},
|
||||||
|
"abort": {
|
||||||
|
"already_configured": "Bridge is already configured"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -146,6 +146,7 @@ FLOWS = [
|
||||||
'nest',
|
'nest',
|
||||||
'openuv',
|
'openuv',
|
||||||
'sonos',
|
'sonos',
|
||||||
|
'tradfri',
|
||||||
'zone',
|
'zone',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,10 @@ from unittest.mock import Mock, MagicMock, patch, PropertyMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytradfri.device import Device, LightControl, Light
|
from pytradfri.device import Device, LightControl, Light
|
||||||
from pytradfri import RequestError
|
|
||||||
|
|
||||||
from homeassistant.components import tradfri
|
from homeassistant.components import tradfri
|
||||||
from homeassistant.setup import async_setup_component
|
|
||||||
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TEST_FEATURES = {'can_set_dimmer': False,
|
DEFAULT_TEST_FEATURES = {'can_set_dimmer': False,
|
||||||
|
@ -199,7 +199,7 @@ def mock_gateway():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_api(mock_gateway):
|
def mock_api(mock_gateway):
|
||||||
"""Mock api."""
|
"""Mock api."""
|
||||||
async def api(self, command):
|
async def api(command):
|
||||||
"""Mock api function."""
|
"""Mock api function."""
|
||||||
# Store the data for "real" command objects.
|
# Store the data for "real" command objects.
|
||||||
if(hasattr(command, '_data') and not isinstance(command, Mock)):
|
if(hasattr(command, '_data') and not isinstance(command, Mock)):
|
||||||
|
@ -213,63 +213,20 @@ async def generate_psk(self, code):
|
||||||
return "mock"
|
return "mock"
|
||||||
|
|
||||||
|
|
||||||
async def setup_gateway(hass, mock_gateway, mock_api,
|
async def setup_gateway(hass, mock_gateway, mock_api):
|
||||||
generate_psk=generate_psk,
|
|
||||||
known_hosts=None):
|
|
||||||
"""Load the Tradfri platform with a mock gateway."""
|
"""Load the Tradfri platform with a mock gateway."""
|
||||||
def request_config(_, callback, description, submit_caption, fields):
|
entry = MockConfigEntry(domain=tradfri.DOMAIN, data={
|
||||||
"""Mock request_config."""
|
'host': 'mock-host',
|
||||||
hass.async_add_job(callback, {'security_code': 'mock'})
|
'identity': 'mock-identity',
|
||||||
|
'key': 'mock-key',
|
||||||
if known_hosts is None:
|
'import_groups': True,
|
||||||
known_hosts = {}
|
'gateway_id': 'mock-gateway-id',
|
||||||
|
})
|
||||||
with patch('pytradfri.api.aiocoap_api.APIFactory.generate_psk',
|
hass.data[tradfri.KEY_GATEWAY] = {entry.entry_id: mock_gateway}
|
||||||
generate_psk), \
|
hass.data[tradfri.KEY_API] = {entry.entry_id: mock_api}
|
||||||
patch('pytradfri.api.aiocoap_api.APIFactory.request', mock_api), \
|
await hass.config_entries.async_forward_entry_setup(
|
||||||
patch('pytradfri.Gateway', return_value=mock_gateway), \
|
entry, 'light'
|
||||||
patch.object(tradfri, 'load_json', return_value=known_hosts), \
|
)
|
||||||
patch.object(tradfri, 'save_json'), \
|
|
||||||
patch.object(hass.components.configurator, 'request_config',
|
|
||||||
request_config):
|
|
||||||
|
|
||||||
await async_setup_component(hass, tradfri.DOMAIN,
|
|
||||||
{
|
|
||||||
tradfri.DOMAIN: {
|
|
||||||
'host': 'mock-host',
|
|
||||||
'allow_tradfri_groups': True
|
|
||||||
}
|
|
||||||
})
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_gateway(hass, mock_gateway, mock_api):
|
|
||||||
"""Test that the gateway can be setup without errors."""
|
|
||||||
await setup_gateway(hass, mock_gateway, mock_api)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_gateway_known_host(hass, mock_gateway, mock_api):
|
|
||||||
"""Test gateway setup with a known host."""
|
|
||||||
await setup_gateway(hass, mock_gateway, mock_api,
|
|
||||||
known_hosts={
|
|
||||||
'mock-host': {
|
|
||||||
'identity': 'mock',
|
|
||||||
'key': 'mock-key'
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
async def test_incorrect_security_code(hass, mock_gateway, mock_api):
|
|
||||||
"""Test that an error is shown if the security code is incorrect."""
|
|
||||||
async def psk_error(self, code):
|
|
||||||
"""Raise RequestError when called."""
|
|
||||||
raise RequestError
|
|
||||||
|
|
||||||
with patch.object(hass.components.configurator, 'async_notify_errors') \
|
|
||||||
as notify_error:
|
|
||||||
await setup_gateway(hass, mock_gateway, mock_api,
|
|
||||||
generate_psk=psk_error)
|
|
||||||
assert len(notify_error.mock_calls) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def mock_light(test_features={}, test_state={}, n=0):
|
def mock_light(test_features={}, test_state={}, n=0):
|
||||||
|
|
1
tests/components/tradfri/__init__.py
Normal file
1
tests/components/tradfri/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
"""Tests for the tradfri component."""
|
156
tests/components/tradfri/test_config_flow.py
Normal file
156
tests/components/tradfri/test_config_flow.py
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
"""Test the Tradfri config flow."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant import data_entry_flow
|
||||||
|
from homeassistant.components.tradfri import config_flow
|
||||||
|
|
||||||
|
from tests.common import mock_coro
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_auth():
|
||||||
|
"""Mock authenticate."""
|
||||||
|
with patch('homeassistant.components.tradfri.config_flow.'
|
||||||
|
'authenticate') as mock_auth:
|
||||||
|
yield mock_auth
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_gateway_info():
|
||||||
|
"""Mock get_gateway_info."""
|
||||||
|
with patch('homeassistant.components.tradfri.config_flow.'
|
||||||
|
'get_gateway_info') as mock_gateway:
|
||||||
|
yield mock_gateway
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_entry_setup():
|
||||||
|
"""Mock entry setup."""
|
||||||
|
with patch('homeassistant.components.tradfri.'
|
||||||
|
'async_setup_entry') as mock_setup:
|
||||||
|
mock_setup.return_value = mock_coro(True)
|
||||||
|
yield mock_setup
|
||||||
|
|
||||||
|
|
||||||
|
async def test_user_connection_successful(hass, mock_auth, mock_entry_setup):
|
||||||
|
"""Test a successful connection."""
|
||||||
|
mock_auth.side_effect = lambda hass, host, code: mock_coro({
|
||||||
|
'host': host,
|
||||||
|
'gateway_id': 'bla'
|
||||||
|
})
|
||||||
|
|
||||||
|
flow = await hass.config_entries.flow.async_init(
|
||||||
|
'tradfri', context={'source': 'user'})
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(flow['flow_id'], {
|
||||||
|
'host': '123.123.123.123',
|
||||||
|
'security_code': 'abcd',
|
||||||
|
})
|
||||||
|
|
||||||
|
assert len(mock_entry_setup.mock_calls) == 1
|
||||||
|
|
||||||
|
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
|
assert result['result'].data == {
|
||||||
|
'host': '123.123.123.123',
|
||||||
|
'gateway_id': 'bla',
|
||||||
|
'import_groups': False
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_user_connection_timeout(hass, mock_auth, mock_entry_setup):
|
||||||
|
"""Test a connection timeout."""
|
||||||
|
mock_auth.side_effect = config_flow.AuthError('timeout')
|
||||||
|
|
||||||
|
flow = await hass.config_entries.flow.async_init(
|
||||||
|
'tradfri', context={'source': 'user'})
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(flow['flow_id'], {
|
||||||
|
'host': '127.0.0.1',
|
||||||
|
'security_code': 'abcd',
|
||||||
|
})
|
||||||
|
|
||||||
|
assert len(mock_entry_setup.mock_calls) == 0
|
||||||
|
|
||||||
|
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result['errors'] == {
|
||||||
|
'base': 'timeout'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_user_connection_bad_key(hass, mock_auth, mock_entry_setup):
|
||||||
|
"""Test a connection with bad key."""
|
||||||
|
mock_auth.side_effect = config_flow.AuthError('invalid_security_code')
|
||||||
|
|
||||||
|
flow = await hass.config_entries.flow.async_init(
|
||||||
|
'tradfri', context={'source': 'user'})
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(flow['flow_id'], {
|
||||||
|
'host': '127.0.0.1',
|
||||||
|
'security_code': 'abcd',
|
||||||
|
})
|
||||||
|
|
||||||
|
assert len(mock_entry_setup.mock_calls) == 0
|
||||||
|
|
||||||
|
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result['errors'] == {
|
||||||
|
'security_code': 'invalid_security_code'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_discovery_connection(hass, mock_auth, mock_entry_setup):
|
||||||
|
"""Test a connection via discovery."""
|
||||||
|
mock_auth.side_effect = lambda hass, host, code: mock_coro({
|
||||||
|
'host': host,
|
||||||
|
'gateway_id': 'bla'
|
||||||
|
})
|
||||||
|
|
||||||
|
flow = await hass.config_entries.flow.async_init(
|
||||||
|
'tradfri', context={'source': 'discovery'}, data={
|
||||||
|
'host': '123.123.123.123'
|
||||||
|
})
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(flow['flow_id'], {
|
||||||
|
'security_code': 'abcd',
|
||||||
|
})
|
||||||
|
|
||||||
|
assert len(mock_entry_setup.mock_calls) == 1
|
||||||
|
|
||||||
|
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
|
assert result['result'].data == {
|
||||||
|
'host': '123.123.123.123',
|
||||||
|
'gateway_id': 'bla',
|
||||||
|
'import_groups': False
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_import_connection(hass, mock_gateway_info, mock_entry_setup):
|
||||||
|
"""Test a connection via import."""
|
||||||
|
mock_gateway_info.side_effect = \
|
||||||
|
lambda hass, host, identity, key: mock_coro({
|
||||||
|
'host': host,
|
||||||
|
'identity': identity,
|
||||||
|
'key': key,
|
||||||
|
'gateway_id': 'mock-gateway'
|
||||||
|
})
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
'tradfri', context={'source': 'import'}, data={
|
||||||
|
'host': '123.123.123.123',
|
||||||
|
'identity': 'mock-iden',
|
||||||
|
'key': 'mock-key',
|
||||||
|
'import_groups': True
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
|
assert result['result'].data == {
|
||||||
|
'host': '123.123.123.123',
|
||||||
|
'gateway_id': 'mock-gateway',
|
||||||
|
'identity': 'mock-iden',
|
||||||
|
'key': 'mock-key',
|
||||||
|
'import_groups': True
|
||||||
|
}
|
||||||
|
|
||||||
|
assert len(mock_gateway_info.mock_calls) == 1
|
||||||
|
assert len(mock_entry_setup.mock_calls) == 1
|
Loading…
Add table
Add a link
Reference in a new issue