Config flow tradfri (#16665)

* Fix comments

* Add config flow tests

* Fix Tradfri light tests

* Lint

* Remove import group from config flow

* fix stale comments
This commit is contained in:
Paulus Schoutsen 2018-09-19 21:21:43 +02:00 committed by GitHub
parent 3160fa5de8
commit a1c524d372
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 519 additions and 262 deletions

View file

@ -48,6 +48,7 @@ CONFIG_ENTRY_HANDLERS = {
SERVICE_DECONZ: 'deconz', SERVICE_DECONZ: 'deconz',
'google_cast': 'cast', 'google_cast': 'cast',
SERVICE_HUE: 'hue', SERVICE_HUE: 'hue',
SERVICE_IKEA_TRADFRI: 'tradfri',
'sonos': 'sonos', 'sonos': 'sonos',
} }
@ -55,7 +56,6 @@ SERVICE_HANDLERS = {
SERVICE_HASS_IOS_APP: ('ios', None), SERVICE_HASS_IOS_APP: ('ios', None),
SERVICE_NETGEAR: ('device_tracker', None), SERVICE_NETGEAR: ('device_tracker', None),
SERVICE_WEMO: ('wemo', None), SERVICE_WEMO: ('wemo', None),
SERVICE_IKEA_TRADFRI: ('tradfri', None),
SERVICE_HASSIO: ('hassio', None), SERVICE_HASSIO: ('hassio', None),
SERVICE_AXIS: ('axis', None), SERVICE_AXIS: ('axis', None),
SERVICE_APPLE_TV: ('apple_tv', None), SERVICE_APPLE_TV: ('apple_tv', None),

View file

@ -13,8 +13,9 @@ from homeassistant.components.light import (
SUPPORT_COLOR, Light) SUPPORT_COLOR, Light)
from homeassistant.components.light import \ from homeassistant.components.light import \
PLATFORM_SCHEMA as LIGHT_PLATFORM_SCHEMA PLATFORM_SCHEMA as LIGHT_PLATFORM_SCHEMA
from homeassistant.components.tradfri import KEY_GATEWAY, KEY_TRADFRI_GROUPS, \ from homeassistant.components.tradfri import KEY_GATEWAY, KEY_API
KEY_API from homeassistant.components.tradfri.const import (
CONF_IMPORT_GROUPS, CONF_GATEWAY_ID)
import homeassistant.util.color as color_util import homeassistant.util.color as color_util
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -31,28 +32,21 @@ SUPPORTED_FEATURES = SUPPORT_TRANSITION
SUPPORTED_GROUP_FEATURES = SUPPORT_BRIGHTNESS | SUPPORT_TRANSITION SUPPORTED_GROUP_FEATURES = SUPPORT_BRIGHTNESS | SUPPORT_TRANSITION
async def async_setup_platform(hass, config, async def async_setup_entry(hass, config_entry, async_add_entities):
async_add_entities, discovery_info=None): """Load Tradfri lights based on a config entry."""
"""Set up the IKEA Tradfri Light platform.""" gateway_id = config_entry.data[CONF_GATEWAY_ID]
if discovery_info is None: api = hass.data[KEY_API][config_entry.entry_id]
return gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
gateway_id = discovery_info['gateway'] devices_commands = await api(gateway.get_devices())
api = hass.data[KEY_API][gateway_id]
gateway = hass.data[KEY_GATEWAY][gateway_id]
devices_command = gateway.get_devices()
devices_commands = await api(devices_command)
devices = await api(devices_commands) devices = await api(devices_commands)
lights = [dev for dev in devices if dev.has_light_control] lights = [dev for dev in devices if dev.has_light_control]
if lights: if lights:
async_add_entities( async_add_entities(
TradfriLight(light, api, gateway_id) for light in lights) TradfriLight(light, api, gateway_id) for light in lights)
allow_tradfri_groups = hass.data[KEY_TRADFRI_GROUPS][gateway_id] if config_entry.data[CONF_IMPORT_GROUPS]:
if allow_tradfri_groups: groups_commands = await api(gateway.get_groups())
groups_command = gateway.get_groups()
groups_commands = await api(groups_command)
groups = await api(groups_commands) groups = await api(groups_commands)
if groups: if groups:
async_add_entities( async_add_entities(

View file

@ -19,20 +19,14 @@ DEPENDENCIES = ['tradfri']
SCAN_INTERVAL = timedelta(minutes=5) SCAN_INTERVAL = timedelta(minutes=5)
async def async_setup_platform(hass, config, async_add_entities, async def async_setup_entry(hass, config_entry, async_add_entities):
discovery_info=None): """Set up a Tradfri config entry."""
"""Set up the IKEA Tradfri device platform.""" api = hass.data[KEY_API][config_entry.entry_id]
if discovery_info is None: gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
return
gateway_id = discovery_info['gateway'] devices_commands = await api(gateway.get_devices())
api = hass.data[KEY_API][gateway_id]
gateway = hass.data[KEY_GATEWAY][gateway_id]
devices_command = gateway.get_devices()
devices_commands = await api(devices_command)
all_devices = await api(devices_commands) all_devices = await api(devices_commands)
devices = [dev for dev in all_devices if not dev.has_light_control] devices = (dev for dev in all_devices if not dev.has_light_control)
async_add_entities(TradfriDevice(device, api) for device in devices) async_add_entities(TradfriDevice(device, api) for device in devices)

View file

@ -1,173 +0,0 @@
"""
Support for IKEA Tradfri.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/ikea_tradfri/
"""
import logging
from uuid import uuid4
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers import discovery
from homeassistant.const import CONF_HOST
from homeassistant.components.discovery import SERVICE_IKEA_TRADFRI
from homeassistant.util.json import load_json, save_json
REQUIREMENTS = ['pytradfri[async]==5.5.1']
DOMAIN = 'tradfri'
GATEWAY_IDENTITY = 'homeassistant'
CONFIG_FILE = '.tradfri_psk.conf'
KEY_CONFIG = 'tradfri_configuring'
KEY_GATEWAY = 'tradfri_gateway'
KEY_API = 'tradfri_api'
KEY_TRADFRI_GROUPS = 'tradfri_allow_tradfri_groups'
CONF_ALLOW_TRADFRI_GROUPS = 'allow_tradfri_groups'
DEFAULT_ALLOW_TRADFRI_GROUPS = True
CONFIG_SCHEMA = vol.Schema({
DOMAIN: vol.Schema({
vol.Inclusive(CONF_HOST, 'gateway'): cv.string,
vol.Optional(CONF_ALLOW_TRADFRI_GROUPS,
default=DEFAULT_ALLOW_TRADFRI_GROUPS): cv.boolean,
})
}, extra=vol.ALLOW_EXTRA)
_LOGGER = logging.getLogger(__name__)
def request_configuration(hass, config, host):
"""Request configuration steps from the user."""
configurator = hass.components.configurator
hass.data.setdefault(KEY_CONFIG, {})
instance = hass.data[KEY_CONFIG].get(host)
# Configuration already in progress
if instance:
return
async def configuration_callback(callback_data):
"""Handle the submitted configuration."""
try:
from pytradfri.api.aiocoap_api import APIFactory
from pytradfri import RequestError
except ImportError:
_LOGGER.exception("Looks like something isn't installed!")
return
identity = uuid4().hex
security_code = callback_data.get('security_code')
api_factory = APIFactory(host, psk_id=identity, loop=hass.loop)
# Need To Fix: currently entering a wrong security code sends
# pytradfri aiocoap API into an endless loop.
# Should just raise a requestError or something.
try:
key = await api_factory.generate_psk(security_code)
except RequestError:
configurator.async_notify_errors(hass, instance,
"Security Code not accepted.")
return
res = await _setup_gateway(hass, config, host, identity, key,
DEFAULT_ALLOW_TRADFRI_GROUPS)
if not res:
configurator.async_notify_errors(hass, instance,
"Unable to connect.")
return
def success():
"""Set up was successful."""
conf = load_json(hass.config.path(CONFIG_FILE))
conf[host] = {'identity': identity,
'key': key}
save_json(hass.config.path(CONFIG_FILE), conf)
configurator.request_done(instance)
hass.async_add_job(success)
instance = configurator.request_config(
"IKEA Trådfri", configuration_callback,
description='Please enter the security code written at the bottom of '
'your IKEA Trådfri Gateway.',
submit_caption="Confirm",
fields=[{'id': 'security_code', 'name': 'Security Code',
'type': 'password'}]
)
async def async_setup(hass, config):
"""Set up the Tradfri component."""
conf = config.get(DOMAIN, {})
host = conf.get(CONF_HOST)
allow_tradfri_groups = conf.get(CONF_ALLOW_TRADFRI_GROUPS)
known_hosts = await hass.async_add_job(load_json,
hass.config.path(CONFIG_FILE))
async def gateway_discovered(service, info,
allow_groups=DEFAULT_ALLOW_TRADFRI_GROUPS):
"""Run when a gateway is discovered."""
host = info['host']
if host in known_hosts:
# use fallbacks for old config style
# identity was hard coded as 'homeassistant'
identity = known_hosts[host].get('identity', 'homeassistant')
key = known_hosts[host].get('key')
await _setup_gateway(hass, config, host, identity, key,
allow_groups)
else:
hass.async_add_job(request_configuration, hass, config, host)
discovery.async_listen(hass, SERVICE_IKEA_TRADFRI, gateway_discovered)
if host:
await gateway_discovered(None,
{'host': host},
allow_tradfri_groups)
return True
async def _setup_gateway(hass, hass_config, host, identity, key,
allow_tradfri_groups):
"""Create a gateway."""
from pytradfri import Gateway, RequestError # pylint: disable=import-error
try:
from pytradfri.api.aiocoap_api import APIFactory
except ImportError:
_LOGGER.exception("Looks like something isn't installed!")
return False
try:
factory = APIFactory(host, psk_id=identity, psk=key,
loop=hass.loop)
api = factory.request
gateway = Gateway()
gateway_info_result = await api(gateway.get_gateway_info())
except RequestError:
_LOGGER.exception("Tradfri setup failed.")
return False
gateway_id = gateway_info_result.id
hass.data.setdefault(KEY_API, {})
hass.data.setdefault(KEY_GATEWAY, {})
gateways = hass.data[KEY_GATEWAY]
hass.data[KEY_API][gateway_id] = api
hass.data.setdefault(KEY_TRADFRI_GROUPS, {})
tradfri_groups = hass.data[KEY_TRADFRI_GROUPS]
tradfri_groups[gateway_id] = allow_tradfri_groups
# Check if already set up
if gateway_id in gateways:
return True
gateways[gateway_id] = gateway
hass.async_create_task(discovery.async_load_platform(
hass, 'light', DOMAIN, {'gateway': gateway_id}, hass_config))
hass.async_create_task(discovery.async_load_platform(
hass, 'sensor', DOMAIN, {'gateway': gateway_id}, hass_config))
return True

View file

@ -0,0 +1,23 @@
{
"config": {
"abort": {
"already_configured": "Bridge is already configured"
},
"error": {
"cannot_connect": "Unable to connect to the gateway.",
"invalid_key": "Failed to register with provided key. If this keeps happening, try restarting the gateway.",
"timeout": "Timeout validating the code."
},
"step": {
"auth": {
"data": {
"host": "Host",
"security_code": "Security Code"
},
"description": "You can find the security code on the back of your gateway.",
"title": "Enter security code"
}
},
"title": "IKEA TR\u00c5DFRI"
}
}

View file

@ -0,0 +1,102 @@
"""
Support for IKEA Tradfri.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/ikea_tradfri/
"""
import logging
import voluptuous as vol
from homeassistant import config_entries
import homeassistant.helpers.config_validation as cv
from homeassistant.util.json import load_json
from .const import CONF_IMPORT_GROUPS, CONF_IDENTITY, CONF_HOST, CONF_KEY
from . import config_flow # noqa pylint_disable=unused-import
REQUIREMENTS = ['pytradfri[async]==5.5.1']
DOMAIN = 'tradfri'
CONFIG_FILE = '.tradfri_psk.conf'
KEY_GATEWAY = 'tradfri_gateway'
KEY_API = 'tradfri_api'
CONF_ALLOW_TRADFRI_GROUPS = 'allow_tradfri_groups'
DEFAULT_ALLOW_TRADFRI_GROUPS = True
CONFIG_SCHEMA = vol.Schema({
DOMAIN: vol.Schema({
vol.Inclusive(CONF_HOST, 'gateway'): cv.string,
vol.Optional(CONF_ALLOW_TRADFRI_GROUPS,
default=DEFAULT_ALLOW_TRADFRI_GROUPS): cv.boolean,
})
}, extra=vol.ALLOW_EXTRA)
_LOGGER = logging.getLogger(__name__)
async def async_setup(hass, config):
"""Set up the Tradfri component."""
conf = config.get(DOMAIN)
if conf is None:
return True
known_hosts = await hass.async_add_executor_job(
load_json, hass.config.path(CONFIG_FILE))
for host, info in known_hosts.items():
info[CONF_HOST] = host
info[CONF_IMPORT_GROUPS] = conf[CONF_ALLOW_TRADFRI_GROUPS]
hass.async_create_task(hass.config_entries.flow.async_init(
DOMAIN, context={'source': config_entries.SOURCE_IMPORT},
data=info
))
host = conf.get(CONF_HOST)
if host is None or host in known_hosts:
return True
hass.async_create_task(hass.config_entries.flow.async_init(
DOMAIN, context={'source': config_entries.SOURCE_IMPORT},
data={'host': host}
))
return True
async def async_setup_entry(hass, entry):
"""Create a gateway."""
# host, identity, key, allow_tradfri_groups
from pytradfri import Gateway, RequestError # pylint: disable=import-error
from pytradfri.api.aiocoap_api import APIFactory
factory = APIFactory(
entry.data[CONF_HOST],
psk_id=entry.data[CONF_IDENTITY],
psk=entry.data[CONF_KEY],
loop=hass.loop
)
api = factory.request
gateway = Gateway()
try:
await api(gateway.get_gateway_info())
except RequestError:
_LOGGER.error("Tradfri setup failed.")
return False
hass.data.setdefault(KEY_API, {})[entry.entry_id] = api
hass.data.setdefault(KEY_GATEWAY, {})[entry.entry_id] = gateway
hass.async_create_task(hass.config_entries.async_forward_entry_setup(
entry, 'light'
))
hass.async_create_task(hass.config_entries.async_forward_entry_setup(
entry, 'sensor'
))
return True

View file

@ -0,0 +1,172 @@
"""Config flow for Tradfri."""
import asyncio
from collections import OrderedDict
from uuid import uuid4
import async_timeout
import voluptuous as vol
from homeassistant import config_entries
from .const import (
CONF_IMPORT_GROUPS, CONF_IDENTITY, CONF_HOST, CONF_KEY, CONF_GATEWAY_ID)
KEY_HOST = 'host'
KEY_SECURITY_CODE = 'security_code'
KEY_IMPORT_GROUPS = 'import_groups'
class AuthError(Exception):
"""Exception if authentication occurs."""
def __init__(self, code):
"""Initialize exception."""
super().__init__()
self.code = code
@config_entries.HANDLERS.register('tradfri')
class FlowHandler(config_entries.ConfigFlow):
"""Handle a config flow."""
VERSION = 1
def __init__(self):
"""Initialize flow."""
self._host = None
async def async_step_user(self, user_input=None):
"""Handle a flow initialized by the user."""
return await self.async_step_auth()
async def async_step_auth(self, user_input=None):
"""Handle the authentication with a gateway."""
errors = {}
if user_input is not None:
host = user_input.get(KEY_HOST, self._host)
try:
auth = await authenticate(
self.hass, host,
user_input[KEY_SECURITY_CODE])
# We don't ask for import group anymore as group state
# is not reliable, don't want to show that to the user.
auth[CONF_IMPORT_GROUPS] = False
return await self._entry_from_data(auth)
except AuthError as err:
if err.code == 'invalid_security_code':
errors[KEY_SECURITY_CODE] = err.code
else:
errors['base'] = err.code
fields = OrderedDict()
if self._host is None:
fields[vol.Required(KEY_HOST)] = str
fields[vol.Required(KEY_SECURITY_CODE)] = str
return self.async_show_form(
step_id='auth',
data_schema=vol.Schema(fields),
errors=errors,
)
async def async_step_discovery(self, user_input):
"""Handle discovery."""
self._host = user_input['host']
return await self.async_step_auth()
async def async_step_import(self, user_input):
"""Import a config entry."""
for entry in self._async_current_entries():
if entry.data[CONF_HOST] == user_input['host']:
return self.async_abort(
reason='already_configured'
)
# Happens if user has host directly in configuration.yaml
if 'key' not in user_input:
self._host = user_input['host']
return await self.async_step_auth()
try:
data = await get_gateway_info(
self.hass, user_input['host'], user_input['identity'],
user_input['key'])
data[CONF_IMPORT_GROUPS] = user_input[CONF_IMPORT_GROUPS]
return await self._entry_from_data(data)
except AuthError:
# If we fail to connect, just pass it on to discovery
self._host = user_input['host']
return await self.async_step_auth()
async def _entry_from_data(self, data):
"""Create an entry from data."""
host = data[CONF_HOST]
gateway_id = data[CONF_GATEWAY_ID]
same_hub_entries = [entry.entry_id for entry
in self._async_current_entries()
if entry.data[CONF_GATEWAY_ID] == gateway_id or
entry.data[CONF_HOST] == host]
if same_hub_entries:
await asyncio.wait([self.hass.config_entries.async_remove(entry_id)
for entry_id in same_hub_entries])
return self.async_create_entry(
title=host,
data=data
)
async def authenticate(hass, host, security_code):
"""Authenticate with a Tradfri hub."""
from pytradfri.api.aiocoap_api import APIFactory
from pytradfri import RequestError
identity = uuid4().hex
api_factory = APIFactory(host, psk_id=identity, loop=hass.loop)
try:
with async_timeout.timeout(5):
key = await api_factory.generate_psk(security_code)
except RequestError:
raise AuthError('invalid_security_code')
except asyncio.TimeoutError:
raise AuthError('timeout')
return await get_gateway_info(hass, host, identity, key)
async def get_gateway_info(hass, host, identity, key):
"""Return info for the gateway."""
from pytradfri.api.aiocoap_api import APIFactory
from pytradfri import Gateway, RequestError
try:
factory = APIFactory(
host,
psk_id=identity,
psk=key,
loop=hass.loop
)
api = factory.request
gateway = Gateway()
gateway_info_result = await api(gateway.get_gateway_info())
except RequestError:
raise AuthError('cannot_connect')
return {
CONF_HOST: host,
CONF_IDENTITY: identity,
CONF_KEY: key,
CONF_GATEWAY_ID: gateway_info_result.id,
}

View file

@ -0,0 +1,7 @@
"""Consts used by Tradfri."""
from homeassistant.const import CONF_HOST # noqa pylint: disable=unused-import
CONF_IMPORT_GROUPS = 'import_groups'
CONF_IDENTITY = 'identity'
CONF_KEY = 'key'
CONF_GATEWAY_ID = 'gateway_id'

View file

@ -0,0 +1,23 @@
{
"config": {
"title": "IKEA TRÅDFRI",
"step": {
"auth": {
"title": "Enter security code",
"description": "You can find the security code on the back of your gateway.",
"data": {
"host": "Host",
"security_code": "Security Code"
}
}
},
"error": {
"invalid_key": "Failed to register with provided key. If this keeps happening, try restarting the gateway.",
"cannot_connect": "Unable to connect to the gateway.",
"timeout": "Timeout validating the code."
},
"abort": {
"already_configured": "Bridge is already configured"
}
}
}

View file

@ -146,6 +146,7 @@ FLOWS = [
'nest', 'nest',
'openuv', 'openuv',
'sonos', 'sonos',
'tradfri',
'zone', 'zone',
] ]

View file

@ -5,10 +5,10 @@ from unittest.mock import Mock, MagicMock, patch, PropertyMock
import pytest import pytest
from pytradfri.device import Device, LightControl, Light from pytradfri.device import Device, LightControl, Light
from pytradfri import RequestError
from homeassistant.components import tradfri from homeassistant.components import tradfri
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
DEFAULT_TEST_FEATURES = {'can_set_dimmer': False, DEFAULT_TEST_FEATURES = {'can_set_dimmer': False,
@ -199,7 +199,7 @@ def mock_gateway():
@pytest.fixture @pytest.fixture
def mock_api(mock_gateway): def mock_api(mock_gateway):
"""Mock api.""" """Mock api."""
async def api(self, command): async def api(command):
"""Mock api function.""" """Mock api function."""
# Store the data for "real" command objects. # Store the data for "real" command objects.
if(hasattr(command, '_data') and not isinstance(command, Mock)): if(hasattr(command, '_data') and not isinstance(command, Mock)):
@ -213,63 +213,20 @@ async def generate_psk(self, code):
return "mock" return "mock"
async def setup_gateway(hass, mock_gateway, mock_api, async def setup_gateway(hass, mock_gateway, mock_api):
generate_psk=generate_psk,
known_hosts=None):
"""Load the Tradfri platform with a mock gateway.""" """Load the Tradfri platform with a mock gateway."""
def request_config(_, callback, description, submit_caption, fields): entry = MockConfigEntry(domain=tradfri.DOMAIN, data={
"""Mock request_config.""" 'host': 'mock-host',
hass.async_add_job(callback, {'security_code': 'mock'}) 'identity': 'mock-identity',
'key': 'mock-key',
if known_hosts is None: 'import_groups': True,
known_hosts = {} 'gateway_id': 'mock-gateway-id',
})
with patch('pytradfri.api.aiocoap_api.APIFactory.generate_psk', hass.data[tradfri.KEY_GATEWAY] = {entry.entry_id: mock_gateway}
generate_psk), \ hass.data[tradfri.KEY_API] = {entry.entry_id: mock_api}
patch('pytradfri.api.aiocoap_api.APIFactory.request', mock_api), \ await hass.config_entries.async_forward_entry_setup(
patch('pytradfri.Gateway', return_value=mock_gateway), \ entry, 'light'
patch.object(tradfri, 'load_json', return_value=known_hosts), \ )
patch.object(tradfri, 'save_json'), \
patch.object(hass.components.configurator, 'request_config',
request_config):
await async_setup_component(hass, tradfri.DOMAIN,
{
tradfri.DOMAIN: {
'host': 'mock-host',
'allow_tradfri_groups': True
}
})
await hass.async_block_till_done()
async def test_setup_gateway(hass, mock_gateway, mock_api):
"""Test that the gateway can be setup without errors."""
await setup_gateway(hass, mock_gateway, mock_api)
async def test_setup_gateway_known_host(hass, mock_gateway, mock_api):
"""Test gateway setup with a known host."""
await setup_gateway(hass, mock_gateway, mock_api,
known_hosts={
'mock-host': {
'identity': 'mock',
'key': 'mock-key'
}
})
async def test_incorrect_security_code(hass, mock_gateway, mock_api):
"""Test that an error is shown if the security code is incorrect."""
async def psk_error(self, code):
"""Raise RequestError when called."""
raise RequestError
with patch.object(hass.components.configurator, 'async_notify_errors') \
as notify_error:
await setup_gateway(hass, mock_gateway, mock_api,
generate_psk=psk_error)
assert len(notify_error.mock_calls) > 0
def mock_light(test_features={}, test_state={}, n=0): def mock_light(test_features={}, test_state={}, n=0):

View file

@ -0,0 +1 @@
"""Tests for the tradfri component."""

View file

@ -0,0 +1,156 @@
"""Test the Tradfri config flow."""
from unittest.mock import patch
import pytest
from homeassistant import data_entry_flow
from homeassistant.components.tradfri import config_flow
from tests.common import mock_coro
@pytest.fixture
def mock_auth():
"""Mock authenticate."""
with patch('homeassistant.components.tradfri.config_flow.'
'authenticate') as mock_auth:
yield mock_auth
@pytest.fixture
def mock_gateway_info():
"""Mock get_gateway_info."""
with patch('homeassistant.components.tradfri.config_flow.'
'get_gateway_info') as mock_gateway:
yield mock_gateway
@pytest.fixture
def mock_entry_setup():
"""Mock entry setup."""
with patch('homeassistant.components.tradfri.'
'async_setup_entry') as mock_setup:
mock_setup.return_value = mock_coro(True)
yield mock_setup
async def test_user_connection_successful(hass, mock_auth, mock_entry_setup):
"""Test a successful connection."""
mock_auth.side_effect = lambda hass, host, code: mock_coro({
'host': host,
'gateway_id': 'bla'
})
flow = await hass.config_entries.flow.async_init(
'tradfri', context={'source': 'user'})
result = await hass.config_entries.flow.async_configure(flow['flow_id'], {
'host': '123.123.123.123',
'security_code': 'abcd',
})
assert len(mock_entry_setup.mock_calls) == 1
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result['result'].data == {
'host': '123.123.123.123',
'gateway_id': 'bla',
'import_groups': False
}
async def test_user_connection_timeout(hass, mock_auth, mock_entry_setup):
"""Test a connection timeout."""
mock_auth.side_effect = config_flow.AuthError('timeout')
flow = await hass.config_entries.flow.async_init(
'tradfri', context={'source': 'user'})
result = await hass.config_entries.flow.async_configure(flow['flow_id'], {
'host': '127.0.0.1',
'security_code': 'abcd',
})
assert len(mock_entry_setup.mock_calls) == 0
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['errors'] == {
'base': 'timeout'
}
async def test_user_connection_bad_key(hass, mock_auth, mock_entry_setup):
"""Test a connection with bad key."""
mock_auth.side_effect = config_flow.AuthError('invalid_security_code')
flow = await hass.config_entries.flow.async_init(
'tradfri', context={'source': 'user'})
result = await hass.config_entries.flow.async_configure(flow['flow_id'], {
'host': '127.0.0.1',
'security_code': 'abcd',
})
assert len(mock_entry_setup.mock_calls) == 0
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['errors'] == {
'security_code': 'invalid_security_code'
}
async def test_discovery_connection(hass, mock_auth, mock_entry_setup):
"""Test a connection via discovery."""
mock_auth.side_effect = lambda hass, host, code: mock_coro({
'host': host,
'gateway_id': 'bla'
})
flow = await hass.config_entries.flow.async_init(
'tradfri', context={'source': 'discovery'}, data={
'host': '123.123.123.123'
})
result = await hass.config_entries.flow.async_configure(flow['flow_id'], {
'security_code': 'abcd',
})
assert len(mock_entry_setup.mock_calls) == 1
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result['result'].data == {
'host': '123.123.123.123',
'gateway_id': 'bla',
'import_groups': False
}
async def test_import_connection(hass, mock_gateway_info, mock_entry_setup):
"""Test a connection via import."""
mock_gateway_info.side_effect = \
lambda hass, host, identity, key: mock_coro({
'host': host,
'identity': identity,
'key': key,
'gateway_id': 'mock-gateway'
})
result = await hass.config_entries.flow.async_init(
'tradfri', context={'source': 'import'}, data={
'host': '123.123.123.123',
'identity': 'mock-iden',
'key': 'mock-key',
'import_groups': True
})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result['result'].data == {
'host': '123.123.123.123',
'gateway_id': 'mock-gateway',
'identity': 'mock-iden',
'key': 'mock-key',
'import_groups': True
}
assert len(mock_gateway_info.mock_calls) == 1
assert len(mock_entry_setup.mock_calls) == 1