Add config flow for Hue (#12830)

* Add config flow for Hue

* Upgrade to aiohue 0.2

* Fix tests

* Add tests

* Add aiohue to test requirements

* Bump aiohue dependency

* Lint

* Lint

* Fix aiohttp mock

* Lint

* Fix tests
This commit is contained in:
Paulus Schoutsen 2018-03-03 21:28:04 -08:00 committed by GitHub
parent d06807c634
commit 67c49a7662
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 389 additions and 65 deletions

View file

@ -163,7 +163,7 @@ class ConfigManagerFlowResourceView(HomeAssistantView):
hass = request.app['hass'] hass = request.app['hass']
try: try:
hass.config_entries.async_abort(flow_id) hass.config_entries.flow.async_abort(flow_id)
except config_entries.UnknownFlow: except config_entries.UnknownFlow:
return self.json_message('Invalid flow specified', 404) return self.json_message('Invalid flow specified', 404)

View file

@ -4,20 +4,24 @@ This component provides basic support for the Philips Hue system.
For more details about this component, please refer to the documentation at For more details about this component, please refer to the documentation at
https://home-assistant.io/components/hue/ https://home-assistant.io/components/hue/
""" """
import asyncio
import json import json
from functools import partial
import logging import logging
import os import os
import socket import socket
import async_timeout
import requests import requests
import voluptuous as vol import voluptuous as vol
from homeassistant.components.discovery import SERVICE_HUE from homeassistant.components.discovery import SERVICE_HUE
from homeassistant.const import CONF_FILENAME, CONF_HOST from homeassistant.const import CONF_FILENAME, CONF_HOST
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers import discovery from homeassistant.helpers import discovery, aiohttp_client
from homeassistant import config_entries
REQUIREMENTS = ['phue==1.0'] REQUIREMENTS = ['phue==1.0', 'aiohue==0.3.0']
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -133,13 +137,14 @@ def bridge_discovered(hass, service, discovery_info):
def setup_bridge(host, hass, filename=None, allow_unreachable=False, def setup_bridge(host, hass, filename=None, allow_unreachable=False,
allow_in_emulated_hue=True, allow_hue_groups=True): allow_in_emulated_hue=True, allow_hue_groups=True,
username=None):
"""Set up a given Hue bridge.""" """Set up a given Hue bridge."""
# Only register a device once # Only register a device once
if socket.gethostbyname(host) in hass.data[DOMAIN]: if socket.gethostbyname(host) in hass.data[DOMAIN]:
return return
bridge = HueBridge(host, hass, filename, allow_unreachable, bridge = HueBridge(host, hass, filename, username, allow_unreachable,
allow_in_emulated_hue, allow_hue_groups) allow_in_emulated_hue, allow_hue_groups)
bridge.setup() bridge.setup()
@ -164,13 +169,14 @@ def _find_host_from_config(hass, filename=PHUE_CONFIG_FILE):
class HueBridge(object): class HueBridge(object):
"""Manages a single Hue bridge.""" """Manages a single Hue bridge."""
def __init__(self, host, hass, filename, allow_unreachable=False, def __init__(self, host, hass, filename, username, allow_unreachable=False,
allow_in_emulated_hue=True, allow_hue_groups=True): allow_in_emulated_hue=True, allow_hue_groups=True):
"""Initialize the system.""" """Initialize the system."""
self.host = host self.host = host
self.bridge_id = socket.gethostbyname(host) self.bridge_id = socket.gethostbyname(host)
self.hass = hass self.hass = hass
self.filename = filename self.filename = filename
self.username = username
self.allow_unreachable = allow_unreachable self.allow_unreachable = allow_unreachable
self.allow_in_emulated_hue = allow_in_emulated_hue self.allow_in_emulated_hue = allow_in_emulated_hue
self.allow_hue_groups = allow_hue_groups self.allow_hue_groups = allow_hue_groups
@ -189,10 +195,14 @@ class HueBridge(object):
import phue import phue
try: try:
self.bridge = phue.Bridge( kwargs = {}
self.host, if self.username is not None:
config_file_path=self.hass.config.path(self.filename)) kwargs['username'] = self.username
except (ConnectionRefusedError, OSError): # Wrong host was given if self.filename is not None:
kwargs['config_file_path'] = \
self.hass.config.path(self.filename)
self.bridge = phue.Bridge(self.host, **kwargs)
except OSError: # Wrong host was given
_LOGGER.error("Error connecting to the Hue bridge at %s", _LOGGER.error("Error connecting to the Hue bridge at %s",
self.host) self.host)
return return
@ -204,6 +214,7 @@ class HueBridge(object):
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unknown error connecting with Hue bridge at %s", _LOGGER.exception("Unknown error connecting with Hue bridge at %s",
self.host) self.host)
return
# If we came here and configuring this host, mark as done # If we came here and configuring this host, mark as done
if self.config_request_id: if self.config_request_id:
@ -260,3 +271,112 @@ class HueBridge(object):
def set_group(self, light_id, command): def set_group(self, light_id, command):
"""Change light settings for a group. See phue for detail.""" """Change light settings for a group. See phue for detail."""
return self.bridge.set_group(light_id, command) return self.bridge.set_group(light_id, command)
@config_entries.HANDLERS.register(DOMAIN)
class HueFlowHandler(config_entries.ConfigFlowHandler):
"""Handle a Hue config flow."""
VERSION = 1
def __init__(self):
"""Initialize the Hue flow."""
self.host = None
@property
def _websession(self):
"""Return a websession.
Cannot assign in init because hass variable is not set yet.
"""
return aiohttp_client.async_get_clientsession(self.hass)
async def async_step_init(self, user_input=None):
"""Handle a flow start."""
from aiohue.discovery import discover_nupnp
if user_input is not None:
self.host = user_input['host']
return await self.async_step_link()
try:
with async_timeout.timeout(5):
bridges = await discover_nupnp(websession=self._websession)
except asyncio.TimeoutError:
return self.async_abort(
reason='Unable to discover Hue bridges.'
)
if not bridges:
return self.async_abort(
reason='No Philips Hue bridges discovered.'
)
# Find already configured hosts
configured_hosts = set(
entry.data['host'] for entry
in self.hass.config_entries.async_entries(DOMAIN))
hosts = [bridge.host for bridge in bridges
if bridge.host not in configured_hosts]
if not hosts:
return self.async_abort(
reason='All Philips Hue bridges are already configured.'
)
elif len(hosts) == 1:
self.host = hosts[0]
return await self.async_step_link()
return self.async_show_form(
step_id='init',
title='Pick Hue Bridge',
data_schema=vol.Schema({
vol.Required('host'): vol.In(hosts)
})
)
async def async_step_link(self, user_input=None):
"""Attempt to link with the Hue bridge."""
import aiohue
errors = {}
if user_input is not None:
bridge = aiohue.Bridge(self.host, websession=self._websession)
try:
with async_timeout.timeout(5):
# Create auth token
await bridge.create_user('home-assistant')
# Fetches name and id
await bridge.initialize()
except (asyncio.TimeoutError, aiohue.RequestError,
aiohue.LinkButtonNotPressed):
errors['base'] = 'Failed to register, please try again.'
except aiohue.AiohueException:
errors['base'] = 'Unknown linking error occurred.'
_LOGGER.exception('Uknown Hue linking error occurred')
else:
return self.async_create_entry(
title=bridge.config.name,
data={
'host': bridge.host,
'bridge_id': bridge.config.bridgeid,
'username': bridge.username,
}
)
return self.async_show_form(
step_id='link',
title='Link Hub',
description=CONFIG_INSTRUCTIONS,
errors=errors,
)
async def async_setup_entry(hass, entry):
"""Set up a bridge for a config entry."""
await hass.async_add_job(partial(
setup_bridge, entry.data['host'], hass,
username=entry.data['username']))
return True

View file

@ -219,7 +219,8 @@ class SpcWebGateway:
url = self._build_url(resource) url = self._build_url(resource)
try: try:
_LOGGER.debug("Attempting to retrieve SPC data from %s", url) _LOGGER.debug("Attempting to retrieve SPC data from %s", url)
session = aiohttp.ClientSession() session = \
self._hass.helpers.aiohttp_client.async_get_clientsession()
with async_timeout.timeout(10, loop=self._hass.loop): with async_timeout.timeout(10, loop=self._hass.loop):
action = session.get if use_get else session.put action = session.get if use_get else session.put
response = yield from action(url) response = yield from action(url)

View file

@ -126,7 +126,8 @@ _LOGGER = logging.getLogger(__name__)
HANDLERS = Registry() HANDLERS = Registry()
# Components that have config flows. In future we will auto-generate this list. # Components that have config flows. In future we will auto-generate this list.
FLOWS = [ FLOWS = [
'config_entry_example' 'config_entry_example',
'hue',
] ]
SOURCE_USER = 'user' SOURCE_USER = 'user'

View file

@ -35,14 +35,7 @@ def async_get_clientsession(hass, verify_ssl=True):
key = DATA_CLIENTSESSION_NOTVERIFY key = DATA_CLIENTSESSION_NOTVERIFY
if key not in hass.data: if key not in hass.data:
connector = _async_get_connector(hass, verify_ssl) hass.data[key] = async_create_clientsession(hass, verify_ssl)
clientsession = aiohttp.ClientSession(
loop=hass.loop,
connector=connector,
headers={USER_AGENT: SERVER_SOFTWARE}
)
_async_register_clientsession_shutdown(hass, clientsession)
hass.data[key] = clientsession
return hass.data[key] return hass.data[key]

View file

@ -75,6 +75,9 @@ aiodns==1.1.1
# homeassistant.components.http # homeassistant.components.http
aiohttp_cors==0.6.0 aiohttp_cors==0.6.0
# homeassistant.components.hue
aiohue==0.3.0
# homeassistant.components.sensor.imap # homeassistant.components.sensor.imap
aioimaplib==0.7.13 aioimaplib==0.7.13

View file

@ -34,6 +34,9 @@ aioautomatic==0.6.5
# homeassistant.components.http # homeassistant.components.http
aiohttp_cors==0.6.0 aiohttp_cors==0.6.0
# homeassistant.components.hue
aiohue==0.3.0
# homeassistant.components.notify.apns # homeassistant.components.notify.apns
apns2==0.3.0 apns2==0.3.0

View file

@ -37,6 +37,7 @@ COMMENT_REQUIREMENTS = (
TEST_REQUIREMENTS = ( TEST_REQUIREMENTS = (
'aioautomatic', 'aioautomatic',
'aiohttp_cors', 'aiohttp_cors',
'aiohue',
'apns2', 'apns2',
'caldav', 'caldav',
'coinmarketcap', 'coinmarketcap',

View file

@ -4,13 +4,17 @@ import logging
import unittest import unittest
from unittest.mock import call, MagicMock, patch from unittest.mock import call, MagicMock, patch
import aiohue
import pytest
import voluptuous as vol
from homeassistant.components import configurator, hue from homeassistant.components import configurator, hue
from homeassistant.const import CONF_FILENAME, CONF_HOST from homeassistant.const import CONF_FILENAME, CONF_HOST
from homeassistant.setup import setup_component, async_setup_component from homeassistant.setup import setup_component, async_setup_component
from tests.common import ( from tests.common import (
assert_setup_component, get_test_home_assistant, get_test_config_dir, assert_setup_component, get_test_home_assistant, get_test_config_dir,
MockDependency MockDependency, MockConfigEntry, mock_coro
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -212,7 +216,8 @@ class TestHueBridge(unittest.TestCase):
mock_bridge = mock_phue.Bridge mock_bridge = mock_phue.Bridge
mock_bridge.side_effect = ConnectionRefusedError() mock_bridge.side_effect = ConnectionRefusedError()
bridge = hue.HueBridge('localhost', self.hass, hue.PHUE_CONFIG_FILE) bridge = hue.HueBridge(
'localhost', self.hass, hue.PHUE_CONFIG_FILE, None)
bridge.setup() bridge.setup()
self.assertFalse(bridge.configured) self.assertFalse(bridge.configured)
self.assertTrue(bridge.config_request_id is None) self.assertTrue(bridge.config_request_id is None)
@ -228,7 +233,8 @@ class TestHueBridge(unittest.TestCase):
mock_phue.PhueRegistrationException = Exception mock_phue.PhueRegistrationException = Exception
mock_bridge.side_effect = mock_phue.PhueRegistrationException(1, 2) mock_bridge.side_effect = mock_phue.PhueRegistrationException(1, 2)
bridge = hue.HueBridge('localhost', self.hass, hue.PHUE_CONFIG_FILE) bridge = hue.HueBridge(
'localhost', self.hass, hue.PHUE_CONFIG_FILE, None)
bridge.setup() bridge.setup()
self.assertFalse(bridge.configured) self.assertFalse(bridge.configured)
self.assertFalse(bridge.config_request_id is None) self.assertFalse(bridge.config_request_id is None)
@ -250,7 +256,8 @@ class TestHueBridge(unittest.TestCase):
None, None,
] ]
bridge = hue.HueBridge('localhost', self.hass, hue.PHUE_CONFIG_FILE) bridge = hue.HueBridge(
'localhost', self.hass, hue.PHUE_CONFIG_FILE, None)
bridge.setup() bridge.setup()
self.assertFalse(bridge.configured) self.assertFalse(bridge.configured)
self.assertFalse(bridge.config_request_id is None) self.assertFalse(bridge.config_request_id is None)
@ -291,7 +298,8 @@ class TestHueBridge(unittest.TestCase):
ConnectionRefusedError(), ConnectionRefusedError(),
] ]
bridge = hue.HueBridge('localhost', self.hass, hue.PHUE_CONFIG_FILE) bridge = hue.HueBridge(
'localhost', self.hass, hue.PHUE_CONFIG_FILE, None)
bridge.setup() bridge.setup()
self.assertFalse(bridge.configured) self.assertFalse(bridge.configured)
self.assertFalse(bridge.config_request_id is None) self.assertFalse(bridge.config_request_id is None)
@ -332,7 +340,8 @@ class TestHueBridge(unittest.TestCase):
mock_phue.PhueRegistrationException(1, 2), mock_phue.PhueRegistrationException(1, 2),
] ]
bridge = hue.HueBridge('localhost', self.hass, hue.PHUE_CONFIG_FILE) bridge = hue.HueBridge(
'localhost', self.hass, hue.PHUE_CONFIG_FILE, None)
bridge.setup() bridge.setup()
self.assertFalse(bridge.configured) self.assertFalse(bridge.configured)
self.assertFalse(bridge.config_request_id is None) self.assertFalse(bridge.config_request_id is None)
@ -364,7 +373,7 @@ class TestHueBridge(unittest.TestCase):
"""Test the hue_activate_scene service.""" """Test the hue_activate_scene service."""
with patch('homeassistant.helpers.discovery.load_platform'): with patch('homeassistant.helpers.discovery.load_platform'):
bridge = hue.HueBridge('localhost', self.hass, bridge = hue.HueBridge('localhost', self.hass,
hue.PHUE_CONFIG_FILE) hue.PHUE_CONFIG_FILE, None)
bridge.setup() bridge.setup()
# No args # No args
@ -393,15 +402,187 @@ class TestHueBridge(unittest.TestCase):
bridge.bridge.run_scene.assert_called_once_with('group', 'scene') bridge.bridge.run_scene.assert_called_once_with('group', 'scene')
@asyncio.coroutine async def test_setup_no_host(hass, requests_mock):
def test_setup_no_host(hass, requests_mock):
"""No host specified in any way.""" """No host specified in any way."""
requests_mock.get(hue.API_NUPNP, json=[]) requests_mock.get(hue.API_NUPNP, json=[])
with MockDependency('phue') as mock_phue: with MockDependency('phue') as mock_phue:
result = yield from async_setup_component( result = await async_setup_component(
hass, hue.DOMAIN, {hue.DOMAIN: {}}) hass, hue.DOMAIN, {hue.DOMAIN: {}})
assert result assert result
mock_phue.Bridge.assert_not_called() mock_phue.Bridge.assert_not_called()
assert hass.data[hue.DOMAIN] == {} assert hass.data[hue.DOMAIN] == {}
async def test_flow_works(hass, aioclient_mock):
"""Test config flow ."""
aioclient_mock.get(hue.API_NUPNP, json=[
{'internalipaddress': '1.2.3.4', 'id': 'bla'}
])
flow = hue.HueFlowHandler()
flow.hass = hass
await flow.async_step_init()
with patch('aiohue.Bridge') as mock_bridge:
def mock_constructor(host, websession):
mock_bridge.host = host
return mock_bridge
mock_bridge.side_effect = mock_constructor
mock_bridge.username = 'username-abc'
mock_bridge.config.name = 'Mock Bridge'
mock_bridge.config.bridgeid = 'bridge-id-1234'
mock_bridge.create_user.return_value = mock_coro()
mock_bridge.initialize.return_value = mock_coro()
result = await flow.async_step_link(user_input={})
assert mock_bridge.host == '1.2.3.4'
assert len(mock_bridge.create_user.mock_calls) == 1
assert len(mock_bridge.initialize.mock_calls) == 1
assert result['type'] == 'create_entry'
assert result['title'] == 'Mock Bridge'
assert result['data'] == {
'host': '1.2.3.4',
'bridge_id': 'bridge-id-1234',
'username': 'username-abc'
}
async def test_flow_no_discovered_bridges(hass, aioclient_mock):
"""Test config flow discovers no bridges."""
aioclient_mock.get(hue.API_NUPNP, json=[])
flow = hue.HueFlowHandler()
flow.hass = hass
result = await flow.async_step_init()
assert result['type'] == 'abort'
async def test_flow_all_discovered_bridges_exist(hass, aioclient_mock):
"""Test config flow discovers only already configured bridges."""
aioclient_mock.get(hue.API_NUPNP, json=[
{'internalipaddress': '1.2.3.4', 'id': 'bla'}
])
MockConfigEntry(domain='hue', data={
'host': '1.2.3.4'
}).add_to_hass(hass)
flow = hue.HueFlowHandler()
flow.hass = hass
result = await flow.async_step_init()
assert result['type'] == 'abort'
async def test_flow_one_bridge_discovered(hass, aioclient_mock):
"""Test config flow discovers one bridge."""
aioclient_mock.get(hue.API_NUPNP, json=[
{'internalipaddress': '1.2.3.4', 'id': 'bla'}
])
flow = hue.HueFlowHandler()
flow.hass = hass
result = await flow.async_step_init()
assert result['type'] == 'form'
assert result['step_id'] == 'link'
async def test_flow_two_bridges_discovered(hass, aioclient_mock):
"""Test config flow discovers two bridges."""
aioclient_mock.get(hue.API_NUPNP, json=[
{'internalipaddress': '1.2.3.4', 'id': 'bla'},
{'internalipaddress': '5.6.7.8', 'id': 'beer'}
])
flow = hue.HueFlowHandler()
flow.hass = hass
result = await flow.async_step_init()
assert result['type'] == 'form'
assert result['step_id'] == 'init'
with pytest.raises(vol.Invalid):
assert result['data_schema']({'host': '0.0.0.0'})
result['data_schema']({'host': '1.2.3.4'})
result['data_schema']({'host': '5.6.7.8'})
async def test_flow_two_bridges_discovered_one_new(hass, aioclient_mock):
"""Test config flow discovers two bridges."""
aioclient_mock.get(hue.API_NUPNP, json=[
{'internalipaddress': '1.2.3.4', 'id': 'bla'},
{'internalipaddress': '5.6.7.8', 'id': 'beer'}
])
MockConfigEntry(domain='hue', data={
'host': '1.2.3.4'
}).add_to_hass(hass)
flow = hue.HueFlowHandler()
flow.hass = hass
result = await flow.async_step_init()
assert result['type'] == 'form'
assert result['step_id'] == 'link'
assert flow.host == '5.6.7.8'
async def test_flow_timeout_discovery(hass):
"""Test config flow ."""
flow = hue.HueFlowHandler()
flow.hass = hass
with patch('aiohue.discovery.discover_nupnp',
side_effect=asyncio.TimeoutError):
result = await flow.async_step_init()
assert result['type'] == 'abort'
async def test_flow_link_timeout(hass):
"""Test config flow ."""
flow = hue.HueFlowHandler()
flow.hass = hass
with patch('aiohue.Bridge.create_user',
side_effect=asyncio.TimeoutError):
result = await flow.async_step_link({})
assert result['type'] == 'form'
assert result['step_id'] == 'link'
assert result['errors'] == {
'base': 'Failed to register, please try again.'
}
async def test_flow_link_button_not_pressed(hass):
"""Test config flow ."""
flow = hue.HueFlowHandler()
flow.hass = hass
with patch('aiohue.Bridge.create_user',
side_effect=aiohue.LinkButtonNotPressed):
result = await flow.async_step_link({})
assert result['type'] == 'form'
assert result['step_id'] == 'link'
assert result['errors'] == {
'base': 'Failed to register, please try again.'
}
async def test_flow_link_unknown_host(hass):
"""Test config flow ."""
flow = hue.HueFlowHandler()
flow.hass = hass
with patch('aiohue.Bridge.create_user',
side_effect=aiohue.RequestError):
result = await flow.async_step_link({})
assert result['type'] == 'form'
assert result['step_id'] == 'link'
assert result['errors'] == {
'base': 'Failed to register, please try again.'
}

View file

@ -0,0 +1 @@
"""Tests for the test utilities."""

View file

@ -1,11 +1,13 @@
"""Aiohttp test utils.""" """Aiohttp test utils."""
import asyncio import asyncio
from contextlib import contextmanager from contextlib import contextmanager
import functools
import json as _json import json as _json
import re
from unittest import mock from unittest import mock
from urllib.parse import urlparse, parse_qs from urllib.parse import parse_qs
import yarl
from aiohttp import ClientSession
from yarl import URL
from aiohttp.client_exceptions import ClientResponseError from aiohttp.client_exceptions import ClientResponseError
@ -31,14 +33,17 @@ class AiohttpClientMocker:
exc=None, exc=None,
cookies=None): cookies=None):
"""Mock a request.""" """Mock a request."""
if json: if json is not None:
text = _json.dumps(json) text = _json.dumps(json)
if text: if text is not None:
content = text.encode('utf-8') content = text.encode('utf-8')
if content is None: if content is None:
content = b'' content = b''
if not isinstance(url, re._pattern_type):
url = URL(url)
if params: if params:
url = str(yarl.URL(url).with_query(params)) url = url.with_query(params)
self._mocks.append(AiohttpClientMockResponse( self._mocks.append(AiohttpClientMockResponse(
method, url, status, content, cookies, exc, headers)) method, url, status, content, cookies, exc, headers))
@ -74,13 +79,21 @@ class AiohttpClientMocker:
self._cookies.clear() self._cookies.clear()
self.mock_calls.clear() self.mock_calls.clear()
@asyncio.coroutine def create_session(self, loop):
# pylint: disable=unused-variable """Create a ClientSession that is bound to this mocker."""
def match_request(self, method, url, *, data=None, auth=None, params=None, session = ClientSession(loop=loop)
headers=None, allow_redirects=None, timeout=None, session._request = self.match_request
json=None): return session
async def match_request(self, method, url, *, data=None, auth=None,
params=None, headers=None, allow_redirects=None,
timeout=None, json=None):
"""Match a request against pre-registered requests.""" """Match a request against pre-registered requests."""
data = data or json data = data or json
url = URL(url)
if params:
url = url.with_query(params)
for response in self._mocks: for response in self._mocks:
if response.match_request(method, url, params): if response.match_request(method, url, params):
self.mock_calls.append((method, url, data, headers)) self.mock_calls.append((method, url, data, headers))
@ -101,8 +114,6 @@ class AiohttpClientMockResponse:
"""Initialize a fake response.""" """Initialize a fake response."""
self.method = method self.method = method
self._url = url self._url = url
self._url_parts = (None if hasattr(url, 'search')
else urlparse(url.lower()))
self.status = status self.status = status
self.response = response self.response = response
self.exc = exc self.exc = exc
@ -133,25 +144,17 @@ class AiohttpClientMockResponse:
if method.lower() != self.method.lower(): if method.lower() != self.method.lower():
return False return False
if params:
url = str(yarl.URL(url).with_query(params))
# regular expression matching # regular expression matching
if self._url_parts is None: if isinstance(self._url, re._pattern_type):
return self._url.search(url) is not None return self._url.search(str(url)) is not None
req = urlparse(url.lower()) if (self._url.scheme != url.scheme or self._url.host != url.host or
self._url.path != url.path):
if self._url_parts.scheme and req.scheme != self._url_parts.scheme:
return False
if self._url_parts.netloc and req.netloc != self._url_parts.netloc:
return False
if (req.path or '/') != (self._url_parts.path or '/'):
return False return False
# Ensure all query components in matcher are present in the request # Ensure all query components in matcher are present in the request
request_qs = parse_qs(req.query) request_qs = parse_qs(url.query_string)
matcher_qs = parse_qs(self._url_parts.query) matcher_qs = parse_qs(self._url.query_string)
for key, vals in matcher_qs.items(): for key, vals in matcher_qs.items():
for val in vals: for val in vals:
try: try:
@ -207,12 +210,7 @@ def mock_aiohttp_client():
"""Context manager to mock aiohttp client.""" """Context manager to mock aiohttp client."""
mocker = AiohttpClientMocker() mocker = AiohttpClientMocker()
with mock.patch('aiohttp.ClientSession') as mock_session: with mock.patch(
instance = mock_session() 'homeassistant.helpers.aiohttp_client.async_create_clientsession',
instance.request = mocker.match_request side_effect=lambda hass, *args: mocker.create_session(hass.loop)):
for method in ('get', 'post', 'put', 'options', 'delete'):
setattr(instance, method,
functools.partial(mocker.match_request, method))
yield mocker yield mocker

View file

@ -0,0 +1,22 @@
"""Tests for our aiohttp mocker."""
from .aiohttp import AiohttpClientMocker
import pytest
async def test_matching_url():
"""Test we can match urls."""
mocker = AiohttpClientMocker()
mocker.get('http://example.com')
await mocker.match_request('get', 'http://example.com/')
mocker.clear_requests()
with pytest.raises(AssertionError):
await mocker.match_request('get', 'http://example.com/')
mocker.clear_requests()
mocker.get('http://example.com?a=1')
await mocker.match_request('get', 'http://example.com/',
params={'a': 1, 'b': 2})