diff --git a/homeassistant/components/discovery.py b/homeassistant/components/discovery.py index 41cf3791256..22d7ae87b8d 100644 --- a/homeassistant/components/discovery.py +++ b/homeassistant/components/discovery.py @@ -48,6 +48,7 @@ CONFIG_ENTRY_HANDLERS = { SERVICE_DECONZ: 'deconz', 'google_cast': 'cast', SERVICE_HUE: 'hue', + SERVICE_IKEA_TRADFRI: 'tradfri', 'sonos': 'sonos', } @@ -55,7 +56,6 @@ SERVICE_HANDLERS = { SERVICE_HASS_IOS_APP: ('ios', None), SERVICE_NETGEAR: ('device_tracker', None), SERVICE_WEMO: ('wemo', None), - SERVICE_IKEA_TRADFRI: ('tradfri', None), SERVICE_HASSIO: ('hassio', None), SERVICE_AXIS: ('axis', None), SERVICE_APPLE_TV: ('apple_tv', None), diff --git a/homeassistant/components/light/tradfri.py b/homeassistant/components/light/tradfri.py index 0d12d095bb6..bd432b5dedc 100644 --- a/homeassistant/components/light/tradfri.py +++ b/homeassistant/components/light/tradfri.py @@ -13,8 +13,9 @@ from homeassistant.components.light import ( SUPPORT_COLOR, Light) from homeassistant.components.light import \ PLATFORM_SCHEMA as LIGHT_PLATFORM_SCHEMA -from homeassistant.components.tradfri import KEY_GATEWAY, KEY_TRADFRI_GROUPS, \ - KEY_API +from homeassistant.components.tradfri import KEY_GATEWAY, KEY_API +from homeassistant.components.tradfri.const import ( + CONF_IMPORT_GROUPS, CONF_GATEWAY_ID) import homeassistant.util.color as color_util _LOGGER = logging.getLogger(__name__) @@ -31,28 +32,21 @@ SUPPORTED_FEATURES = SUPPORT_TRANSITION SUPPORTED_GROUP_FEATURES = SUPPORT_BRIGHTNESS | SUPPORT_TRANSITION -async def async_setup_platform(hass, config, - async_add_entities, discovery_info=None): - """Set up the IKEA Tradfri Light platform.""" - if discovery_info is None: - return +async def async_setup_entry(hass, config_entry, async_add_entities): + """Load Tradfri lights based on a config entry.""" + gateway_id = config_entry.data[CONF_GATEWAY_ID] + api = hass.data[KEY_API][config_entry.entry_id] + gateway = hass.data[KEY_GATEWAY][config_entry.entry_id] - gateway_id = discovery_info['gateway'] - api = hass.data[KEY_API][gateway_id] - gateway = hass.data[KEY_GATEWAY][gateway_id] - - devices_command = gateway.get_devices() - devices_commands = await api(devices_command) + devices_commands = await api(gateway.get_devices()) devices = await api(devices_commands) lights = [dev for dev in devices if dev.has_light_control] if lights: async_add_entities( TradfriLight(light, api, gateway_id) for light in lights) - allow_tradfri_groups = hass.data[KEY_TRADFRI_GROUPS][gateway_id] - if allow_tradfri_groups: - groups_command = gateway.get_groups() - groups_commands = await api(groups_command) + if config_entry.data[CONF_IMPORT_GROUPS]: + groups_commands = await api(gateway.get_groups()) groups = await api(groups_commands) if groups: async_add_entities( diff --git a/homeassistant/components/sensor/tradfri.py b/homeassistant/components/sensor/tradfri.py index 0849169b747..86d0c1abc19 100644 --- a/homeassistant/components/sensor/tradfri.py +++ b/homeassistant/components/sensor/tradfri.py @@ -19,20 +19,14 @@ DEPENDENCIES = ['tradfri'] SCAN_INTERVAL = timedelta(minutes=5) -async def async_setup_platform(hass, config, async_add_entities, - discovery_info=None): - """Set up the IKEA Tradfri device platform.""" - if discovery_info is None: - return +async def async_setup_entry(hass, config_entry, async_add_entities): + """Set up a Tradfri config entry.""" + api = hass.data[KEY_API][config_entry.entry_id] + gateway = hass.data[KEY_GATEWAY][config_entry.entry_id] - gateway_id = discovery_info['gateway'] - api = hass.data[KEY_API][gateway_id] - gateway = hass.data[KEY_GATEWAY][gateway_id] - - devices_command = gateway.get_devices() - devices_commands = await api(devices_command) + devices_commands = await api(gateway.get_devices()) all_devices = await api(devices_commands) - devices = [dev for dev in all_devices if not dev.has_light_control] + devices = (dev for dev in all_devices if not dev.has_light_control) async_add_entities(TradfriDevice(device, api) for device in devices) diff --git a/homeassistant/components/tradfri.py b/homeassistant/components/tradfri.py deleted file mode 100644 index b2e41902552..00000000000 --- a/homeassistant/components/tradfri.py +++ /dev/null @@ -1,173 +0,0 @@ -""" -Support for IKEA Tradfri. - -For more details about this component, please refer to the documentation at -https://home-assistant.io/components/ikea_tradfri/ -""" -import logging -from uuid import uuid4 - -import voluptuous as vol - -import homeassistant.helpers.config_validation as cv -from homeassistant.helpers import discovery -from homeassistant.const import CONF_HOST -from homeassistant.components.discovery import SERVICE_IKEA_TRADFRI -from homeassistant.util.json import load_json, save_json - -REQUIREMENTS = ['pytradfri[async]==5.5.1'] - -DOMAIN = 'tradfri' -GATEWAY_IDENTITY = 'homeassistant' -CONFIG_FILE = '.tradfri_psk.conf' -KEY_CONFIG = 'tradfri_configuring' -KEY_GATEWAY = 'tradfri_gateway' -KEY_API = 'tradfri_api' -KEY_TRADFRI_GROUPS = 'tradfri_allow_tradfri_groups' -CONF_ALLOW_TRADFRI_GROUPS = 'allow_tradfri_groups' -DEFAULT_ALLOW_TRADFRI_GROUPS = True - -CONFIG_SCHEMA = vol.Schema({ - DOMAIN: vol.Schema({ - vol.Inclusive(CONF_HOST, 'gateway'): cv.string, - vol.Optional(CONF_ALLOW_TRADFRI_GROUPS, - default=DEFAULT_ALLOW_TRADFRI_GROUPS): cv.boolean, - }) -}, extra=vol.ALLOW_EXTRA) - -_LOGGER = logging.getLogger(__name__) - - -def request_configuration(hass, config, host): - """Request configuration steps from the user.""" - configurator = hass.components.configurator - hass.data.setdefault(KEY_CONFIG, {}) - instance = hass.data[KEY_CONFIG].get(host) - - # Configuration already in progress - if instance: - return - - async def configuration_callback(callback_data): - """Handle the submitted configuration.""" - try: - from pytradfri.api.aiocoap_api import APIFactory - from pytradfri import RequestError - except ImportError: - _LOGGER.exception("Looks like something isn't installed!") - return - - identity = uuid4().hex - security_code = callback_data.get('security_code') - - api_factory = APIFactory(host, psk_id=identity, loop=hass.loop) - # Need To Fix: currently entering a wrong security code sends - # pytradfri aiocoap API into an endless loop. - # Should just raise a requestError or something. - try: - key = await api_factory.generate_psk(security_code) - except RequestError: - configurator.async_notify_errors(hass, instance, - "Security Code not accepted.") - return - - res = await _setup_gateway(hass, config, host, identity, key, - DEFAULT_ALLOW_TRADFRI_GROUPS) - - if not res: - configurator.async_notify_errors(hass, instance, - "Unable to connect.") - return - - def success(): - """Set up was successful.""" - conf = load_json(hass.config.path(CONFIG_FILE)) - conf[host] = {'identity': identity, - 'key': key} - save_json(hass.config.path(CONFIG_FILE), conf) - configurator.request_done(instance) - - hass.async_add_job(success) - - instance = configurator.request_config( - "IKEA Trådfri", configuration_callback, - description='Please enter the security code written at the bottom of ' - 'your IKEA Trådfri Gateway.', - submit_caption="Confirm", - fields=[{'id': 'security_code', 'name': 'Security Code', - 'type': 'password'}] - ) - - -async def async_setup(hass, config): - """Set up the Tradfri component.""" - conf = config.get(DOMAIN, {}) - host = conf.get(CONF_HOST) - allow_tradfri_groups = conf.get(CONF_ALLOW_TRADFRI_GROUPS) - known_hosts = await hass.async_add_job(load_json, - hass.config.path(CONFIG_FILE)) - - async def gateway_discovered(service, info, - allow_groups=DEFAULT_ALLOW_TRADFRI_GROUPS): - """Run when a gateway is discovered.""" - host = info['host'] - - if host in known_hosts: - # use fallbacks for old config style - # identity was hard coded as 'homeassistant' - identity = known_hosts[host].get('identity', 'homeassistant') - key = known_hosts[host].get('key') - await _setup_gateway(hass, config, host, identity, key, - allow_groups) - else: - hass.async_add_job(request_configuration, hass, config, host) - - discovery.async_listen(hass, SERVICE_IKEA_TRADFRI, gateway_discovered) - - if host: - await gateway_discovered(None, - {'host': host}, - allow_tradfri_groups) - return True - - -async def _setup_gateway(hass, hass_config, host, identity, key, - allow_tradfri_groups): - """Create a gateway.""" - from pytradfri import Gateway, RequestError # pylint: disable=import-error - try: - from pytradfri.api.aiocoap_api import APIFactory - except ImportError: - _LOGGER.exception("Looks like something isn't installed!") - return False - - try: - factory = APIFactory(host, psk_id=identity, psk=key, - loop=hass.loop) - api = factory.request - gateway = Gateway() - gateway_info_result = await api(gateway.get_gateway_info()) - except RequestError: - _LOGGER.exception("Tradfri setup failed.") - return False - - gateway_id = gateway_info_result.id - hass.data.setdefault(KEY_API, {}) - hass.data.setdefault(KEY_GATEWAY, {}) - gateways = hass.data[KEY_GATEWAY] - hass.data[KEY_API][gateway_id] = api - - hass.data.setdefault(KEY_TRADFRI_GROUPS, {}) - tradfri_groups = hass.data[KEY_TRADFRI_GROUPS] - tradfri_groups[gateway_id] = allow_tradfri_groups - - # Check if already set up - if gateway_id in gateways: - return True - - gateways[gateway_id] = gateway - hass.async_create_task(discovery.async_load_platform( - hass, 'light', DOMAIN, {'gateway': gateway_id}, hass_config)) - hass.async_create_task(discovery.async_load_platform( - hass, 'sensor', DOMAIN, {'gateway': gateway_id}, hass_config)) - return True diff --git a/homeassistant/components/tradfri/.translations/en.json b/homeassistant/components/tradfri/.translations/en.json new file mode 100644 index 00000000000..7b0d2005c2a --- /dev/null +++ b/homeassistant/components/tradfri/.translations/en.json @@ -0,0 +1,23 @@ +{ + "config": { + "abort": { + "already_configured": "Bridge is already configured" + }, + "error": { + "cannot_connect": "Unable to connect to the gateway.", + "invalid_key": "Failed to register with provided key. If this keeps happening, try restarting the gateway.", + "timeout": "Timeout validating the code." + }, + "step": { + "auth": { + "data": { + "host": "Host", + "security_code": "Security Code" + }, + "description": "You can find the security code on the back of your gateway.", + "title": "Enter security code" + } + }, + "title": "IKEA TR\u00c5DFRI" + } +} \ No newline at end of file diff --git a/homeassistant/components/tradfri/__init__.py b/homeassistant/components/tradfri/__init__.py new file mode 100644 index 00000000000..771f2b44c3d --- /dev/null +++ b/homeassistant/components/tradfri/__init__.py @@ -0,0 +1,102 @@ +""" +Support for IKEA Tradfri. + +For more details about this component, please refer to the documentation at +https://home-assistant.io/components/ikea_tradfri/ +""" +import logging + +import voluptuous as vol + +from homeassistant import config_entries +import homeassistant.helpers.config_validation as cv +from homeassistant.util.json import load_json + +from .const import CONF_IMPORT_GROUPS, CONF_IDENTITY, CONF_HOST, CONF_KEY + +from . import config_flow # noqa pylint_disable=unused-import + +REQUIREMENTS = ['pytradfri[async]==5.5.1'] + +DOMAIN = 'tradfri' +CONFIG_FILE = '.tradfri_psk.conf' +KEY_GATEWAY = 'tradfri_gateway' +KEY_API = 'tradfri_api' +CONF_ALLOW_TRADFRI_GROUPS = 'allow_tradfri_groups' +DEFAULT_ALLOW_TRADFRI_GROUPS = True + +CONFIG_SCHEMA = vol.Schema({ + DOMAIN: vol.Schema({ + vol.Inclusive(CONF_HOST, 'gateway'): cv.string, + vol.Optional(CONF_ALLOW_TRADFRI_GROUPS, + default=DEFAULT_ALLOW_TRADFRI_GROUPS): cv.boolean, + }) +}, extra=vol.ALLOW_EXTRA) + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup(hass, config): + """Set up the Tradfri component.""" + conf = config.get(DOMAIN) + + if conf is None: + return True + + known_hosts = await hass.async_add_executor_job( + load_json, hass.config.path(CONFIG_FILE)) + + for host, info in known_hosts.items(): + info[CONF_HOST] = host + info[CONF_IMPORT_GROUPS] = conf[CONF_ALLOW_TRADFRI_GROUPS] + + hass.async_create_task(hass.config_entries.flow.async_init( + DOMAIN, context={'source': config_entries.SOURCE_IMPORT}, + data=info + )) + + host = conf.get(CONF_HOST) + + if host is None or host in known_hosts: + return True + + hass.async_create_task(hass.config_entries.flow.async_init( + DOMAIN, context={'source': config_entries.SOURCE_IMPORT}, + data={'host': host} + )) + + return True + + +async def async_setup_entry(hass, entry): + """Create a gateway.""" + # host, identity, key, allow_tradfri_groups + from pytradfri import Gateway, RequestError # pylint: disable=import-error + from pytradfri.api.aiocoap_api import APIFactory + + factory = APIFactory( + entry.data[CONF_HOST], + psk_id=entry.data[CONF_IDENTITY], + psk=entry.data[CONF_KEY], + loop=hass.loop + ) + api = factory.request + gateway = Gateway() + + try: + await api(gateway.get_gateway_info()) + except RequestError: + _LOGGER.error("Tradfri setup failed.") + return False + + hass.data.setdefault(KEY_API, {})[entry.entry_id] = api + hass.data.setdefault(KEY_GATEWAY, {})[entry.entry_id] = gateway + + hass.async_create_task(hass.config_entries.async_forward_entry_setup( + entry, 'light' + )) + hass.async_create_task(hass.config_entries.async_forward_entry_setup( + entry, 'sensor' + )) + + return True diff --git a/homeassistant/components/tradfri/config_flow.py b/homeassistant/components/tradfri/config_flow.py new file mode 100644 index 00000000000..4de43c79e0c --- /dev/null +++ b/homeassistant/components/tradfri/config_flow.py @@ -0,0 +1,172 @@ +"""Config flow for Tradfri.""" +import asyncio +from collections import OrderedDict +from uuid import uuid4 + +import async_timeout +import voluptuous as vol + +from homeassistant import config_entries + +from .const import ( + CONF_IMPORT_GROUPS, CONF_IDENTITY, CONF_HOST, CONF_KEY, CONF_GATEWAY_ID) + +KEY_HOST = 'host' +KEY_SECURITY_CODE = 'security_code' +KEY_IMPORT_GROUPS = 'import_groups' + + +class AuthError(Exception): + """Exception if authentication occurs.""" + + def __init__(self, code): + """Initialize exception.""" + super().__init__() + self.code = code + + +@config_entries.HANDLERS.register('tradfri') +class FlowHandler(config_entries.ConfigFlow): + """Handle a config flow.""" + + VERSION = 1 + + def __init__(self): + """Initialize flow.""" + self._host = None + + async def async_step_user(self, user_input=None): + """Handle a flow initialized by the user.""" + return await self.async_step_auth() + + async def async_step_auth(self, user_input=None): + """Handle the authentication with a gateway.""" + errors = {} + + if user_input is not None: + host = user_input.get(KEY_HOST, self._host) + try: + auth = await authenticate( + self.hass, host, + user_input[KEY_SECURITY_CODE]) + + # We don't ask for import group anymore as group state + # is not reliable, don't want to show that to the user. + auth[CONF_IMPORT_GROUPS] = False + + return await self._entry_from_data(auth) + + except AuthError as err: + if err.code == 'invalid_security_code': + errors[KEY_SECURITY_CODE] = err.code + else: + errors['base'] = err.code + + fields = OrderedDict() + + if self._host is None: + fields[vol.Required(KEY_HOST)] = str + + fields[vol.Required(KEY_SECURITY_CODE)] = str + + return self.async_show_form( + step_id='auth', + data_schema=vol.Schema(fields), + errors=errors, + ) + + async def async_step_discovery(self, user_input): + """Handle discovery.""" + self._host = user_input['host'] + return await self.async_step_auth() + + async def async_step_import(self, user_input): + """Import a config entry.""" + for entry in self._async_current_entries(): + if entry.data[CONF_HOST] == user_input['host']: + return self.async_abort( + reason='already_configured' + ) + + # Happens if user has host directly in configuration.yaml + if 'key' not in user_input: + self._host = user_input['host'] + return await self.async_step_auth() + + try: + data = await get_gateway_info( + self.hass, user_input['host'], user_input['identity'], + user_input['key']) + + data[CONF_IMPORT_GROUPS] = user_input[CONF_IMPORT_GROUPS] + + return await self._entry_from_data(data) + except AuthError: + # If we fail to connect, just pass it on to discovery + self._host = user_input['host'] + return await self.async_step_auth() + + async def _entry_from_data(self, data): + """Create an entry from data.""" + host = data[CONF_HOST] + gateway_id = data[CONF_GATEWAY_ID] + + same_hub_entries = [entry.entry_id for entry + in self._async_current_entries() + if entry.data[CONF_GATEWAY_ID] == gateway_id or + entry.data[CONF_HOST] == host] + + if same_hub_entries: + await asyncio.wait([self.hass.config_entries.async_remove(entry_id) + for entry_id in same_hub_entries]) + + return self.async_create_entry( + title=host, + data=data + ) + + +async def authenticate(hass, host, security_code): + """Authenticate with a Tradfri hub.""" + from pytradfri.api.aiocoap_api import APIFactory + from pytradfri import RequestError + + identity = uuid4().hex + + api_factory = APIFactory(host, psk_id=identity, loop=hass.loop) + + try: + with async_timeout.timeout(5): + key = await api_factory.generate_psk(security_code) + except RequestError: + raise AuthError('invalid_security_code') + except asyncio.TimeoutError: + raise AuthError('timeout') + + return await get_gateway_info(hass, host, identity, key) + + +async def get_gateway_info(hass, host, identity, key): + """Return info for the gateway.""" + from pytradfri.api.aiocoap_api import APIFactory + from pytradfri import Gateway, RequestError + + try: + factory = APIFactory( + host, + psk_id=identity, + psk=key, + loop=hass.loop + ) + api = factory.request + gateway = Gateway() + gateway_info_result = await api(gateway.get_gateway_info()) + except RequestError: + raise AuthError('cannot_connect') + + return { + CONF_HOST: host, + CONF_IDENTITY: identity, + CONF_KEY: key, + CONF_GATEWAY_ID: gateway_info_result.id, + } diff --git a/homeassistant/components/tradfri/const.py b/homeassistant/components/tradfri/const.py new file mode 100644 index 00000000000..15177bc1a20 --- /dev/null +++ b/homeassistant/components/tradfri/const.py @@ -0,0 +1,7 @@ +"""Consts used by Tradfri.""" +from homeassistant.const import CONF_HOST # noqa pylint: disable=unused-import + +CONF_IMPORT_GROUPS = 'import_groups' +CONF_IDENTITY = 'identity' +CONF_KEY = 'key' +CONF_GATEWAY_ID = 'gateway_id' diff --git a/homeassistant/components/tradfri/strings.json b/homeassistant/components/tradfri/strings.json new file mode 100644 index 00000000000..38c58486a6a --- /dev/null +++ b/homeassistant/components/tradfri/strings.json @@ -0,0 +1,23 @@ +{ + "config": { + "title": "IKEA TRÅDFRI", + "step": { + "auth": { + "title": "Enter security code", + "description": "You can find the security code on the back of your gateway.", + "data": { + "host": "Host", + "security_code": "Security Code" + } + } + }, + "error": { + "invalid_key": "Failed to register with provided key. If this keeps happening, try restarting the gateway.", + "cannot_connect": "Unable to connect to the gateway.", + "timeout": "Timeout validating the code." + }, + "abort": { + "already_configured": "Bridge is already configured" + } + } +} diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index e4c4b5c0327..7763594e0e1 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -146,6 +146,7 @@ FLOWS = [ 'nest', 'openuv', 'sonos', + 'tradfri', 'zone', ] diff --git a/tests/components/light/test_tradfri.py b/tests/components/light/test_tradfri.py index 12c596f3f09..337031cf92c 100644 --- a/tests/components/light/test_tradfri.py +++ b/tests/components/light/test_tradfri.py @@ -5,10 +5,10 @@ from unittest.mock import Mock, MagicMock, patch, PropertyMock import pytest from pytradfri.device import Device, LightControl, Light -from pytradfri import RequestError from homeassistant.components import tradfri -from homeassistant.setup import async_setup_component + +from tests.common import MockConfigEntry DEFAULT_TEST_FEATURES = {'can_set_dimmer': False, @@ -199,7 +199,7 @@ def mock_gateway(): @pytest.fixture def mock_api(mock_gateway): """Mock api.""" - async def api(self, command): + async def api(command): """Mock api function.""" # Store the data for "real" command objects. if(hasattr(command, '_data') and not isinstance(command, Mock)): @@ -213,63 +213,20 @@ async def generate_psk(self, code): return "mock" -async def setup_gateway(hass, mock_gateway, mock_api, - generate_psk=generate_psk, - known_hosts=None): +async def setup_gateway(hass, mock_gateway, mock_api): """Load the Tradfri platform with a mock gateway.""" - def request_config(_, callback, description, submit_caption, fields): - """Mock request_config.""" - hass.async_add_job(callback, {'security_code': 'mock'}) - - if known_hosts is None: - known_hosts = {} - - with patch('pytradfri.api.aiocoap_api.APIFactory.generate_psk', - generate_psk), \ - patch('pytradfri.api.aiocoap_api.APIFactory.request', mock_api), \ - patch('pytradfri.Gateway', return_value=mock_gateway), \ - patch.object(tradfri, 'load_json', return_value=known_hosts), \ - patch.object(tradfri, 'save_json'), \ - patch.object(hass.components.configurator, 'request_config', - request_config): - - await async_setup_component(hass, tradfri.DOMAIN, - { - tradfri.DOMAIN: { - 'host': 'mock-host', - 'allow_tradfri_groups': True - } - }) - await hass.async_block_till_done() - - -async def test_setup_gateway(hass, mock_gateway, mock_api): - """Test that the gateway can be setup without errors.""" - await setup_gateway(hass, mock_gateway, mock_api) - - -async def test_setup_gateway_known_host(hass, mock_gateway, mock_api): - """Test gateway setup with a known host.""" - await setup_gateway(hass, mock_gateway, mock_api, - known_hosts={ - 'mock-host': { - 'identity': 'mock', - 'key': 'mock-key' - } - }) - - -async def test_incorrect_security_code(hass, mock_gateway, mock_api): - """Test that an error is shown if the security code is incorrect.""" - async def psk_error(self, code): - """Raise RequestError when called.""" - raise RequestError - - with patch.object(hass.components.configurator, 'async_notify_errors') \ - as notify_error: - await setup_gateway(hass, mock_gateway, mock_api, - generate_psk=psk_error) - assert len(notify_error.mock_calls) > 0 + entry = MockConfigEntry(domain=tradfri.DOMAIN, data={ + 'host': 'mock-host', + 'identity': 'mock-identity', + 'key': 'mock-key', + 'import_groups': True, + 'gateway_id': 'mock-gateway-id', + }) + hass.data[tradfri.KEY_GATEWAY] = {entry.entry_id: mock_gateway} + hass.data[tradfri.KEY_API] = {entry.entry_id: mock_api} + await hass.config_entries.async_forward_entry_setup( + entry, 'light' + ) def mock_light(test_features={}, test_state={}, n=0): diff --git a/tests/components/tradfri/__init__.py b/tests/components/tradfri/__init__.py new file mode 100644 index 00000000000..4d1b505abc9 --- /dev/null +++ b/tests/components/tradfri/__init__.py @@ -0,0 +1 @@ +"""Tests for the tradfri component.""" diff --git a/tests/components/tradfri/test_config_flow.py b/tests/components/tradfri/test_config_flow.py new file mode 100644 index 00000000000..4650fb5d9bc --- /dev/null +++ b/tests/components/tradfri/test_config_flow.py @@ -0,0 +1,156 @@ +"""Test the Tradfri config flow.""" +from unittest.mock import patch + +import pytest + +from homeassistant import data_entry_flow +from homeassistant.components.tradfri import config_flow + +from tests.common import mock_coro + + +@pytest.fixture +def mock_auth(): + """Mock authenticate.""" + with patch('homeassistant.components.tradfri.config_flow.' + 'authenticate') as mock_auth: + yield mock_auth + + +@pytest.fixture +def mock_gateway_info(): + """Mock get_gateway_info.""" + with patch('homeassistant.components.tradfri.config_flow.' + 'get_gateway_info') as mock_gateway: + yield mock_gateway + + +@pytest.fixture +def mock_entry_setup(): + """Mock entry setup.""" + with patch('homeassistant.components.tradfri.' + 'async_setup_entry') as mock_setup: + mock_setup.return_value = mock_coro(True) + yield mock_setup + + +async def test_user_connection_successful(hass, mock_auth, mock_entry_setup): + """Test a successful connection.""" + mock_auth.side_effect = lambda hass, host, code: mock_coro({ + 'host': host, + 'gateway_id': 'bla' + }) + + flow = await hass.config_entries.flow.async_init( + 'tradfri', context={'source': 'user'}) + + result = await hass.config_entries.flow.async_configure(flow['flow_id'], { + 'host': '123.123.123.123', + 'security_code': 'abcd', + }) + + assert len(mock_entry_setup.mock_calls) == 1 + + assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result['result'].data == { + 'host': '123.123.123.123', + 'gateway_id': 'bla', + 'import_groups': False + } + + +async def test_user_connection_timeout(hass, mock_auth, mock_entry_setup): + """Test a connection timeout.""" + mock_auth.side_effect = config_flow.AuthError('timeout') + + flow = await hass.config_entries.flow.async_init( + 'tradfri', context={'source': 'user'}) + + result = await hass.config_entries.flow.async_configure(flow['flow_id'], { + 'host': '127.0.0.1', + 'security_code': 'abcd', + }) + + assert len(mock_entry_setup.mock_calls) == 0 + + assert result['type'] == data_entry_flow.RESULT_TYPE_FORM + assert result['errors'] == { + 'base': 'timeout' + } + + +async def test_user_connection_bad_key(hass, mock_auth, mock_entry_setup): + """Test a connection with bad key.""" + mock_auth.side_effect = config_flow.AuthError('invalid_security_code') + + flow = await hass.config_entries.flow.async_init( + 'tradfri', context={'source': 'user'}) + + result = await hass.config_entries.flow.async_configure(flow['flow_id'], { + 'host': '127.0.0.1', + 'security_code': 'abcd', + }) + + assert len(mock_entry_setup.mock_calls) == 0 + + assert result['type'] == data_entry_flow.RESULT_TYPE_FORM + assert result['errors'] == { + 'security_code': 'invalid_security_code' + } + + +async def test_discovery_connection(hass, mock_auth, mock_entry_setup): + """Test a connection via discovery.""" + mock_auth.side_effect = lambda hass, host, code: mock_coro({ + 'host': host, + 'gateway_id': 'bla' + }) + + flow = await hass.config_entries.flow.async_init( + 'tradfri', context={'source': 'discovery'}, data={ + 'host': '123.123.123.123' + }) + + result = await hass.config_entries.flow.async_configure(flow['flow_id'], { + 'security_code': 'abcd', + }) + + assert len(mock_entry_setup.mock_calls) == 1 + + assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result['result'].data == { + 'host': '123.123.123.123', + 'gateway_id': 'bla', + 'import_groups': False + } + + +async def test_import_connection(hass, mock_gateway_info, mock_entry_setup): + """Test a connection via import.""" + mock_gateway_info.side_effect = \ + lambda hass, host, identity, key: mock_coro({ + 'host': host, + 'identity': identity, + 'key': key, + 'gateway_id': 'mock-gateway' + }) + + result = await hass.config_entries.flow.async_init( + 'tradfri', context={'source': 'import'}, data={ + 'host': '123.123.123.123', + 'identity': 'mock-iden', + 'key': 'mock-key', + 'import_groups': True + }) + + assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result['result'].data == { + 'host': '123.123.123.123', + 'gateway_id': 'mock-gateway', + 'identity': 'mock-iden', + 'key': 'mock-key', + 'import_groups': True + } + + assert len(mock_gateway_info.mock_calls) == 1 + assert len(mock_entry_setup.mock_calls) == 1