Add config flow for Hue (#12830)
* Add config flow for Hue * Upgrade to aiohue 0.2 * Fix tests * Add tests * Add aiohue to test requirements * Bump aiohue dependency * Lint * Lint * Fix aiohttp mock * Lint * Fix tests
This commit is contained in:
parent
d06807c634
commit
67c49a7662
12 changed files with 389 additions and 65 deletions
|
@ -163,7 +163,7 @@ class ConfigManagerFlowResourceView(HomeAssistantView):
|
|||
hass = request.app['hass']
|
||||
|
||||
try:
|
||||
hass.config_entries.async_abort(flow_id)
|
||||
hass.config_entries.flow.async_abort(flow_id)
|
||||
except config_entries.UnknownFlow:
|
||||
return self.json_message('Invalid flow specified', 404)
|
||||
|
||||
|
|
|
@ -4,20 +4,24 @@ This component provides basic support for the Philips Hue system.
|
|||
For more details about this component, please refer to the documentation at
|
||||
https://home-assistant.io/components/hue/
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from functools import partial
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
|
||||
import async_timeout
|
||||
import requests
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.discovery import SERVICE_HUE
|
||||
from homeassistant.const import CONF_FILENAME, CONF_HOST
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers import discovery
|
||||
from homeassistant.helpers import discovery, aiohttp_client
|
||||
from homeassistant import config_entries
|
||||
|
||||
REQUIREMENTS = ['phue==1.0']
|
||||
REQUIREMENTS = ['phue==1.0', 'aiohue==0.3.0']
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -133,13 +137,14 @@ def bridge_discovered(hass, service, discovery_info):
|
|||
|
||||
|
||||
def setup_bridge(host, hass, filename=None, allow_unreachable=False,
|
||||
allow_in_emulated_hue=True, allow_hue_groups=True):
|
||||
allow_in_emulated_hue=True, allow_hue_groups=True,
|
||||
username=None):
|
||||
"""Set up a given Hue bridge."""
|
||||
# Only register a device once
|
||||
if socket.gethostbyname(host) in hass.data[DOMAIN]:
|
||||
return
|
||||
|
||||
bridge = HueBridge(host, hass, filename, allow_unreachable,
|
||||
bridge = HueBridge(host, hass, filename, username, allow_unreachable,
|
||||
allow_in_emulated_hue, allow_hue_groups)
|
||||
bridge.setup()
|
||||
|
||||
|
@ -164,13 +169,14 @@ def _find_host_from_config(hass, filename=PHUE_CONFIG_FILE):
|
|||
class HueBridge(object):
|
||||
"""Manages a single Hue bridge."""
|
||||
|
||||
def __init__(self, host, hass, filename, allow_unreachable=False,
|
||||
def __init__(self, host, hass, filename, username, allow_unreachable=False,
|
||||
allow_in_emulated_hue=True, allow_hue_groups=True):
|
||||
"""Initialize the system."""
|
||||
self.host = host
|
||||
self.bridge_id = socket.gethostbyname(host)
|
||||
self.hass = hass
|
||||
self.filename = filename
|
||||
self.username = username
|
||||
self.allow_unreachable = allow_unreachable
|
||||
self.allow_in_emulated_hue = allow_in_emulated_hue
|
||||
self.allow_hue_groups = allow_hue_groups
|
||||
|
@ -189,10 +195,14 @@ class HueBridge(object):
|
|||
import phue
|
||||
|
||||
try:
|
||||
self.bridge = phue.Bridge(
|
||||
self.host,
|
||||
config_file_path=self.hass.config.path(self.filename))
|
||||
except (ConnectionRefusedError, OSError): # Wrong host was given
|
||||
kwargs = {}
|
||||
if self.username is not None:
|
||||
kwargs['username'] = self.username
|
||||
if self.filename is not None:
|
||||
kwargs['config_file_path'] = \
|
||||
self.hass.config.path(self.filename)
|
||||
self.bridge = phue.Bridge(self.host, **kwargs)
|
||||
except OSError: # Wrong host was given
|
||||
_LOGGER.error("Error connecting to the Hue bridge at %s",
|
||||
self.host)
|
||||
return
|
||||
|
@ -204,6 +214,7 @@ class HueBridge(object):
|
|||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unknown error connecting with Hue bridge at %s",
|
||||
self.host)
|
||||
return
|
||||
|
||||
# If we came here and configuring this host, mark as done
|
||||
if self.config_request_id:
|
||||
|
@ -260,3 +271,112 @@ class HueBridge(object):
|
|||
def set_group(self, light_id, command):
|
||||
"""Change light settings for a group. See phue for detail."""
|
||||
return self.bridge.set_group(light_id, command)
|
||||
|
||||
|
||||
@config_entries.HANDLERS.register(DOMAIN)
|
||||
class HueFlowHandler(config_entries.ConfigFlowHandler):
|
||||
"""Handle a Hue config flow."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Hue flow."""
|
||||
self.host = None
|
||||
|
||||
@property
|
||||
def _websession(self):
|
||||
"""Return a websession.
|
||||
|
||||
Cannot assign in init because hass variable is not set yet.
|
||||
"""
|
||||
return aiohttp_client.async_get_clientsession(self.hass)
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
"""Handle a flow start."""
|
||||
from aiohue.discovery import discover_nupnp
|
||||
|
||||
if user_input is not None:
|
||||
self.host = user_input['host']
|
||||
return await self.async_step_link()
|
||||
|
||||
try:
|
||||
with async_timeout.timeout(5):
|
||||
bridges = await discover_nupnp(websession=self._websession)
|
||||
except asyncio.TimeoutError:
|
||||
return self.async_abort(
|
||||
reason='Unable to discover Hue bridges.'
|
||||
)
|
||||
|
||||
if not bridges:
|
||||
return self.async_abort(
|
||||
reason='No Philips Hue bridges discovered.'
|
||||
)
|
||||
|
||||
# Find already configured hosts
|
||||
configured_hosts = set(
|
||||
entry.data['host'] for entry
|
||||
in self.hass.config_entries.async_entries(DOMAIN))
|
||||
|
||||
hosts = [bridge.host for bridge in bridges
|
||||
if bridge.host not in configured_hosts]
|
||||
|
||||
if not hosts:
|
||||
return self.async_abort(
|
||||
reason='All Philips Hue bridges are already configured.'
|
||||
)
|
||||
|
||||
elif len(hosts) == 1:
|
||||
self.host = hosts[0]
|
||||
return await self.async_step_link()
|
||||
|
||||
return self.async_show_form(
|
||||
step_id='init',
|
||||
title='Pick Hue Bridge',
|
||||
data_schema=vol.Schema({
|
||||
vol.Required('host'): vol.In(hosts)
|
||||
})
|
||||
)
|
||||
|
||||
async def async_step_link(self, user_input=None):
|
||||
"""Attempt to link with the Hue bridge."""
|
||||
import aiohue
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
bridge = aiohue.Bridge(self.host, websession=self._websession)
|
||||
try:
|
||||
with async_timeout.timeout(5):
|
||||
# Create auth token
|
||||
await bridge.create_user('home-assistant')
|
||||
# Fetches name and id
|
||||
await bridge.initialize()
|
||||
except (asyncio.TimeoutError, aiohue.RequestError,
|
||||
aiohue.LinkButtonNotPressed):
|
||||
errors['base'] = 'Failed to register, please try again.'
|
||||
except aiohue.AiohueException:
|
||||
errors['base'] = 'Unknown linking error occurred.'
|
||||
_LOGGER.exception('Uknown Hue linking error occurred')
|
||||
else:
|
||||
return self.async_create_entry(
|
||||
title=bridge.config.name,
|
||||
data={
|
||||
'host': bridge.host,
|
||||
'bridge_id': bridge.config.bridgeid,
|
||||
'username': bridge.username,
|
||||
}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id='link',
|
||||
title='Link Hub',
|
||||
description=CONFIG_INSTRUCTIONS,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
|
||||
async def async_setup_entry(hass, entry):
|
||||
"""Set up a bridge for a config entry."""
|
||||
await hass.async_add_job(partial(
|
||||
setup_bridge, entry.data['host'], hass,
|
||||
username=entry.data['username']))
|
||||
return True
|
||||
|
|
|
@ -219,7 +219,8 @@ class SpcWebGateway:
|
|||
url = self._build_url(resource)
|
||||
try:
|
||||
_LOGGER.debug("Attempting to retrieve SPC data from %s", url)
|
||||
session = aiohttp.ClientSession()
|
||||
session = \
|
||||
self._hass.helpers.aiohttp_client.async_get_clientsession()
|
||||
with async_timeout.timeout(10, loop=self._hass.loop):
|
||||
action = session.get if use_get else session.put
|
||||
response = yield from action(url)
|
||||
|
|
|
@ -126,7 +126,8 @@ _LOGGER = logging.getLogger(__name__)
|
|||
HANDLERS = Registry()
|
||||
# Components that have config flows. In future we will auto-generate this list.
|
||||
FLOWS = [
|
||||
'config_entry_example'
|
||||
'config_entry_example',
|
||||
'hue',
|
||||
]
|
||||
|
||||
SOURCE_USER = 'user'
|
||||
|
|
|
@ -35,14 +35,7 @@ def async_get_clientsession(hass, verify_ssl=True):
|
|||
key = DATA_CLIENTSESSION_NOTVERIFY
|
||||
|
||||
if key not in hass.data:
|
||||
connector = _async_get_connector(hass, verify_ssl)
|
||||
clientsession = aiohttp.ClientSession(
|
||||
loop=hass.loop,
|
||||
connector=connector,
|
||||
headers={USER_AGENT: SERVER_SOFTWARE}
|
||||
)
|
||||
_async_register_clientsession_shutdown(hass, clientsession)
|
||||
hass.data[key] = clientsession
|
||||
hass.data[key] = async_create_clientsession(hass, verify_ssl)
|
||||
|
||||
return hass.data[key]
|
||||
|
||||
|
|
|
@ -75,6 +75,9 @@ aiodns==1.1.1
|
|||
# homeassistant.components.http
|
||||
aiohttp_cors==0.6.0
|
||||
|
||||
# homeassistant.components.hue
|
||||
aiohue==0.3.0
|
||||
|
||||
# homeassistant.components.sensor.imap
|
||||
aioimaplib==0.7.13
|
||||
|
||||
|
|
|
@ -34,6 +34,9 @@ aioautomatic==0.6.5
|
|||
# homeassistant.components.http
|
||||
aiohttp_cors==0.6.0
|
||||
|
||||
# homeassistant.components.hue
|
||||
aiohue==0.3.0
|
||||
|
||||
# homeassistant.components.notify.apns
|
||||
apns2==0.3.0
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ COMMENT_REQUIREMENTS = (
|
|||
TEST_REQUIREMENTS = (
|
||||
'aioautomatic',
|
||||
'aiohttp_cors',
|
||||
'aiohue',
|
||||
'apns2',
|
||||
'caldav',
|
||||
'coinmarketcap',
|
||||
|
|
|
@ -4,13 +4,17 @@ import logging
|
|||
import unittest
|
||||
from unittest.mock import call, MagicMock, patch
|
||||
|
||||
import aiohue
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import configurator, hue
|
||||
from homeassistant.const import CONF_FILENAME, CONF_HOST
|
||||
from homeassistant.setup import setup_component, async_setup_component
|
||||
|
||||
from tests.common import (
|
||||
assert_setup_component, get_test_home_assistant, get_test_config_dir,
|
||||
MockDependency
|
||||
MockDependency, MockConfigEntry, mock_coro
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -212,7 +216,8 @@ class TestHueBridge(unittest.TestCase):
|
|||
mock_bridge = mock_phue.Bridge
|
||||
mock_bridge.side_effect = ConnectionRefusedError()
|
||||
|
||||
bridge = hue.HueBridge('localhost', self.hass, hue.PHUE_CONFIG_FILE)
|
||||
bridge = hue.HueBridge(
|
||||
'localhost', self.hass, hue.PHUE_CONFIG_FILE, None)
|
||||
bridge.setup()
|
||||
self.assertFalse(bridge.configured)
|
||||
self.assertTrue(bridge.config_request_id is None)
|
||||
|
@ -228,7 +233,8 @@ class TestHueBridge(unittest.TestCase):
|
|||
mock_phue.PhueRegistrationException = Exception
|
||||
mock_bridge.side_effect = mock_phue.PhueRegistrationException(1, 2)
|
||||
|
||||
bridge = hue.HueBridge('localhost', self.hass, hue.PHUE_CONFIG_FILE)
|
||||
bridge = hue.HueBridge(
|
||||
'localhost', self.hass, hue.PHUE_CONFIG_FILE, None)
|
||||
bridge.setup()
|
||||
self.assertFalse(bridge.configured)
|
||||
self.assertFalse(bridge.config_request_id is None)
|
||||
|
@ -250,7 +256,8 @@ class TestHueBridge(unittest.TestCase):
|
|||
None,
|
||||
]
|
||||
|
||||
bridge = hue.HueBridge('localhost', self.hass, hue.PHUE_CONFIG_FILE)
|
||||
bridge = hue.HueBridge(
|
||||
'localhost', self.hass, hue.PHUE_CONFIG_FILE, None)
|
||||
bridge.setup()
|
||||
self.assertFalse(bridge.configured)
|
||||
self.assertFalse(bridge.config_request_id is None)
|
||||
|
@ -291,7 +298,8 @@ class TestHueBridge(unittest.TestCase):
|
|||
ConnectionRefusedError(),
|
||||
]
|
||||
|
||||
bridge = hue.HueBridge('localhost', self.hass, hue.PHUE_CONFIG_FILE)
|
||||
bridge = hue.HueBridge(
|
||||
'localhost', self.hass, hue.PHUE_CONFIG_FILE, None)
|
||||
bridge.setup()
|
||||
self.assertFalse(bridge.configured)
|
||||
self.assertFalse(bridge.config_request_id is None)
|
||||
|
@ -332,7 +340,8 @@ class TestHueBridge(unittest.TestCase):
|
|||
mock_phue.PhueRegistrationException(1, 2),
|
||||
]
|
||||
|
||||
bridge = hue.HueBridge('localhost', self.hass, hue.PHUE_CONFIG_FILE)
|
||||
bridge = hue.HueBridge(
|
||||
'localhost', self.hass, hue.PHUE_CONFIG_FILE, None)
|
||||
bridge.setup()
|
||||
self.assertFalse(bridge.configured)
|
||||
self.assertFalse(bridge.config_request_id is None)
|
||||
|
@ -364,7 +373,7 @@ class TestHueBridge(unittest.TestCase):
|
|||
"""Test the hue_activate_scene service."""
|
||||
with patch('homeassistant.helpers.discovery.load_platform'):
|
||||
bridge = hue.HueBridge('localhost', self.hass,
|
||||
hue.PHUE_CONFIG_FILE)
|
||||
hue.PHUE_CONFIG_FILE, None)
|
||||
bridge.setup()
|
||||
|
||||
# No args
|
||||
|
@ -393,15 +402,187 @@ class TestHueBridge(unittest.TestCase):
|
|||
bridge.bridge.run_scene.assert_called_once_with('group', 'scene')
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_setup_no_host(hass, requests_mock):
|
||||
async def test_setup_no_host(hass, requests_mock):
|
||||
"""No host specified in any way."""
|
||||
requests_mock.get(hue.API_NUPNP, json=[])
|
||||
with MockDependency('phue') as mock_phue:
|
||||
result = yield from async_setup_component(
|
||||
result = await async_setup_component(
|
||||
hass, hue.DOMAIN, {hue.DOMAIN: {}})
|
||||
assert result
|
||||
|
||||
mock_phue.Bridge.assert_not_called()
|
||||
|
||||
assert hass.data[hue.DOMAIN] == {}
|
||||
|
||||
|
||||
async def test_flow_works(hass, aioclient_mock):
|
||||
"""Test config flow ."""
|
||||
aioclient_mock.get(hue.API_NUPNP, json=[
|
||||
{'internalipaddress': '1.2.3.4', 'id': 'bla'}
|
||||
])
|
||||
|
||||
flow = hue.HueFlowHandler()
|
||||
flow.hass = hass
|
||||
await flow.async_step_init()
|
||||
|
||||
with patch('aiohue.Bridge') as mock_bridge:
|
||||
def mock_constructor(host, websession):
|
||||
mock_bridge.host = host
|
||||
return mock_bridge
|
||||
|
||||
mock_bridge.side_effect = mock_constructor
|
||||
mock_bridge.username = 'username-abc'
|
||||
mock_bridge.config.name = 'Mock Bridge'
|
||||
mock_bridge.config.bridgeid = 'bridge-id-1234'
|
||||
mock_bridge.create_user.return_value = mock_coro()
|
||||
mock_bridge.initialize.return_value = mock_coro()
|
||||
|
||||
result = await flow.async_step_link(user_input={})
|
||||
|
||||
assert mock_bridge.host == '1.2.3.4'
|
||||
assert len(mock_bridge.create_user.mock_calls) == 1
|
||||
assert len(mock_bridge.initialize.mock_calls) == 1
|
||||
|
||||
assert result['type'] == 'create_entry'
|
||||
assert result['title'] == 'Mock Bridge'
|
||||
assert result['data'] == {
|
||||
'host': '1.2.3.4',
|
||||
'bridge_id': 'bridge-id-1234',
|
||||
'username': 'username-abc'
|
||||
}
|
||||
|
||||
|
||||
async def test_flow_no_discovered_bridges(hass, aioclient_mock):
|
||||
"""Test config flow discovers no bridges."""
|
||||
aioclient_mock.get(hue.API_NUPNP, json=[])
|
||||
flow = hue.HueFlowHandler()
|
||||
flow.hass = hass
|
||||
|
||||
result = await flow.async_step_init()
|
||||
assert result['type'] == 'abort'
|
||||
|
||||
|
||||
async def test_flow_all_discovered_bridges_exist(hass, aioclient_mock):
|
||||
"""Test config flow discovers only already configured bridges."""
|
||||
aioclient_mock.get(hue.API_NUPNP, json=[
|
||||
{'internalipaddress': '1.2.3.4', 'id': 'bla'}
|
||||
])
|
||||
MockConfigEntry(domain='hue', data={
|
||||
'host': '1.2.3.4'
|
||||
}).add_to_hass(hass)
|
||||
flow = hue.HueFlowHandler()
|
||||
flow.hass = hass
|
||||
|
||||
result = await flow.async_step_init()
|
||||
assert result['type'] == 'abort'
|
||||
|
||||
|
||||
async def test_flow_one_bridge_discovered(hass, aioclient_mock):
|
||||
"""Test config flow discovers one bridge."""
|
||||
aioclient_mock.get(hue.API_NUPNP, json=[
|
||||
{'internalipaddress': '1.2.3.4', 'id': 'bla'}
|
||||
])
|
||||
flow = hue.HueFlowHandler()
|
||||
flow.hass = hass
|
||||
|
||||
result = await flow.async_step_init()
|
||||
assert result['type'] == 'form'
|
||||
assert result['step_id'] == 'link'
|
||||
|
||||
|
||||
async def test_flow_two_bridges_discovered(hass, aioclient_mock):
|
||||
"""Test config flow discovers two bridges."""
|
||||
aioclient_mock.get(hue.API_NUPNP, json=[
|
||||
{'internalipaddress': '1.2.3.4', 'id': 'bla'},
|
||||
{'internalipaddress': '5.6.7.8', 'id': 'beer'}
|
||||
])
|
||||
flow = hue.HueFlowHandler()
|
||||
flow.hass = hass
|
||||
|
||||
result = await flow.async_step_init()
|
||||
assert result['type'] == 'form'
|
||||
assert result['step_id'] == 'init'
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
assert result['data_schema']({'host': '0.0.0.0'})
|
||||
|
||||
result['data_schema']({'host': '1.2.3.4'})
|
||||
result['data_schema']({'host': '5.6.7.8'})
|
||||
|
||||
|
||||
async def test_flow_two_bridges_discovered_one_new(hass, aioclient_mock):
|
||||
"""Test config flow discovers two bridges."""
|
||||
aioclient_mock.get(hue.API_NUPNP, json=[
|
||||
{'internalipaddress': '1.2.3.4', 'id': 'bla'},
|
||||
{'internalipaddress': '5.6.7.8', 'id': 'beer'}
|
||||
])
|
||||
MockConfigEntry(domain='hue', data={
|
||||
'host': '1.2.3.4'
|
||||
}).add_to_hass(hass)
|
||||
flow = hue.HueFlowHandler()
|
||||
flow.hass = hass
|
||||
|
||||
result = await flow.async_step_init()
|
||||
assert result['type'] == 'form'
|
||||
assert result['step_id'] == 'link'
|
||||
assert flow.host == '5.6.7.8'
|
||||
|
||||
|
||||
async def test_flow_timeout_discovery(hass):
|
||||
"""Test config flow ."""
|
||||
flow = hue.HueFlowHandler()
|
||||
flow.hass = hass
|
||||
|
||||
with patch('aiohue.discovery.discover_nupnp',
|
||||
side_effect=asyncio.TimeoutError):
|
||||
result = await flow.async_step_init()
|
||||
|
||||
assert result['type'] == 'abort'
|
||||
|
||||
|
||||
async def test_flow_link_timeout(hass):
|
||||
"""Test config flow ."""
|
||||
flow = hue.HueFlowHandler()
|
||||
flow.hass = hass
|
||||
|
||||
with patch('aiohue.Bridge.create_user',
|
||||
side_effect=asyncio.TimeoutError):
|
||||
result = await flow.async_step_link({})
|
||||
|
||||
assert result['type'] == 'form'
|
||||
assert result['step_id'] == 'link'
|
||||
assert result['errors'] == {
|
||||
'base': 'Failed to register, please try again.'
|
||||
}
|
||||
|
||||
|
||||
async def test_flow_link_button_not_pressed(hass):
|
||||
"""Test config flow ."""
|
||||
flow = hue.HueFlowHandler()
|
||||
flow.hass = hass
|
||||
|
||||
with patch('aiohue.Bridge.create_user',
|
||||
side_effect=aiohue.LinkButtonNotPressed):
|
||||
result = await flow.async_step_link({})
|
||||
|
||||
assert result['type'] == 'form'
|
||||
assert result['step_id'] == 'link'
|
||||
assert result['errors'] == {
|
||||
'base': 'Failed to register, please try again.'
|
||||
}
|
||||
|
||||
|
||||
async def test_flow_link_unknown_host(hass):
|
||||
"""Test config flow ."""
|
||||
flow = hue.HueFlowHandler()
|
||||
flow.hass = hass
|
||||
|
||||
with patch('aiohue.Bridge.create_user',
|
||||
side_effect=aiohue.RequestError):
|
||||
result = await flow.async_step_link({})
|
||||
|
||||
assert result['type'] == 'form'
|
||||
assert result['step_id'] == 'link'
|
||||
assert result['errors'] == {
|
||||
'base': 'Failed to register, please try again.'
|
||||
}
|
||||
|
|
1
tests/test_util/__init__.py
Normal file
1
tests/test_util/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
"""Tests for the test utilities."""
|
|
@ -1,11 +1,13 @@
|
|||
"""Aiohttp test utils."""
|
||||
import asyncio
|
||||
from contextlib import contextmanager
|
||||
import functools
|
||||
import json as _json
|
||||
import re
|
||||
from unittest import mock
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
import yarl
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from yarl import URL
|
||||
|
||||
from aiohttp.client_exceptions import ClientResponseError
|
||||
|
||||
|
@ -31,14 +33,17 @@ class AiohttpClientMocker:
|
|||
exc=None,
|
||||
cookies=None):
|
||||
"""Mock a request."""
|
||||
if json:
|
||||
if json is not None:
|
||||
text = _json.dumps(json)
|
||||
if text:
|
||||
if text is not None:
|
||||
content = text.encode('utf-8')
|
||||
if content is None:
|
||||
content = b''
|
||||
|
||||
if not isinstance(url, re._pattern_type):
|
||||
url = URL(url)
|
||||
if params:
|
||||
url = str(yarl.URL(url).with_query(params))
|
||||
url = url.with_query(params)
|
||||
|
||||
self._mocks.append(AiohttpClientMockResponse(
|
||||
method, url, status, content, cookies, exc, headers))
|
||||
|
@ -74,13 +79,21 @@ class AiohttpClientMocker:
|
|||
self._cookies.clear()
|
||||
self.mock_calls.clear()
|
||||
|
||||
@asyncio.coroutine
|
||||
# pylint: disable=unused-variable
|
||||
def match_request(self, method, url, *, data=None, auth=None, params=None,
|
||||
headers=None, allow_redirects=None, timeout=None,
|
||||
json=None):
|
||||
def create_session(self, loop):
|
||||
"""Create a ClientSession that is bound to this mocker."""
|
||||
session = ClientSession(loop=loop)
|
||||
session._request = self.match_request
|
||||
return session
|
||||
|
||||
async def match_request(self, method, url, *, data=None, auth=None,
|
||||
params=None, headers=None, allow_redirects=None,
|
||||
timeout=None, json=None):
|
||||
"""Match a request against pre-registered requests."""
|
||||
data = data or json
|
||||
url = URL(url)
|
||||
if params:
|
||||
url = url.with_query(params)
|
||||
|
||||
for response in self._mocks:
|
||||
if response.match_request(method, url, params):
|
||||
self.mock_calls.append((method, url, data, headers))
|
||||
|
@ -101,8 +114,6 @@ class AiohttpClientMockResponse:
|
|||
"""Initialize a fake response."""
|
||||
self.method = method
|
||||
self._url = url
|
||||
self._url_parts = (None if hasattr(url, 'search')
|
||||
else urlparse(url.lower()))
|
||||
self.status = status
|
||||
self.response = response
|
||||
self.exc = exc
|
||||
|
@ -133,25 +144,17 @@ class AiohttpClientMockResponse:
|
|||
if method.lower() != self.method.lower():
|
||||
return False
|
||||
|
||||
if params:
|
||||
url = str(yarl.URL(url).with_query(params))
|
||||
|
||||
# regular expression matching
|
||||
if self._url_parts is None:
|
||||
return self._url.search(url) is not None
|
||||
if isinstance(self._url, re._pattern_type):
|
||||
return self._url.search(str(url)) is not None
|
||||
|
||||
req = urlparse(url.lower())
|
||||
|
||||
if self._url_parts.scheme and req.scheme != self._url_parts.scheme:
|
||||
return False
|
||||
if self._url_parts.netloc and req.netloc != self._url_parts.netloc:
|
||||
return False
|
||||
if (req.path or '/') != (self._url_parts.path or '/'):
|
||||
if (self._url.scheme != url.scheme or self._url.host != url.host or
|
||||
self._url.path != url.path):
|
||||
return False
|
||||
|
||||
# Ensure all query components in matcher are present in the request
|
||||
request_qs = parse_qs(req.query)
|
||||
matcher_qs = parse_qs(self._url_parts.query)
|
||||
request_qs = parse_qs(url.query_string)
|
||||
matcher_qs = parse_qs(self._url.query_string)
|
||||
for key, vals in matcher_qs.items():
|
||||
for val in vals:
|
||||
try:
|
||||
|
@ -207,12 +210,7 @@ def mock_aiohttp_client():
|
|||
"""Context manager to mock aiohttp client."""
|
||||
mocker = AiohttpClientMocker()
|
||||
|
||||
with mock.patch('aiohttp.ClientSession') as mock_session:
|
||||
instance = mock_session()
|
||||
instance.request = mocker.match_request
|
||||
|
||||
for method in ('get', 'post', 'put', 'options', 'delete'):
|
||||
setattr(instance, method,
|
||||
functools.partial(mocker.match_request, method))
|
||||
|
||||
with mock.patch(
|
||||
'homeassistant.helpers.aiohttp_client.async_create_clientsession',
|
||||
side_effect=lambda hass, *args: mocker.create_session(hass.loop)):
|
||||
yield mocker
|
||||
|
|
22
tests/test_util/test_aiohttp.py
Normal file
22
tests/test_util/test_aiohttp.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
"""Tests for our aiohttp mocker."""
|
||||
from .aiohttp import AiohttpClientMocker
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
async def test_matching_url():
|
||||
"""Test we can match urls."""
|
||||
mocker = AiohttpClientMocker()
|
||||
mocker.get('http://example.com')
|
||||
await mocker.match_request('get', 'http://example.com/')
|
||||
|
||||
mocker.clear_requests()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
await mocker.match_request('get', 'http://example.com/')
|
||||
|
||||
mocker.clear_requests()
|
||||
|
||||
mocker.get('http://example.com?a=1')
|
||||
await mocker.match_request('get', 'http://example.com/',
|
||||
params={'a': 1, 'b': 2})
|
Loading…
Add table
Reference in a new issue