Cleanup http (#12424)

* Clean up HTTP component

* Clean up HTTP mock

* Remove unused import

* Fix test

* Lint
This commit is contained in:
Paulus Schoutsen 2018-02-15 13:06:14 -08:00 committed by Pascal Vizeli
parent ad8fe8a93a
commit f32911d036
28 changed files with 811 additions and 1014 deletions

View file

@ -14,7 +14,7 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
) )
from homeassistant.components.http import REQUIREMENTS # NOQA from homeassistant.components.http import REQUIREMENTS # NOQA
from homeassistant.components.http import HomeAssistantWSGI from homeassistant.components.http import HomeAssistantHTTP
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.deprecation import get_deprecated from homeassistant.helpers.deprecation import get_deprecated
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -86,7 +86,7 @@ def setup(hass, yaml_config):
"""Activate the emulated_hue component.""" """Activate the emulated_hue component."""
config = Config(hass, yaml_config.get(DOMAIN, {})) config = Config(hass, yaml_config.get(DOMAIN, {}))
server = HomeAssistantWSGI( server = HomeAssistantHTTP(
hass, hass,
server_host=config.host_ip_addr, server_host=config.host_ip_addr,
server_port=config.listen_port, server_port=config.listen_port,

View file

@ -17,7 +17,7 @@ import jinja2
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.auth import is_trusted_ip from homeassistant.components.http.const import KEY_AUTHENTICATED
from homeassistant.config import find_config_file, load_yaml_config_file from homeassistant.config import find_config_file, load_yaml_config_file
from homeassistant.const import CONF_NAME, EVENT_THEMES_UPDATED from homeassistant.const import CONF_NAME, EVENT_THEMES_UPDATED
from homeassistant.core import callback from homeassistant.core import callback
@ -490,7 +490,7 @@ class IndexView(HomeAssistantView):
panel_url = hass.data[DATA_PANELS][panel].webcomponent_url_es5 panel_url = hass.data[DATA_PANELS][panel].webcomponent_url_es5
no_auth = '1' no_auth = '1'
if hass.config.api.api_password and not is_trusted_ip(request): if hass.config.api.api_password and not request[KEY_AUTHENTICATED]:
# do not try to auto connect on load # do not try to auto connect on load
no_auth = '0' no_auth = '0'

View file

@ -12,35 +12,28 @@ import os
import ssl import ssl
from aiohttp import web from aiohttp import web
from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
SERVER_PORT, CONTENT_TYPE_JSON, HTTP_HEADER_HA_AUTH, SERVER_PORT, CONTENT_TYPE_JSON,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START,)
HTTP_HEADER_X_REQUESTED_WITH)
from homeassistant.core import is_callback from homeassistant.core import is_callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
import homeassistant.remote as rem import homeassistant.remote as rem
import homeassistant.util as hass_util import homeassistant.util as hass_util
from homeassistant.util.logging import HideSensitiveDataFilter from homeassistant.util.logging import HideSensitiveDataFilter
from .auth import auth_middleware from .auth import setup_auth
from .ban import ban_middleware from .ban import setup_bans
from .const import ( from .cors import setup_cors
KEY_BANS_ENABLED, KEY_AUTHENTICATED, KEY_LOGIN_THRESHOLD, from .real_ip import setup_real_ip
KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR) from .const import KEY_AUTHENTICATED, KEY_REAL_IP
from .static import ( from .static import (
CachingFileResponse, CachingStaticResource, staticresource_middleware) CachingFileResponse, CachingStaticResource, staticresource_middleware)
from .util import get_real_ip
REQUIREMENTS = ['aiohttp_cors==0.6.0'] REQUIREMENTS = ['aiohttp_cors==0.6.0']
ALLOWED_CORS_HEADERS = [
ORIGIN, ACCEPT, HTTP_HEADER_X_REQUESTED_WITH, CONTENT_TYPE,
HTTP_HEADER_HA_AUTH]
DOMAIN = 'http' DOMAIN = 'http'
CONF_API_PASSWORD = 'api_password' CONF_API_PASSWORD = 'api_password'
@ -127,7 +120,7 @@ def async_setup(hass, config):
logging.getLogger('aiohttp.access').addFilter( logging.getLogger('aiohttp.access').addFilter(
HideSensitiveDataFilter(api_password)) HideSensitiveDataFilter(api_password))
server = HomeAssistantWSGI( server = HomeAssistantHTTP(
hass, hass,
server_host=server_host, server_host=server_host,
server_port=server_port, server_port=server_port,
@ -173,25 +166,29 @@ def async_setup(hass, config):
return True return True
class HomeAssistantWSGI(object): class HomeAssistantHTTP(object):
"""WSGI server for Home Assistant.""" """HTTP server for Home Assistant."""
def __init__(self, hass, api_password, ssl_certificate, def __init__(self, hass, api_password, ssl_certificate,
ssl_key, server_host, server_port, cors_origins, ssl_key, server_host, server_port, cors_origins,
use_x_forwarded_for, trusted_networks, use_x_forwarded_for, trusted_networks,
login_threshold, is_ban_enabled): login_threshold, is_ban_enabled):
"""Initialize the WSGI Home Assistant server.""" """Initialize the HTTP Home Assistant server."""
middlewares = [auth_middleware, staticresource_middleware] app = self.app = web.Application(
middlewares=[staticresource_middleware])
# This order matters
setup_real_ip(app, use_x_forwarded_for)
if is_ban_enabled: if is_ban_enabled:
middlewares.insert(0, ban_middleware) setup_bans(hass, app, login_threshold)
self.app = web.Application(middlewares=middlewares) setup_auth(app, trusted_networks, api_password)
self.app['hass'] = hass
self.app[KEY_USE_X_FORWARDED_FOR] = use_x_forwarded_for if cors_origins:
self.app[KEY_TRUSTED_NETWORKS] = trusted_networks setup_cors(app, cors_origins)
self.app[KEY_BANS_ENABLED] = is_ban_enabled
self.app[KEY_LOGIN_THRESHOLD] = login_threshold app['hass'] = hass
self.hass = hass self.hass = hass
self.api_password = api_password self.api_password = api_password
@ -199,21 +196,10 @@ class HomeAssistantWSGI(object):
self.ssl_key = ssl_key self.ssl_key = ssl_key
self.server_host = server_host self.server_host = server_host
self.server_port = server_port self.server_port = server_port
self.is_ban_enabled = is_ban_enabled
self._handler = None self._handler = None
self.server = None self.server = None
if cors_origins:
import aiohttp_cors
self.cors = aiohttp_cors.setup(self.app, defaults={
host: aiohttp_cors.ResourceOptions(
allow_headers=ALLOWED_CORS_HEADERS,
allow_methods='*',
) for host in cors_origins
})
else:
self.cors = None
def register_view(self, view): def register_view(self, view):
"""Register a view with the WSGI server. """Register a view with the WSGI server.
@ -292,15 +278,7 @@ class HomeAssistantWSGI(object):
@asyncio.coroutine @asyncio.coroutine
def start(self): def start(self):
"""Start the WSGI server.""" """Start the WSGI server."""
cors_added = set() yield from self.app.startup()
if self.cors is not None:
for route in list(self.app.router.routes()):
if hasattr(route, 'resource'):
route = route.resource
if route in cors_added:
continue
self.cors.add(route)
cors_added.add(route)
if self.ssl_certificate: if self.ssl_certificate:
try: try:
@ -420,7 +398,7 @@ def request_handler_factory(view, handler):
raise HTTPUnauthorized() raise HTTPUnauthorized()
_LOGGER.info('Serving %s to %s (auth: %s)', _LOGGER.info('Serving %s to %s (auth: %s)',
request.path, get_real_ip(request), authenticated) request.path, request.get(KEY_REAL_IP), authenticated)
result = handler(request, **request.match_info) result = handler(request, **request.match_info)

View file

@ -7,55 +7,66 @@ import logging
from aiohttp import hdrs from aiohttp import hdrs
from aiohttp.web import middleware from aiohttp.web import middleware
from homeassistant.core import callback
from homeassistant.const import HTTP_HEADER_HA_AUTH from homeassistant.const import HTTP_HEADER_HA_AUTH
from .util import get_real_ip from .const import KEY_AUTHENTICATED, KEY_REAL_IP
from .const import KEY_TRUSTED_NETWORKS, KEY_AUTHENTICATED
DATA_API_PASSWORD = 'api_password' DATA_API_PASSWORD = 'api_password'
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@middleware @callback
@asyncio.coroutine def setup_auth(app, trusted_networks, api_password):
def auth_middleware(request, handler): """Create auth middleware for the app."""
"""Authenticate as middleware.""" @middleware
# If no password set, just always set authenticated=True @asyncio.coroutine
if request.app['hass'].http.api_password is None: def auth_middleware(request, handler):
request[KEY_AUTHENTICATED] = True """Authenticate as middleware."""
# If no password set, just always set authenticated=True
if api_password is None:
request[KEY_AUTHENTICATED] = True
return (yield from handler(request))
# Check authentication
authenticated = False
if (HTTP_HEADER_HA_AUTH in request.headers and
hmac.compare_digest(
api_password, request.headers[HTTP_HEADER_HA_AUTH])):
# A valid auth header has been set
authenticated = True
elif (DATA_API_PASSWORD in request.query and
hmac.compare_digest(api_password,
request.query[DATA_API_PASSWORD])):
authenticated = True
elif (hdrs.AUTHORIZATION in request.headers and
validate_authorization_header(api_password, request)):
authenticated = True
elif _is_trusted_ip(request, trusted_networks):
authenticated = True
request[KEY_AUTHENTICATED] = authenticated
return (yield from handler(request)) return (yield from handler(request))
# Check authentication @asyncio.coroutine
authenticated = False def auth_startup(app):
"""Initialize auth middleware when app starts up."""
app.middlewares.append(auth_middleware)
if (HTTP_HEADER_HA_AUTH in request.headers and app.on_startup.append(auth_startup)
validate_password(
request, request.headers[HTTP_HEADER_HA_AUTH])):
# A valid auth header has been set
authenticated = True
elif (DATA_API_PASSWORD in request.query and
validate_password(request, request.query[DATA_API_PASSWORD])):
authenticated = True
elif (hdrs.AUTHORIZATION in request.headers and
validate_authorization_header(request)):
authenticated = True
elif is_trusted_ip(request):
authenticated = True
request[KEY_AUTHENTICATED] = authenticated
return (yield from handler(request))
def is_trusted_ip(request): def _is_trusted_ip(request, trusted_networks):
"""Test if request is from a trusted ip.""" """Test if request is from a trusted ip."""
ip_addr = get_real_ip(request) ip_addr = request[KEY_REAL_IP]
return ip_addr and any( return any(
ip_addr in trusted_network for trusted_network ip_addr in trusted_network for trusted_network
in request.app[KEY_TRUSTED_NETWORKS]) in trusted_networks)
def validate_password(request, api_password): def validate_password(request, api_password):
@ -64,7 +75,7 @@ def validate_password(request, api_password):
api_password, request.app['hass'].http.api_password) api_password, request.app['hass'].http.api_password)
def validate_authorization_header(request): def validate_authorization_header(api_password, request):
"""Test an authorization header if valid password.""" """Test an authorization header if valid password."""
if hdrs.AUTHORIZATION not in request.headers: if hdrs.AUTHORIZATION not in request.headers:
return False return False
@ -80,4 +91,4 @@ def validate_authorization_header(request):
if username != 'homeassistant': if username != 'homeassistant':
return False return False
return validate_password(request, password) return hmac.compare_digest(api_password, password)

View file

@ -10,18 +10,20 @@ from aiohttp.web import middleware
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback
from homeassistant.components import persistent_notification from homeassistant.components import persistent_notification
from homeassistant.config import load_yaml_config_file from homeassistant.config import load_yaml_config_file
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.util.yaml import dump from homeassistant.util.yaml import dump
from .const import ( from .const import KEY_REAL_IP
KEY_BANS_ENABLED, KEY_BANNED_IPS, KEY_LOGIN_THRESHOLD,
KEY_FAILED_LOGIN_ATTEMPTS)
from .util import get_real_ip
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
KEY_BANNED_IPS = 'ha_banned_ips'
KEY_FAILED_LOGIN_ATTEMPTS = 'ha_failed_login_attempts'
KEY_LOGIN_THRESHOLD = 'ha_login_threshold'
NOTIFICATION_ID_BAN = 'ip-ban' NOTIFICATION_ID_BAN = 'ip-ban'
NOTIFICATION_ID_LOGIN = 'http-login' NOTIFICATION_ID_LOGIN = 'http-login'
@ -33,21 +35,31 @@ SCHEMA_IP_BAN_ENTRY = vol.Schema({
}) })
@callback
def setup_bans(hass, app, login_threshold):
"""Create IP Ban middleware for the app."""
@asyncio.coroutine
def ban_startup(app):
"""Initialize bans when app starts up."""
app.middlewares.append(ban_middleware)
app[KEY_BANNED_IPS] = yield from hass.async_add_job(
load_ip_bans_config, hass.config.path(IP_BANS_FILE))
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
app[KEY_LOGIN_THRESHOLD] = login_threshold
app.on_startup.append(ban_startup)
@middleware @middleware
@asyncio.coroutine @asyncio.coroutine
def ban_middleware(request, handler): def ban_middleware(request, handler):
"""IP Ban middleware.""" """IP Ban middleware."""
if not request.app[KEY_BANS_ENABLED]: if KEY_BANNED_IPS not in request.app:
_LOGGER.error('IP Ban middleware loaded but banned IPs not loaded')
return (yield from handler(request)) return (yield from handler(request))
if KEY_BANNED_IPS not in request.app:
hass = request.app['hass']
request.app[KEY_BANNED_IPS] = yield from hass.async_add_job(
load_ip_bans_config, hass.config.path(IP_BANS_FILE))
# Verify if IP is not banned # Verify if IP is not banned
ip_address_ = get_real_ip(request) ip_address_ = request[KEY_REAL_IP]
is_banned = any(ip_ban.ip_address == ip_address_ is_banned = any(ip_ban.ip_address == ip_address_
for ip_ban in request.app[KEY_BANNED_IPS]) for ip_ban in request.app[KEY_BANNED_IPS])
@ -64,7 +76,7 @@ def ban_middleware(request, handler):
@asyncio.coroutine @asyncio.coroutine
def process_wrong_login(request): def process_wrong_login(request):
"""Process a wrong login attempt.""" """Process a wrong login attempt."""
remote_addr = get_real_ip(request) remote_addr = request[KEY_REAL_IP]
msg = ('Login attempt or request with invalid authentication ' msg = ('Login attempt or request with invalid authentication '
'from {}'.format(remote_addr)) 'from {}'.format(remote_addr))
@ -73,13 +85,11 @@ def process_wrong_login(request):
request.app['hass'], msg, 'Login attempt failed', request.app['hass'], msg, 'Login attempt failed',
NOTIFICATION_ID_LOGIN) NOTIFICATION_ID_LOGIN)
if (not request.app[KEY_BANS_ENABLED] or # Check if ban middleware is loaded
if (KEY_BANNED_IPS not in request.app or
request.app[KEY_LOGIN_THRESHOLD] < 1): request.app[KEY_LOGIN_THRESHOLD] < 1):
return return
if KEY_FAILED_LOGIN_ATTEMPTS not in request.app:
request.app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1 request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1
if (request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] > if (request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] >

View file

@ -1,11 +1,3 @@
"""HTTP specific constants.""" """HTTP specific constants."""
KEY_AUTHENTICATED = 'ha_authenticated' KEY_AUTHENTICATED = 'ha_authenticated'
KEY_USE_X_FORWARDED_FOR = 'ha_use_x_forwarded_for'
KEY_TRUSTED_NETWORKS = 'ha_trusted_networks'
KEY_REAL_IP = 'ha_real_ip' KEY_REAL_IP = 'ha_real_ip'
KEY_BANS_ENABLED = 'ha_bans_enabled'
KEY_BANNED_IPS = 'ha_banned_ips'
KEY_FAILED_LOGIN_ATTEMPTS = 'ha_failed_login_attempts'
KEY_LOGIN_THRESHOLD = 'ha_login_threshold'
HTTP_HEADER_X_FORWARDED_FOR = 'X-Forwarded-For'

View file

@ -0,0 +1,43 @@
"""Provide cors support for the HTTP component."""
import asyncio
from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE
from homeassistant.const import (
HTTP_HEADER_X_REQUESTED_WITH, HTTP_HEADER_HA_AUTH)
from homeassistant.core import callback
ALLOWED_CORS_HEADERS = [
ORIGIN, ACCEPT, HTTP_HEADER_X_REQUESTED_WITH, CONTENT_TYPE,
HTTP_HEADER_HA_AUTH]
@callback
def setup_cors(app, origins):
"""Setup cors."""
import aiohttp_cors
cors = aiohttp_cors.setup(app, defaults={
host: aiohttp_cors.ResourceOptions(
allow_headers=ALLOWED_CORS_HEADERS,
allow_methods='*',
) for host in origins
})
@asyncio.coroutine
def cors_startup(app):
"""Initialize cors when app starts up."""
cors_added = set()
for route in list(app.router.routes()):
if hasattr(route, 'resource'):
route = route.resource
if route in cors_added:
continue
cors.add(route)
cors_added.add(route)
app.on_startup.append(cors_startup)

View file

@ -0,0 +1,35 @@
"""Middleware to fetch real IP."""
import asyncio
from ipaddress import ip_address
from aiohttp.web import middleware
from aiohttp.hdrs import X_FORWARDED_FOR
from homeassistant.core import callback
from .const import KEY_REAL_IP
@callback
def setup_real_ip(app, use_x_forwarded_for):
"""Create IP Ban middleware for the app."""
@middleware
@asyncio.coroutine
def real_ip_middleware(request, handler):
"""Real IP middleware."""
if (use_x_forwarded_for and
X_FORWARDED_FOR in request.headers):
request[KEY_REAL_IP] = ip_address(
request.headers.get(X_FORWARDED_FOR).split(',')[0])
else:
request[KEY_REAL_IP] = \
ip_address(request.transport.get_extra_info('peername')[0])
return (yield from handler(request))
@asyncio.coroutine
def app_startup(app):
"""Initialize bans when app starts up."""
app.middlewares.append(real_ip_middleware)
app.on_startup.append(app_startup)

View file

@ -1,25 +0,0 @@
"""HTTP utilities."""
from ipaddress import ip_address
from .const import (
KEY_REAL_IP, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
def get_real_ip(request):
"""Get IP address of client."""
if KEY_REAL_IP in request:
return request[KEY_REAL_IP]
if (request.app.get(KEY_USE_X_FORWARDED_FOR) and
HTTP_HEADER_X_FORWARDED_FOR in request.headers):
request[KEY_REAL_IP] = ip_address(
request.headers.get(HTTP_HEADER_X_FORWARDED_FOR).split(',')[0])
else:
peername = request.transport.get_extra_info('peername')
if peername:
request[KEY_REAL_IP] = ip_address(peername[0])
else:
request[KEY_REAL_IP] = None
return request[KEY_REAL_IP]

View file

@ -12,7 +12,7 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.util import get_real_ip from homeassistant.components.http.const import KEY_REAL_IP
from homeassistant.components.telegram_bot import ( from homeassistant.components.telegram_bot import (
CONF_ALLOWED_CHAT_IDS, BaseTelegramBotEntity, PLATFORM_SCHEMA) CONF_ALLOWED_CHAT_IDS, BaseTelegramBotEntity, PLATFORM_SCHEMA)
from homeassistant.const import ( from homeassistant.const import (
@ -110,7 +110,7 @@ class BotPushReceiver(HomeAssistantView, BaseTelegramBotEntity):
@asyncio.coroutine @asyncio.coroutine
def post(self, request): def post(self, request):
"""Accept the POST from telegram.""" """Accept the POST from telegram."""
real_ip = get_real_ip(request) real_ip = request[KEY_REAL_IP]
if not any(real_ip in net for net in self.trusted_networks): if not any(real_ip in net for net in self.trusted_networks):
_LOGGER.warning("Access denied from %s", real_ip) _LOGGER.warning("Access denied from %s", real_ip)
return self.json_message('Access denied', HTTP_UNAUTHORIZED) return self.json_message('Access denied', HTTP_UNAUTHORIZED)

View file

@ -9,8 +9,6 @@ import logging
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
from aiohttp import web
from homeassistant import core as ha, loader from homeassistant import core as ha, loader
from homeassistant.setup import setup_component, async_setup_component from homeassistant.setup import setup_component, async_setup_component
from homeassistant.config import async_process_component_config from homeassistant.config import async_process_component_config
@ -25,9 +23,6 @@ from homeassistant.const import (
EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE, EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE,
ATTR_DISCOVERED, SERVER_PORT, EVENT_HOMEASSISTANT_CLOSE) ATTR_DISCOVERED, SERVER_PORT, EVENT_HOMEASSISTANT_CLOSE)
from homeassistant.components import mqtt, recorder from homeassistant.components import mqtt, recorder
from homeassistant.components.http.auth import auth_middleware
from homeassistant.components.http.const import (
KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED, KEY_TRUSTED_NETWORKS)
from homeassistant.util.async import ( from homeassistant.util.async import (
run_callback_threadsafe, run_coroutine_threadsafe) run_callback_threadsafe, run_coroutine_threadsafe)
@ -262,35 +257,6 @@ def mock_state_change_event(hass, new_state, old_state=None):
hass.bus.fire(EVENT_STATE_CHANGED, event_data) hass.bus.fire(EVENT_STATE_CHANGED, event_data)
def mock_http_component(hass, api_password=None):
"""Mock the HTTP component."""
hass.http = MagicMock(api_password=api_password)
mock_component(hass, 'http')
hass.http.views = {}
def mock_register_view(view):
"""Store registered view."""
if isinstance(view, type):
# Instantiate the view, if needed
view = view()
hass.http.views[view.name] = view
hass.http.register_view = mock_register_view
def mock_http_component_app(hass, api_password=None):
"""Create an aiohttp.web.Application instance for testing."""
if 'http' not in hass.config.components:
mock_http_component(hass, api_password)
app = web.Application(middlewares=[auth_middleware])
app['hass'] = hass
app[KEY_USE_X_FORWARDED_FOR] = False
app[KEY_BANS_ENABLED] = False
app[KEY_TRUSTED_NETWORKS] = []
return app
@asyncio.coroutine @asyncio.coroutine
def async_mock_mqtt_component(hass, config=None): def async_mock_mqtt_component(hass, config=None):
"""Mock the MQTT component.""" """Mock the MQTT component."""

View file

@ -9,7 +9,7 @@ from uvcclient import nvr
from homeassistant.setup import setup_component from homeassistant.setup import setup_component
from homeassistant.components.camera import uvc from homeassistant.components.camera import uvc
from tests.common import get_test_home_assistant, mock_http_component from tests.common import get_test_home_assistant
class TestUVCSetup(unittest.TestCase): class TestUVCSetup(unittest.TestCase):
@ -18,7 +18,6 @@ class TestUVCSetup(unittest.TestCase):
def setUp(self): def setUp(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant() self.hass = get_test_home_assistant()
mock_http_component(self.hass)
def tearDown(self): def tearDown(self):
"""Stop everything that was started.""" """Stop everything that was started."""

View file

@ -14,7 +14,7 @@ def test_setup_check_env_prevents_load(hass, loop):
with patch.dict(os.environ, clear=True), \ with patch.dict(os.environ, clear=True), \
patch.object(config, 'SECTIONS', ['hassbian']), \ patch.object(config, 'SECTIONS', ['hassbian']), \
patch('homeassistant.components.http.' patch('homeassistant.components.http.'
'HomeAssistantWSGI.register_view') as reg_view: 'HomeAssistantHTTP.register_view') as reg_view:
loop.run_until_complete(async_setup_component(hass, 'config', {})) loop.run_until_complete(async_setup_component(hass, 'config', {}))
assert 'config' in hass.config.components assert 'config' in hass.config.components
assert reg_view.called is False assert reg_view.called is False
@ -25,7 +25,7 @@ def test_setup_check_env_works(hass, loop):
with patch.dict(os.environ, {'FORCE_HASSBIAN': '1'}), \ with patch.dict(os.environ, {'FORCE_HASSBIAN': '1'}), \
patch.object(config, 'SECTIONS', ['hassbian']), \ patch.object(config, 'SECTIONS', ['hassbian']), \
patch('homeassistant.components.http.' patch('homeassistant.components.http.'
'HomeAssistantWSGI.register_view') as reg_view: 'HomeAssistantHTTP.register_view') as reg_view:
loop.run_until_complete(async_setup_component(hass, 'config', {})) loop.run_until_complete(async_setup_component(hass, 'config', {}))
assert 'config' in hass.config.components assert 'config' in hass.config.components
assert len(reg_view.mock_calls) == 2 assert len(reg_view.mock_calls) == 2

View file

@ -2,19 +2,11 @@
import asyncio import asyncio
from unittest.mock import patch from unittest.mock import patch
import pytest
from homeassistant.const import EVENT_COMPONENT_LOADED from homeassistant.const import EVENT_COMPONENT_LOADED
from homeassistant.setup import async_setup_component, ATTR_COMPONENT from homeassistant.setup import async_setup_component, ATTR_COMPONENT
from homeassistant.components import config from homeassistant.components import config
from tests.common import mock_http_component, mock_coro, mock_component from tests.common import mock_coro, mock_component
@pytest.fixture(autouse=True)
def stub_http(hass):
"""Stub the HTTP component."""
mock_http_component(hass)
@asyncio.coroutine @asyncio.coroutine

View file

@ -3,28 +3,30 @@ import asyncio
import json import json
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from homeassistant.bootstrap import async_setup_component from homeassistant.bootstrap import async_setup_component
from homeassistant.components import config from homeassistant.components import config
from homeassistant.components.zwave import DATA_NETWORK, const from homeassistant.components.zwave import DATA_NETWORK, const
from homeassistant.components.config.zwave import (
ZWaveNodeValueView, ZWaveNodeGroupView, ZWaveNodeConfigView,
ZWaveUserCodeView, ZWaveConfigWriteView)
from tests.common import mock_http_component_app
from tests.mock.zwave import MockNode, MockValue, MockEntityValues from tests.mock.zwave import MockNode, MockValue, MockEntityValues
VIEW_NAME = 'api:config:zwave:device_config' VIEW_NAME = 'api:config:zwave:device_config'
@asyncio.coroutine @pytest.fixture
def test_get_device_config(hass, test_client): def client(loop, hass, test_client):
"""Test getting device config.""" """Client to communicate with Z-Wave config views."""
with patch.object(config, 'SECTIONS', ['zwave']): with patch.object(config, 'SECTIONS', ['zwave']):
yield from async_setup_component(hass, 'config', {}) loop.run_until_complete(async_setup_component(hass, 'config', {}))
client = yield from test_client(hass.http.app) return loop.run_until_complete(test_client(hass.http.app))
@asyncio.coroutine
def test_get_device_config(client):
"""Test getting device config."""
def mock_read(path): def mock_read(path):
"""Mock reading data.""" """Mock reading data."""
return { return {
@ -47,13 +49,8 @@ def test_get_device_config(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_update_device_config(hass, test_client): def test_update_device_config(client):
"""Test updating device config.""" """Test updating device config."""
with patch.object(config, 'SECTIONS', ['zwave']):
yield from async_setup_component(hass, 'config', {})
client = yield from test_client(hass.http.app)
orig_data = { orig_data = {
'hello.beer': { 'hello.beer': {
'ignored': True, 'ignored': True,
@ -90,13 +87,8 @@ def test_update_device_config(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_update_device_config_invalid_key(hass, test_client): def test_update_device_config_invalid_key(client):
"""Test updating device config.""" """Test updating device config."""
with patch.object(config, 'SECTIONS', ['zwave']):
yield from async_setup_component(hass, 'config', {})
client = yield from test_client(hass.http.app)
resp = yield from client.post( resp = yield from client.post(
'/api/config/zwave/device_config/invalid_entity', data=json.dumps({ '/api/config/zwave/device_config/invalid_entity', data=json.dumps({
'polling_intensity': 2 'polling_intensity': 2
@ -106,13 +98,8 @@ def test_update_device_config_invalid_key(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_update_device_config_invalid_data(hass, test_client): def test_update_device_config_invalid_data(client):
"""Test updating device config.""" """Test updating device config."""
with patch.object(config, 'SECTIONS', ['zwave']):
yield from async_setup_component(hass, 'config', {})
client = yield from test_client(hass.http.app)
resp = yield from client.post( resp = yield from client.post(
'/api/config/zwave/device_config/hello.beer', data=json.dumps({ '/api/config/zwave/device_config/hello.beer', data=json.dumps({
'invalid_option': 2 'invalid_option': 2
@ -122,13 +109,8 @@ def test_update_device_config_invalid_data(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_update_device_config_invalid_json(hass, test_client): def test_update_device_config_invalid_json(client):
"""Test updating device config.""" """Test updating device config."""
with patch.object(config, 'SECTIONS', ['zwave']):
yield from async_setup_component(hass, 'config', {})
client = yield from test_client(hass.http.app)
resp = yield from client.post( resp = yield from client.post(
'/api/config/zwave/device_config/hello.beer', data='not json') '/api/config/zwave/device_config/hello.beer', data='not json')
@ -136,11 +118,8 @@ def test_update_device_config_invalid_json(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_get_values(hass, test_client): def test_get_values(hass, client):
"""Test getting values on node.""" """Test getting values on node."""
app = mock_http_component_app(hass)
ZWaveNodeValueView().register(app.router)
node = MockNode(node_id=1) node = MockNode(node_id=1)
value = MockValue(value_id=123456, node=node, label='Test Label', value = MockValue(value_id=123456, node=node, label='Test Label',
instance=1, index=2, poll_intensity=4) instance=1, index=2, poll_intensity=4)
@ -150,8 +129,6 @@ def test_get_values(hass, test_client):
values2 = MockEntityValues(primary=value2) values2 = MockEntityValues(primary=value2)
hass.data[const.DATA_ENTITY_VALUES] = [values, values2] hass.data[const.DATA_ENTITY_VALUES] = [values, values2]
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/values/1') resp = yield from client.get('/api/zwave/values/1')
assert resp.status == 200 assert resp.status == 200
@ -168,11 +145,8 @@ def test_get_values(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_get_groups(hass, test_client): def test_get_groups(hass, client):
"""Test getting groupdata on node.""" """Test getting groupdata on node."""
app = mock_http_component_app(hass)
ZWaveNodeGroupView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock() network = hass.data[DATA_NETWORK] = MagicMock()
node = MockNode(node_id=2) node = MockNode(node_id=2)
node.groups.associations = 'assoc' node.groups.associations = 'assoc'
@ -182,8 +156,6 @@ def test_get_groups(hass, test_client):
node.groups = {1: node.groups} node.groups = {1: node.groups}
network.nodes = {2: node} network.nodes = {2: node}
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/groups/2') resp = yield from client.get('/api/zwave/groups/2')
assert resp.status == 200 assert resp.status == 200
@ -200,18 +172,13 @@ def test_get_groups(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_get_groups_nogroups(hass, test_client): def test_get_groups_nogroups(hass, client):
"""Test getting groupdata on node with no groups.""" """Test getting groupdata on node with no groups."""
app = mock_http_component_app(hass)
ZWaveNodeGroupView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock() network = hass.data[DATA_NETWORK] = MagicMock()
node = MockNode(node_id=2) node = MockNode(node_id=2)
network.nodes = {2: node} network.nodes = {2: node}
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/groups/2') resp = yield from client.get('/api/zwave/groups/2')
assert resp.status == 200 assert resp.status == 200
@ -221,16 +188,11 @@ def test_get_groups_nogroups(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_get_groups_nonode(hass, test_client): def test_get_groups_nonode(hass, client):
"""Test getting groupdata on nonexisting node.""" """Test getting groupdata on nonexisting node."""
app = mock_http_component_app(hass)
ZWaveNodeGroupView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock() network = hass.data[DATA_NETWORK] = MagicMock()
network.nodes = {1: 1, 5: 5} network.nodes = {1: 1, 5: 5}
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/groups/2') resp = yield from client.get('/api/zwave/groups/2')
assert resp.status == 404 assert resp.status == 404
@ -240,11 +202,8 @@ def test_get_groups_nonode(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_get_config(hass, test_client): def test_get_config(hass, client):
"""Test getting config on node.""" """Test getting config on node."""
app = mock_http_component_app(hass)
ZWaveNodeConfigView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock() network = hass.data[DATA_NETWORK] = MagicMock()
node = MockNode(node_id=2) node = MockNode(node_id=2)
value = MockValue( value = MockValue(
@ -261,8 +220,6 @@ def test_get_config(hass, test_client):
network.nodes = {2: node} network.nodes = {2: node}
node.get_values.return_value = node.values node.get_values.return_value = node.values
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/config/2') resp = yield from client.get('/api/zwave/config/2')
assert resp.status == 200 assert resp.status == 200
@ -278,19 +235,14 @@ def test_get_config(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_get_config_noconfig_node(hass, test_client): def test_get_config_noconfig_node(hass, client):
"""Test getting config on node without config.""" """Test getting config on node without config."""
app = mock_http_component_app(hass)
ZWaveNodeConfigView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock() network = hass.data[DATA_NETWORK] = MagicMock()
node = MockNode(node_id=2) node = MockNode(node_id=2)
network.nodes = {2: node} network.nodes = {2: node}
node.get_values.return_value = node.values node.get_values.return_value = node.values
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/config/2') resp = yield from client.get('/api/zwave/config/2')
assert resp.status == 200 assert resp.status == 200
@ -300,16 +252,11 @@ def test_get_config_noconfig_node(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_get_config_nonode(hass, test_client): def test_get_config_nonode(hass, client):
"""Test getting config on nonexisting node.""" """Test getting config on nonexisting node."""
app = mock_http_component_app(hass)
ZWaveNodeConfigView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock() network = hass.data[DATA_NETWORK] = MagicMock()
network.nodes = {1: 1, 5: 5} network.nodes = {1: 1, 5: 5}
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/config/2') resp = yield from client.get('/api/zwave/config/2')
assert resp.status == 404 assert resp.status == 404
@ -319,16 +266,11 @@ def test_get_config_nonode(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_get_usercodes_nonode(hass, test_client): def test_get_usercodes_nonode(hass, client):
"""Test getting usercodes on nonexisting node.""" """Test getting usercodes on nonexisting node."""
app = mock_http_component_app(hass)
ZWaveUserCodeView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock() network = hass.data[DATA_NETWORK] = MagicMock()
network.nodes = {1: 1, 5: 5} network.nodes = {1: 1, 5: 5}
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/usercodes/2') resp = yield from client.get('/api/zwave/usercodes/2')
assert resp.status == 404 assert resp.status == 404
@ -338,11 +280,8 @@ def test_get_usercodes_nonode(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_get_usercodes(hass, test_client): def test_get_usercodes(hass, client):
"""Test getting usercodes on node.""" """Test getting usercodes on node."""
app = mock_http_component_app(hass)
ZWaveUserCodeView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock() network = hass.data[DATA_NETWORK] = MagicMock()
node = MockNode(node_id=18, node = MockNode(node_id=18,
command_classes=[const.COMMAND_CLASS_USER_CODE]) command_classes=[const.COMMAND_CLASS_USER_CODE])
@ -356,8 +295,6 @@ def test_get_usercodes(hass, test_client):
network.nodes = {18: node} network.nodes = {18: node}
node.get_values.return_value = node.values node.get_values.return_value = node.values
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/usercodes/18') resp = yield from client.get('/api/zwave/usercodes/18')
assert resp.status == 200 assert resp.status == 200
@ -369,19 +306,14 @@ def test_get_usercodes(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_get_usercode_nousercode_node(hass, test_client): def test_get_usercode_nousercode_node(hass, client):
"""Test getting usercodes on node without usercodes.""" """Test getting usercodes on node without usercodes."""
app = mock_http_component_app(hass)
ZWaveUserCodeView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock() network = hass.data[DATA_NETWORK] = MagicMock()
node = MockNode(node_id=18) node = MockNode(node_id=18)
network.nodes = {18: node} network.nodes = {18: node}
node.get_values.return_value = node.values node.get_values.return_value = node.values
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/usercodes/18') resp = yield from client.get('/api/zwave/usercodes/18')
assert resp.status == 200 assert resp.status == 200
@ -391,11 +323,8 @@ def test_get_usercode_nousercode_node(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_get_usercodes_no_genreuser(hass, test_client): def test_get_usercodes_no_genreuser(hass, client):
"""Test getting usercodes on node missing genre user.""" """Test getting usercodes on node missing genre user."""
app = mock_http_component_app(hass)
ZWaveUserCodeView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock() network = hass.data[DATA_NETWORK] = MagicMock()
node = MockNode(node_id=18, node = MockNode(node_id=18,
command_classes=[const.COMMAND_CLASS_USER_CODE]) command_classes=[const.COMMAND_CLASS_USER_CODE])
@ -409,8 +338,6 @@ def test_get_usercodes_no_genreuser(hass, test_client):
network.nodes = {18: node} network.nodes = {18: node}
node.get_values.return_value = node.values node.get_values.return_value = node.values
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/usercodes/18') resp = yield from client.get('/api/zwave/usercodes/18')
assert resp.status == 200 assert resp.status == 200
@ -420,13 +347,8 @@ def test_get_usercodes_no_genreuser(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_save_config_no_network(hass, test_client): def test_save_config_no_network(hass, client):
"""Test saving configuration without network data.""" """Test saving configuration without network data."""
app = mock_http_component_app(hass)
ZWaveConfigWriteView().register(app.router)
client = yield from test_client(app)
resp = yield from client.post('/api/zwave/saveconfig') resp = yield from client.post('/api/zwave/saveconfig')
assert resp.status == 404 assert resp.status == 404
@ -435,15 +357,10 @@ def test_save_config_no_network(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_save_config(hass, test_client): def test_save_config(hass, client):
"""Test saving configuration.""" """Test saving configuration."""
app = mock_http_component_app(hass)
ZWaveConfigWriteView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock() network = hass.data[DATA_NETWORK] = MagicMock()
client = yield from test_client(app)
resp = yield from client.post('/api/zwave/saveconfig') resp = yield from client.post('/api/zwave/saveconfig')
assert resp.status == 200 assert resp.status == 200

View file

@ -5,11 +5,10 @@ import logging
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
import aioautomatic import aioautomatic
from homeassistant.setup import async_setup_component
from homeassistant.components.device_tracker.automatic import ( from homeassistant.components.device_tracker.automatic import (
async_setup_scanner) async_setup_scanner)
from tests.common import mock_http_component
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -23,8 +22,7 @@ def test_invalid_credentials(
mock_open, mock_isfile, mock_makedirs, mock_json_dump, mock_json_load, mock_open, mock_isfile, mock_makedirs, mock_json_dump, mock_json_load,
mock_create_session, hass): mock_create_session, hass):
"""Test with invalid credentials.""" """Test with invalid credentials."""
mock_http_component(hass) hass.loop.run_until_complete(async_setup_component(hass, 'http', {}))
mock_json_load.return_value = {'refresh_token': 'bad_token'} mock_json_load.return_value = {'refresh_token': 'bad_token'}
@asyncio.coroutine @asyncio.coroutine
@ -59,8 +57,7 @@ def test_valid_credentials(
mock_open, mock_isfile, mock_makedirs, mock_json_dump, mock_json_load, mock_open, mock_isfile, mock_makedirs, mock_json_dump, mock_json_load,
mock_ws_connect, mock_create_session, hass): mock_ws_connect, mock_create_session, hass):
"""Test with valid credentials.""" """Test with valid credentials."""
mock_http_component(hass) hass.loop.run_until_complete(async_setup_component(hass, 'http', {}))
mock_json_load.return_value = {'refresh_token': 'good_token'} mock_json_load.return_value = {'refresh_token': 'good_token'}
session = MagicMock() session = MagicMock()

View file

@ -1 +1,38 @@
"""Tests for the HTTP component.""" """Tests for the HTTP component."""
import asyncio
from ipaddress import ip_address
from aiohttp import web
from homeassistant.components.http.const import KEY_REAL_IP
def mock_real_ip(app):
"""Inject middleware to mock real IP.
Returns a function to set the real IP.
"""
ip_to_mock = None
def set_ip_to_mock(value):
nonlocal ip_to_mock
ip_to_mock = value
@asyncio.coroutine
@web.middleware
def mock_real_ip(request, handler):
"""Mock Real IP middleware."""
nonlocal ip_to_mock
request[KEY_REAL_IP] = ip_address(ip_to_mock)
return (yield from handler(request))
@asyncio.coroutine
def real_ip_startup(app):
"""Startup of real ip."""
app.middlewares.insert(0, mock_real_ip)
app.on_startup.append(real_ip_startup)
return set_ip_to_mock

View file

@ -1,195 +1,156 @@
"""The tests for the Home Assistant HTTP component.""" """The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access # pylint: disable=protected-access
import asyncio import asyncio
from ipaddress import ip_address, ip_network from ipaddress import ip_network
from unittest.mock import patch from unittest.mock import patch
import aiohttp from aiohttp import BasicAuth, web
from aiohttp.web_exceptions import HTTPUnauthorized
import pytest import pytest
from homeassistant import const from homeassistant.const import HTTP_HEADER_HA_AUTH
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.components.http as http from homeassistant.components.http.auth import setup_auth
from homeassistant.components.http.const import ( from homeassistant.components.http.real_ip import setup_real_ip
KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR) from homeassistant.components.http.const import KEY_AUTHENTICATED
from . import mock_real_ip
API_PASSWORD = 'test1234' API_PASSWORD = 'test1234'
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases # Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1', TRUSTED_NETWORKS = [
'FD01:DB8::1'] ip_network('192.0.2.0/24'),
ip_network('2001:DB8:ABCD::/48'),
ip_network('100.64.0.1'),
ip_network('FD01:DB8::1'),
]
TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1', TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
'2001:DB8:ABCD::1'] '2001:DB8:ABCD::1']
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1'] UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
@pytest.fixture @asyncio.coroutine
def mock_api_client(hass, test_client): def mock_handler(request):
"""Start the Hass HTTP component.""" """Return if request was authenticated."""
hass.loop.run_until_complete(async_setup_component(hass, 'api', { if not request[KEY_AUTHENTICATED]:
'http': { raise HTTPUnauthorized
http.CONF_API_PASSWORD: API_PASSWORD, return web.Response(status=200)
}
}))
return hass.loop.run_until_complete(test_client(hass.http.app))
@pytest.fixture @pytest.fixture
def mock_trusted_networks(hass, mock_api_client): def app():
"""Mock trusted networks.""" """Fixture to setup a web.Application."""
hass.http.app[KEY_TRUSTED_NETWORKS] = [ app = web.Application()
ip_network(trusted_network) app.router.add_get('/', mock_handler)
for trusted_network in TRUSTED_NETWORKS] setup_real_ip(app, False)
return app
@asyncio.coroutine @asyncio.coroutine
def test_access_denied_without_password(mock_api_client): def test_auth_middleware_loaded_by_default(hass):
"""Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_auth') as mock_setup:
yield from async_setup_component(hass, 'http', {
'http': {}
})
assert len(mock_setup.mock_calls) == 1
@asyncio.coroutine
def test_access_without_password(app, test_client):
"""Test access without password.""" """Test access without password."""
resp = yield from mock_api_client.get(const.URL_API) setup_auth(app, [], None)
client = yield from test_client(app)
resp = yield from client.get('/')
assert resp.status == 200
@asyncio.coroutine
def test_access_with_password_in_header(app, test_client):
"""Test access with password in URL."""
setup_auth(app, [], API_PASSWORD)
client = yield from test_client(app)
req = yield from client.get(
'/', headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status == 200
req = yield from client.get(
'/', headers={HTTP_HEADER_HA_AUTH: 'wrong-pass'})
assert req.status == 401
@asyncio.coroutine
def test_access_with_password_in_query(app, test_client):
"""Test access without password."""
setup_auth(app, [], API_PASSWORD)
client = yield from test_client(app)
resp = yield from client.get('/', params={
'api_password': API_PASSWORD
})
assert resp.status == 200
resp = yield from client.get('/')
assert resp.status == 401 assert resp.status == 401
resp = yield from client.get('/', params={
@asyncio.coroutine 'api_password': 'wrong-password'
def test_access_denied_with_wrong_password_in_header(mock_api_client):
"""Test access with wrong password."""
resp = yield from mock_api_client.get(const.URL_API, headers={
const.HTTP_HEADER_HA_AUTH: 'wrongpassword'
}) })
assert resp.status == 401 assert resp.status == 401
@asyncio.coroutine @asyncio.coroutine
def test_access_denied_with_x_forwarded_for(hass, mock_api_client, def test_basic_auth_works(app, test_client):
mock_trusted_networks):
"""Test access denied through the X-Forwarded-For http header."""
hass.http.use_x_forwarded_for = True
for remote_addr in UNTRUSTED_ADDRESSES:
resp = yield from mock_api_client.get(const.URL_API, headers={
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
assert resp.status == 401, \
"{} shouldn't be trusted".format(remote_addr)
@asyncio.coroutine
def test_access_denied_with_untrusted_ip(mock_api_client,
mock_trusted_networks):
"""Test access with an untrusted ip address."""
for remote_addr in UNTRUSTED_ADDRESSES:
with patch('homeassistant.components.http.'
'util.get_real_ip',
return_value=ip_address(remote_addr)):
resp = yield from mock_api_client.get(
const.URL_API, params={'api_password': ''})
assert resp.status == 401, \
"{} shouldn't be trusted".format(remote_addr)
@asyncio.coroutine
def test_access_with_password_in_header(mock_api_client, caplog):
"""Test access with password in URL."""
# Hide logging from requests package that we use to test logging
req = yield from mock_api_client.get(
const.URL_API, headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status == 200
logs = caplog.text
assert const.URL_API in logs
assert API_PASSWORD not in logs
@asyncio.coroutine
def test_access_denied_with_wrong_password_in_url(mock_api_client):
"""Test access with wrong password."""
resp = yield from mock_api_client.get(
const.URL_API, params={'api_password': 'wrongpassword'})
assert resp.status == 401
@asyncio.coroutine
def test_access_with_password_in_url(mock_api_client, caplog):
"""Test access with password in URL."""
req = yield from mock_api_client.get(
const.URL_API, params={'api_password': API_PASSWORD})
assert req.status == 200
logs = caplog.text
assert const.URL_API in logs
assert API_PASSWORD not in logs
@asyncio.coroutine
def test_access_granted_with_x_forwarded_for(hass, mock_api_client, caplog,
mock_trusted_networks):
"""Test access denied through the X-Forwarded-For http header."""
hass.http.app[KEY_USE_X_FORWARDED_FOR] = True
for remote_addr in TRUSTED_ADDRESSES:
resp = yield from mock_api_client.get(const.URL_API, headers={
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
assert resp.status == 200, \
"{} should be trusted".format(remote_addr)
@asyncio.coroutine
def test_access_granted_with_trusted_ip(mock_api_client, caplog,
mock_trusted_networks):
"""Test access with trusted addresses."""
for remote_addr in TRUSTED_ADDRESSES:
with patch('homeassistant.components.http.'
'auth.get_real_ip',
return_value=ip_address(remote_addr)):
resp = yield from mock_api_client.get(
const.URL_API, params={'api_password': ''})
assert resp.status == 200, \
'{} should be trusted'.format(remote_addr)
@asyncio.coroutine
def test_basic_auth_works(mock_api_client, caplog):
"""Test access with basic authentication.""" """Test access with basic authentication."""
req = yield from mock_api_client.get( setup_auth(app, [], API_PASSWORD)
const.URL_API, client = yield from test_client(app)
auth=aiohttp.BasicAuth('homeassistant', API_PASSWORD))
req = yield from client.get(
'/',
auth=BasicAuth('homeassistant', API_PASSWORD))
assert req.status == 200 assert req.status == 200
assert const.URL_API in caplog.text
@asyncio.coroutine
def test_basic_auth_username_homeassistant(mock_api_client, caplog):
"""Test access with basic auth requires username homeassistant."""
req = yield from mock_api_client.get(
const.URL_API,
auth=aiohttp.BasicAuth('wrong_username', API_PASSWORD))
req = yield from client.get(
'/',
auth=BasicAuth('wrong_username', API_PASSWORD))
assert req.status == 401 assert req.status == 401
req = yield from client.get(
@asyncio.coroutine '/',
def test_basic_auth_wrong_password(mock_api_client, caplog): auth=BasicAuth('homeassistant', 'wrong password'))
"""Test access with basic auth not allowed with wrong password."""
req = yield from mock_api_client.get(
const.URL_API,
auth=aiohttp.BasicAuth('homeassistant', 'wrong password'))
assert req.status == 401 assert req.status == 401
req = yield from client.get(
@asyncio.coroutine '/',
def test_authorization_header_must_be_basic_type(mock_api_client, caplog):
"""Test only basic authorization is allowed for auth header."""
req = yield from mock_api_client.get(
const.URL_API,
headers={ headers={
'authorization': 'NotBasic abcdefg' 'authorization': 'NotBasic abcdefg'
}) })
assert req.status == 401 assert req.status == 401
@asyncio.coroutine
def test_access_with_trusted_ip(test_client):
"""Test access with an untrusted ip address."""
app = web.Application()
app.router.add_get('/', mock_handler)
setup_auth(app, TRUSTED_NETWORKS, 'some-pass')
set_mock_ip = mock_real_ip(app)
client = yield from test_client(app)
for remote_addr in UNTRUSTED_ADDRESSES:
set_mock_ip(remote_addr)
resp = yield from client.get('/')
assert resp.status == 401, \
"{} shouldn't be trusted".format(remote_addr)
for remote_addr in TRUSTED_ADDRESSES:
set_mock_ip(remote_addr)
resp = yield from client.get('/')
assert resp.status == 200, \
"{} should be trusted".format(remote_addr)

View file

@ -1,91 +1,96 @@
"""The tests for the Home Assistant HTTP component.""" """The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access # pylint: disable=protected-access
import asyncio import asyncio
from ipaddress import ip_address
from unittest.mock import patch, mock_open from unittest.mock import patch, mock_open
import pytest from aiohttp import web
from aiohttp.web_exceptions import HTTPUnauthorized
from homeassistant import const
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.components.http as http import homeassistant.components.http as http
from homeassistant.components.http.const import ( from homeassistant.components.http.ban import (
KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD, KEY_BANNED_IPS) IpBan, IP_BANS_FILE, setup_bans, KEY_BANNED_IPS)
from homeassistant.components.http.ban import IpBan, IP_BANS_FILE
from . import mock_real_ip
API_PASSWORD = 'test1234'
BANNED_IPS = ['200.201.202.203', '100.64.0.2'] BANNED_IPS = ['200.201.202.203', '100.64.0.2']
@pytest.fixture
def mock_api_client(hass, test_client):
"""Start the Hass HTTP component."""
hass.loop.run_until_complete(async_setup_component(hass, 'api', {
'http': {
http.CONF_API_PASSWORD: API_PASSWORD,
}
}))
hass.http.app[KEY_BANNED_IPS] = [IpBan(banned_ip) for banned_ip
in BANNED_IPS]
return hass.loop.run_until_complete(test_client(hass.http.app))
@asyncio.coroutine @asyncio.coroutine
def test_access_from_banned_ip(hass, mock_api_client): def test_access_from_banned_ip(hass, test_client):
"""Test accessing to server from banned IP. Both trusted and not.""" """Test accessing to server from banned IP. Both trusted and not."""
hass.http.app[KEY_BANS_ENABLED] = True app = web.Application()
setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app)
with patch('homeassistant.components.http.ban.load_ip_bans_config',
return_value=[IpBan(banned_ip) for banned_ip
in BANNED_IPS]):
client = yield from test_client(app)
for remote_addr in BANNED_IPS: for remote_addr in BANNED_IPS:
with patch('homeassistant.components.http.' set_real_ip(remote_addr)
'ban.get_real_ip', resp = yield from client.get('/')
return_value=ip_address(remote_addr)): assert resp.status == 403
resp = yield from mock_api_client.get(
const.URL_API)
assert resp.status == 403
@asyncio.coroutine @asyncio.coroutine
def test_access_from_banned_ip_when_ban_is_off(hass, mock_api_client): def test_ban_middleware_not_loaded_by_config(hass):
"""Test accessing to server from banned IP when feature is off.""" """Test accessing to server from banned IP when feature is off."""
hass.http.app[KEY_BANS_ENABLED] = False with patch('homeassistant.components.http.setup_bans') as mock_setup:
for remote_addr in BANNED_IPS: yield from async_setup_component(hass, 'http', {
with patch('homeassistant.components.http.' 'http': {
'ban.get_real_ip', http.CONF_IP_BAN_ENABLED: False,
return_value=ip_address(remote_addr)): }
resp = yield from mock_api_client.get( })
const.URL_API,
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD}) assert len(mock_setup.mock_calls) == 0
assert resp.status == 200
@asyncio.coroutine @asyncio.coroutine
def test_ip_bans_file_creation(hass, mock_api_client): def test_ban_middleware_loaded_by_default(hass):
"""Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_bans') as mock_setup:
yield from async_setup_component(hass, 'http', {
'http': {}
})
assert len(mock_setup.mock_calls) == 1
@asyncio.coroutine
def test_ip_bans_file_creation(hass, test_client):
"""Testing if banned IP file created.""" """Testing if banned IP file created."""
hass.http.app[KEY_BANS_ENABLED] = True app = web.Application()
hass.http.app[KEY_LOGIN_THRESHOLD] = 1 app['hass'] = hass
@asyncio.coroutine
def unauth_handler(request):
"""Return a mock web response."""
raise HTTPUnauthorized
app.router.add_get('/', unauth_handler)
setup_bans(hass, app, 1)
mock_real_ip(app)("200.201.202.204")
with patch('homeassistant.components.http.ban.load_ip_bans_config',
return_value=[IpBan(banned_ip) for banned_ip
in BANNED_IPS]):
client = yield from test_client(app)
m = mock_open() m = mock_open()
@asyncio.coroutine
def call_server():
with patch('homeassistant.components.http.'
'ban.get_real_ip',
return_value=ip_address("200.201.202.204")):
resp = yield from mock_api_client.get(
const.URL_API,
headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'})
return resp
with patch('homeassistant.components.http.ban.open', m, create=True): with patch('homeassistant.components.http.ban.open', m, create=True):
resp = yield from call_server() resp = yield from client.get('/')
assert resp.status == 401 assert resp.status == 401
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS) assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS)
assert m.call_count == 0 assert m.call_count == 0
resp = yield from call_server() resp = yield from client.get('/')
assert resp.status == 401 assert resp.status == 401
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1 assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a') m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
resp = yield from call_server() resp = yield from client.get('/')
assert resp.status == 403 assert resp.status == 403
assert m.call_count == 1 assert m.call_count == 1

View file

@ -0,0 +1,104 @@
"""Test cors for the HTTP component."""
import asyncio
from unittest.mock import patch
from aiohttp import web
from aiohttp.hdrs import (
ACCESS_CONTROL_ALLOW_ORIGIN,
ACCESS_CONTROL_ALLOW_HEADERS,
ACCESS_CONTROL_REQUEST_HEADERS,
ACCESS_CONTROL_REQUEST_METHOD,
ORIGIN
)
import pytest
from homeassistant.const import HTTP_HEADER_HA_AUTH
from homeassistant.setup import async_setup_component
from homeassistant.components.http.cors import setup_cors
TRUSTED_ORIGIN = 'https://home-assistant.io'
@asyncio.coroutine
def test_cors_middleware_not_loaded_by_default(hass):
"""Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_cors') as mock_setup:
yield from async_setup_component(hass, 'http', {
'http': {}
})
assert len(mock_setup.mock_calls) == 0
@asyncio.coroutine
def test_cors_middleware_loaded_from_config(hass):
"""Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_cors') as mock_setup:
yield from async_setup_component(hass, 'http', {
'http': {
'cors_allowed_origins': ['http://home-assistant.io']
}
})
assert len(mock_setup.mock_calls) == 1
@asyncio.coroutine
def mock_handler(request):
"""Return if request was authenticated."""
return web.Response(status=200)
@pytest.fixture
def client(loop, test_client):
"""Fixture to setup a web.Application."""
app = web.Application()
app.router.add_get('/', mock_handler)
setup_cors(app, [TRUSTED_ORIGIN])
return loop.run_until_complete(test_client(app))
@asyncio.coroutine
def test_cors_requests(client):
"""Test cross origin requests."""
req = yield from client.get('/', headers={
ORIGIN: TRUSTED_ORIGIN
})
assert req.status == 200
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \
TRUSTED_ORIGIN
# With password in URL
req = yield from client.get('/', params={
'api_password': 'some-pass'
}, headers={
ORIGIN: TRUSTED_ORIGIN
})
assert req.status == 200
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \
TRUSTED_ORIGIN
# With password in headers
req = yield from client.get('/', headers={
HTTP_HEADER_HA_AUTH: 'some-pass',
ORIGIN: TRUSTED_ORIGIN
})
assert req.status == 200
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \
TRUSTED_ORIGIN
@asyncio.coroutine
def test_cors_preflight_allowed(client):
"""Test cross origin resource sharing preflight (OPTIONS) request."""
req = yield from client.options('/', headers={
ORIGIN: TRUSTED_ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD: 'GET',
ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access'
})
assert req.status == 200
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == TRUSTED_ORIGIN
assert req.headers[ACCESS_CONTROL_ALLOW_HEADERS] == \
HTTP_HEADER_HA_AUTH.upper()

View file

@ -1,124 +1,10 @@
"""The tests for the Home Assistant HTTP component.""" """The tests for the Home Assistant HTTP component."""
import asyncio import asyncio
from aiohttp.hdrs import ( from homeassistant.setup import async_setup_component
ORIGIN, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_ALLOW_HEADERS,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS,
CONTENT_TYPE)
import requests
from tests.common import get_test_instance_port, get_test_home_assistant
from homeassistant import const, setup
import homeassistant.components.http as http import homeassistant.components.http as http
API_PASSWORD = 'test1234'
SERVER_PORT = get_test_instance_port()
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
HA_HEADERS = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
CONTENT_TYPE: const.CONTENT_TYPE_JSON,
}
CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE]
hass = None
def _url(path=''):
"""Helper method to generate URLs."""
return HTTP_BASE_URL + path
# pylint: disable=invalid-name
def setUpModule():
"""Initialize a Home Assistant server."""
global hass
hass = get_test_home_assistant()
setup.setup_component(
hass, http.DOMAIN, {
http.DOMAIN: {
http.CONF_API_PASSWORD: API_PASSWORD,
http.CONF_SERVER_PORT: SERVER_PORT,
http.CONF_CORS_ORIGINS: CORS_ORIGINS,
}
}
)
setup.setup_component(hass, 'api')
# Registering static path as it caused CORS to blow up
hass.http.register_static_path(
'/custom_components', hass.config.path('custom_components'))
hass.start()
# pylint: disable=invalid-name
def tearDownModule():
"""Stop the Home Assistant server."""
hass.stop()
class TestCors:
"""Test HTTP component."""
def test_cors_allowed_with_password_in_url(self):
"""Test cross origin resource sharing with password in url."""
req = requests.get(_url(const.URL_API),
params={'api_password': API_PASSWORD},
headers={ORIGIN: HTTP_BASE_URL})
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
def test_cors_allowed_with_password_in_header(self):
"""Test cross origin resource sharing with password in header."""
headers = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
ORIGIN: HTTP_BASE_URL
}
req = requests.get(_url(const.URL_API), headers=headers)
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
def test_cors_denied_without_origin_header(self):
"""Test cross origin resource sharing with password in header."""
headers = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD
}
req = requests.get(_url(const.URL_API), headers=headers)
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = ACCESS_CONTROL_ALLOW_HEADERS
assert req.status_code == 200
assert allow_origin not in req.headers
assert allow_headers not in req.headers
def test_cors_preflight_allowed(self):
"""Test cross origin resource sharing preflight (OPTIONS) request."""
headers = {
ORIGIN: HTTP_BASE_URL,
ACCESS_CONTROL_REQUEST_METHOD: 'GET',
ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access'
}
req = requests.options(_url(const.URL_API), headers=headers)
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = ACCESS_CONTROL_ALLOW_HEADERS
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
assert req.headers.get(allow_headers) == \
const.HTTP_HEADER_HA_AUTH.upper()
class TestView(http.HomeAssistantView): class TestView(http.HomeAssistantView):
"""Test the HTTP views.""" """Test the HTTP views."""
@ -133,12 +19,12 @@ class TestView(http.HomeAssistantView):
@asyncio.coroutine @asyncio.coroutine
def test_registering_view_while_running(hass, test_client): def test_registering_view_while_running(hass, test_client, unused_port):
"""Test that we can register a view while the server is running.""" """Test that we can register a view while the server is running."""
yield from setup.async_setup_component( yield from async_setup_component(
hass, http.DOMAIN, { hass, http.DOMAIN, {
http.DOMAIN: { http.DOMAIN: {
http.CONF_SERVER_PORT: get_test_instance_port(), http.CONF_SERVER_PORT: unused_port(),
} }
} }
) )
@ -151,7 +37,7 @@ def test_registering_view_while_running(hass, test_client):
@asyncio.coroutine @asyncio.coroutine
def test_api_base_url_with_domain(hass): def test_api_base_url_with_domain(hass):
"""Test setting API URL.""" """Test setting API URL."""
result = yield from setup.async_setup_component(hass, 'http', { result = yield from async_setup_component(hass, 'http', {
'http': { 'http': {
'base_url': 'example.com' 'base_url': 'example.com'
} }
@ -163,7 +49,7 @@ def test_api_base_url_with_domain(hass):
@asyncio.coroutine @asyncio.coroutine
def test_api_base_url_with_ip(hass): def test_api_base_url_with_ip(hass):
"""Test setting api url.""" """Test setting api url."""
result = yield from setup.async_setup_component(hass, 'http', { result = yield from async_setup_component(hass, 'http', {
'http': { 'http': {
'server_host': '1.1.1.1' 'server_host': '1.1.1.1'
} }
@ -175,7 +61,7 @@ def test_api_base_url_with_ip(hass):
@asyncio.coroutine @asyncio.coroutine
def test_api_base_url_with_ip_port(hass): def test_api_base_url_with_ip_port(hass):
"""Test setting api url.""" """Test setting api url."""
result = yield from setup.async_setup_component(hass, 'http', { result = yield from async_setup_component(hass, 'http', {
'http': { 'http': {
'base_url': '1.1.1.1:8124' 'base_url': '1.1.1.1:8124'
} }
@ -187,9 +73,34 @@ def test_api_base_url_with_ip_port(hass):
@asyncio.coroutine @asyncio.coroutine
def test_api_no_base_url(hass): def test_api_no_base_url(hass):
"""Test setting api url.""" """Test setting api url."""
result = yield from setup.async_setup_component(hass, 'http', { result = yield from async_setup_component(hass, 'http', {
'http': { 'http': {
} }
}) })
assert result assert result
assert hass.config.api.base_url == 'http://127.0.0.1:8123' assert hass.config.api.base_url == 'http://127.0.0.1:8123'
@asyncio.coroutine
def test_not_log_password(hass, unused_port, test_client, caplog):
"""Test access with password doesn't get logged."""
result = yield from async_setup_component(hass, 'api', {
'http': {
http.CONF_SERVER_PORT: unused_port(),
http.CONF_API_PASSWORD: 'some-pass'
}
})
assert result
client = yield from test_client(hass.http.app)
resp = yield from client.get('/api/', params={
'api_password': 'some-pass'
})
assert resp.status == 200
logs = caplog.text
# Ensure we don't log API passwords
assert '/api/' in logs
assert 'some-pass' not in logs

View file

@ -0,0 +1,48 @@
"""Test real IP middleware."""
import asyncio
from aiohttp import web
from aiohttp.hdrs import X_FORWARDED_FOR
from homeassistant.components.http.real_ip import setup_real_ip
from homeassistant.components.http.const import KEY_REAL_IP
@asyncio.coroutine
def mock_handler(request):
"""Handler that returns the real IP as text."""
return web.Response(text=str(request[KEY_REAL_IP]))
@asyncio.coroutine
def test_ignore_x_forwarded_for(test_client):
"""Test that we get the IP from the transport."""
app = web.Application()
app.router.add_get('/', mock_handler)
setup_real_ip(app, False)
mock_api_client = yield from test_client(app)
resp = yield from mock_api_client.get('/', headers={
X_FORWARDED_FOR: '255.255.255.255'
})
assert resp.status == 200
text = yield from resp.text()
assert text != '255.255.255.255'
@asyncio.coroutine
def test_use_x_forwarded_for(test_client):
"""Test that we get the IP from the transport."""
app = web.Application()
app.router.add_get('/', mock_handler)
setup_real_ip(app, True)
mock_api_client = yield from test_client(app)
resp = yield from mock_api_client.get('/', headers={
X_FORWARDED_FOR: '255.255.255.255'
})
assert resp.status == 200
text = yield from resp.text()
assert text == '255.255.255.255'

View file

@ -4,8 +4,7 @@ from unittest.mock import Mock, MagicMock, patch
from homeassistant.setup import setup_component from homeassistant.setup import setup_component
import homeassistant.components.mqtt as mqtt import homeassistant.components.mqtt as mqtt
from tests.common import ( from tests.common import get_test_home_assistant, mock_coro
get_test_home_assistant, mock_coro, mock_http_component)
class TestMQTT: class TestMQTT:
@ -14,7 +13,9 @@ class TestMQTT:
def setup_method(self, method): def setup_method(self, method):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant() self.hass = get_test_home_assistant()
mock_http_component(self.hass, 'super_secret') setup_component(self.hass, 'http', {
'api_password': 'super_secret'
})
def teardown_method(self, method): def teardown_method(self, method):
"""Stop everything that was started.""" """Stop everything that was started."""

View file

@ -4,12 +4,10 @@ import json
from unittest.mock import patch, MagicMock, mock_open from unittest.mock import patch, MagicMock, mock_open
from aiohttp.hdrs import AUTHORIZATION from aiohttp.hdrs import AUTHORIZATION
from homeassistant.setup import async_setup_component
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.json import save_json
from homeassistant.components.notify import html5 from homeassistant.components.notify import html5
from tests.common import mock_http_component_app
CONFIG_FILE = 'file.conf' CONFIG_FILE = 'file.conf'
SUBSCRIPTION_1 = { SUBSCRIPTION_1 = {
@ -52,6 +50,23 @@ REGISTER_URL = '/api/notify.html5'
PUBLISH_URL = '/api/notify.html5/callback' PUBLISH_URL = '/api/notify.html5/callback'
@asyncio.coroutine
def mock_client(hass, test_client, registrations=None):
"""Create a test client for HTML5 views."""
if registrations is None:
registrations = {}
with patch('homeassistant.components.notify.html5._load_config',
return_value=registrations):
yield from async_setup_component(hass, 'notify', {
'notify': {
'platform': 'html5'
}
})
return (yield from test_client(hass.http.app))
class TestHtml5Notify(object): class TestHtml5Notify(object):
"""Tests for HTML5 notify platform.""" """Tests for HTML5 notify platform."""
@ -89,8 +104,6 @@ class TestHtml5Notify(object):
service.send_message('Hello', target=['device', 'non_existing'], service.send_message('Hello', target=['device', 'non_existing'],
data={'icon': 'beer.png'}) data={'icon': 'beer.png'})
print(mock_wp.mock_calls)
assert len(mock_wp.mock_calls) == 3 assert len(mock_wp.mock_calls) == 3
# WebPusher constructor # WebPusher constructor
@ -104,421 +117,224 @@ class TestHtml5Notify(object):
assert payload['body'] == 'Hello' assert payload['body'] == 'Hello'
assert payload['icon'] == 'beer.png' assert payload['icon'] == 'beer.png'
@asyncio.coroutine
def test_registering_new_device_view(self, loop, test_client):
"""Test that the HTML view works."""
hass = MagicMock()
expected = {
'unnamed device': SUBSCRIPTION_1,
}
hass.config.path.return_value = CONFIG_FILE @asyncio.coroutine
service = html5.get_service(hass, {}) def test_registering_new_device_view(hass, test_client):
"""Test that the HTML view works."""
client = yield from mock_client(hass, test_client)
assert service is not None with patch('homeassistant.components.notify.html5.save_json') as mock_save:
assert len(hass.mock_calls) == 3
view = hass.mock_calls[1][1][0]
assert view.json_path == hass.config.path.return_value
assert view.registrations == {}
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(REGISTER_URL, resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_1)) data=json.dumps(SUBSCRIPTION_1))
content = yield from resp.text() assert resp.status == 200
assert resp.status == 200, content assert len(mock_save.mock_calls) == 1
assert view.registrations == expected assert mock_save.mock_calls[0][1][1] == {
'unnamed device': SUBSCRIPTION_1,
}
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
@asyncio.coroutine @asyncio.coroutine
def test_registering_new_device_expiration_view(self, loop, test_client): def test_registering_new_device_expiration_view(hass, test_client):
"""Test that the HTML view works.""" """Test that the HTML view works."""
hass = MagicMock() client = yield from mock_client(hass, test_client)
expected = {
'unnamed device': SUBSCRIPTION_4,
}
hass.config.path.return_value = CONFIG_FILE with patch('homeassistant.components.notify.html5.save_json') as mock_save:
service = html5.get_service(hass, {})
assert service is not None
# assert hass.called
assert len(hass.mock_calls) == 3
view = hass.mock_calls[1][1][0]
assert view.json_path == hass.config.path.return_value
assert view.registrations == {}
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(REGISTER_URL, resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_4)) data=json.dumps(SUBSCRIPTION_4))
content = yield from resp.text() assert resp.status == 200
assert resp.status == 200, content assert mock_save.mock_calls[0][1][1] == {
assert view.registrations == expected 'unnamed device': SUBSCRIPTION_4,
}
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
@asyncio.coroutine @asyncio.coroutine
def test_registering_new_device_fails_view(self, loop, test_client): def test_registering_new_device_fails_view(hass, test_client):
"""Test subs. are not altered when registering a new device fails.""" """Test subs. are not altered when registering a new device fails."""
hass = MagicMock() registrations = {}
expected = {} client = yield from mock_client(hass, test_client, registrations)
hass.config.path.return_value = CONFIG_FILE
html5.get_service(hass, {})
view = hass.mock_calls[1][1][0]
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
hass.async_add_job.side_effect = HomeAssistantError()
with patch('homeassistant.components.notify.html5.save_json',
side_effect=HomeAssistantError()):
resp = yield from client.post(REGISTER_URL, resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_1)) data=json.dumps(SUBSCRIPTION_4))
content = yield from resp.text() assert resp.status == 500
assert resp.status == 500, content assert registrations == {}
assert view.registrations == expected
@asyncio.coroutine
def test_registering_existing_device_view(self, loop, test_client):
"""Test subscription is updated when registering existing device."""
hass = MagicMock()
expected = {
'unnamed device': SUBSCRIPTION_4,
}
hass.config.path.return_value = CONFIG_FILE @asyncio.coroutine
html5.get_service(hass, {}) def test_registering_existing_device_view(hass, test_client):
view = hass.mock_calls[1][1][0] """Test subscription is updated when registering existing device."""
registrations = {}
hass.loop = loop client = yield from mock_client(hass, test_client, registrations)
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
yield from client.post(REGISTER_URL, yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_1)) data=json.dumps(SUBSCRIPTION_1))
resp = yield from client.post(REGISTER_URL, resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_4)) data=json.dumps(SUBSCRIPTION_4))
content = yield from resp.text() assert resp.status == 200
assert resp.status == 200, content assert mock_save.mock_calls[0][1][1] == {
assert view.registrations == expected 'unnamed device': SUBSCRIPTION_4,
}
assert registrations == {
'unnamed device': SUBSCRIPTION_4,
}
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
@asyncio.coroutine @asyncio.coroutine
def test_registering_existing_device_fails_view(self, loop, test_client): def test_registering_existing_device_fails_view(hass, test_client):
"""Test sub. is not updated when registering existing device fails.""" """Test sub. is not updated when registering existing device fails."""
hass = MagicMock() registrations = {}
expected = { client = yield from mock_client(hass, test_client, registrations)
'unnamed device': SUBSCRIPTION_1,
}
hass.config.path.return_value = CONFIG_FILE
html5.get_service(hass, {})
view = hass.mock_calls[1][1][0]
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
yield from client.post(REGISTER_URL, yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_1)) data=json.dumps(SUBSCRIPTION_1))
mock_save.side_effect = HomeAssistantError
hass.async_add_job.side_effect = HomeAssistantError()
resp = yield from client.post(REGISTER_URL, resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_4)) data=json.dumps(SUBSCRIPTION_4))
content = yield from resp.text() assert resp.status == 500
assert resp.status == 500, content assert registrations == {
assert view.registrations == expected 'unnamed device': SUBSCRIPTION_1,
}
@asyncio.coroutine
def test_registering_new_device_validation(self, loop, test_client):
"""Test various errors when registering a new device."""
hass = MagicMock()
m = mock_open() @asyncio.coroutine
with patch( def test_registering_new_device_validation(hass, test_client):
'homeassistant.util.json.open', """Test various errors when registering a new device."""
m, create=True client = yield from mock_client(hass, test_client)
):
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None resp = yield from client.post(REGISTER_URL, data=json.dumps({
'browser': 'invalid browser',
'subscription': 'sub info',
}))
assert resp.status == 400
# assert hass.called resp = yield from client.post(REGISTER_URL, data=json.dumps({
assert len(hass.mock_calls) == 3 'browser': 'chrome',
}))
assert resp.status == 400
view = hass.mock_calls[1][1][0] with patch('homeassistant.components.notify.html5.save_json',
return_value=False):
resp = yield from client.post(REGISTER_URL, data=json.dumps({
'browser': 'chrome',
'subscription': 'sub info',
}))
assert resp.status == 400
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(REGISTER_URL, data=json.dumps({ @asyncio.coroutine
'browser': 'invalid browser', def test_unregistering_device_view(hass, test_client):
'subscription': 'sub info', """Test that the HTML unregister view works."""
})) registrations = {
assert resp.status == 400 'some device': SUBSCRIPTION_1,
'other device': SUBSCRIPTION_2,
}
client = yield from mock_client(hass, test_client, registrations)
resp = yield from client.post(REGISTER_URL, data=json.dumps({ with patch('homeassistant.components.notify.html5.save_json') as mock_save:
'browser': 'chrome', resp = yield from client.delete(REGISTER_URL, data=json.dumps({
})) 'subscription': SUBSCRIPTION_1['subscription'],
assert resp.status == 400 }))
with patch('homeassistant.components.notify.html5.save_json', assert resp.status == 200
return_value=False): assert len(mock_save.mock_calls) == 1
# resp = view.post(Request(builder.get_environ())) assert registrations == {
resp = yield from client.post(REGISTER_URL, data=json.dumps({ 'other device': SUBSCRIPTION_2
'browser': 'chrome', }
'subscription': 'sub info',
}))
assert resp.status == 400
@asyncio.coroutine @asyncio.coroutine
def test_unregistering_device_view(self, loop, test_client): def test_unregister_device_view_handle_unknown_subscription(hass, test_client):
"""Test that the HTML unregister view works.""" """Test that the HTML unregister view handles unknown subscriptions."""
hass = MagicMock() registrations = {}
client = yield from mock_client(hass, test_client, registrations)
config = { with patch('homeassistant.components.notify.html5.save_json') as mock_save:
'some device': SUBSCRIPTION_1, resp = yield from client.delete(REGISTER_URL, data=json.dumps({
'other device': SUBSCRIPTION_2, 'subscription': SUBSCRIPTION_3['subscription']
} }))
m = mock_open(read_data=json.dumps(config)) assert resp.status == 200, resp.response
with patch( assert registrations == {}
'homeassistant.util.json.open', assert len(mock_save.mock_calls) == 0
m, create=True
):
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
# assert hass.called @asyncio.coroutine
assert len(hass.mock_calls) == 3 def test_unregistering_device_view_handles_save_error(hass, test_client):
"""Test that the HTML unregister view handles save errors."""
registrations = {
'some device': SUBSCRIPTION_1,
'other device': SUBSCRIPTION_2,
}
client = yield from mock_client(hass, test_client, registrations)
view = hass.mock_calls[1][1][0] with patch('homeassistant.components.notify.html5.save_json',
assert view.json_path == hass.config.path.return_value side_effect=HomeAssistantError()):
assert view.registrations == config resp = yield from client.delete(REGISTER_URL, data=json.dumps({
'subscription': SUBSCRIPTION_1['subscription'],
}))
hass.loop = loop assert resp.status == 500, resp.response
app = mock_http_component_app(hass) assert registrations == {
view.register(app.router) 'some device': SUBSCRIPTION_1,
client = yield from test_client(app) 'other device': SUBSCRIPTION_2,
hass.http.is_banned_ip.return_value = False }
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
'subscription': SUBSCRIPTION_1['subscription'],
}))
config.pop('some device') @asyncio.coroutine
def test_callback_view_no_jwt(hass, test_client):
"""Test that the notification callback view works without JWT."""
client = yield from mock_client(hass, test_client)
resp = yield from client.post(PUBLISH_URL, data=json.dumps({
'type': 'push',
'tag': '3bc28d69-0921-41f1-ac6a-7a627ba0aa72'
}))
assert resp.status == 200, resp.response assert resp.status == 401, resp.response
assert view.registrations == config
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE,
config)
@asyncio.coroutine @asyncio.coroutine
def test_unregister_device_view_handle_unknown_subscription( def test_callback_view_with_jwt(hass, test_client):
self, loop, test_client): """Test that the notification callback view works with JWT."""
"""Test that the HTML unregister view handles unknown subscriptions.""" registrations = {
hass = MagicMock() 'device': SUBSCRIPTION_1
}
client = yield from mock_client(hass, test_client, registrations)
config = { with patch('pywebpush.WebPusher') as mock_wp:
'some device': SUBSCRIPTION_1, yield from hass.services.async_call('notify', 'notify', {
'other device': SUBSCRIPTION_2, 'message': 'Hello',
} 'target': ['device'],
'data': {'icon': 'beer.png'}
}, blocking=True)
m = mock_open(read_data=json.dumps(config)) assert len(mock_wp.mock_calls) == 3
with patch(
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None # WebPusher constructor
assert mock_wp.mock_calls[0][1][0] == \
SUBSCRIPTION_1['subscription']
# Third mock_call checks the status_code of the response.
assert mock_wp.mock_calls[2][0] == '().send().status_code.__eq__'
# assert hass.called # Call to send
assert len(hass.mock_calls) == 3 push_payload = json.loads(mock_wp.mock_calls[1][1][0])
view = hass.mock_calls[1][1][0] assert push_payload['body'] == 'Hello'
assert view.json_path == hass.config.path.return_value assert push_payload['icon'] == 'beer.png'
assert view.registrations == config
hass.loop = loop bearer_token = "Bearer {}".format(push_payload['data']['jwt'])
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.delete(REGISTER_URL, data=json.dumps({ resp = yield from client.post(PUBLISH_URL, json={
'subscription': SUBSCRIPTION_3['subscription'] 'type': 'push',
})) }, headers={AUTHORIZATION: bearer_token})
assert resp.status == 200, resp.response assert resp.status == 200
assert view.registrations == config body = yield from resp.json()
assert body == {"event": "push", "status": "ok"}
hass.async_add_job.assert_not_called()
@asyncio.coroutine
def test_unregistering_device_view_handles_save_error(
self, loop, test_client):
"""Test that the HTML unregister view handles save errors."""
hass = MagicMock()
config = {
'some device': SUBSCRIPTION_1,
'other device': SUBSCRIPTION_2,
}
m = mock_open(read_data=json.dumps(config))
with patch(
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
# assert hass.called
assert len(hass.mock_calls) == 3
view = hass.mock_calls[1][1][0]
assert view.json_path == hass.config.path.return_value
assert view.registrations == config
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
hass.async_add_job.side_effect = HomeAssistantError()
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
'subscription': SUBSCRIPTION_1['subscription'],
}))
assert resp.status == 500, resp.response
assert view.registrations == config
@asyncio.coroutine
def test_callback_view_no_jwt(self, loop, test_client):
"""Test that the notification callback view works without JWT."""
hass = MagicMock()
m = mock_open()
with patch(
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
# assert hass.called
assert len(hass.mock_calls) == 3
view = hass.mock_calls[2][1][0]
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(PUBLISH_URL, data=json.dumps({
'type': 'push',
'tag': '3bc28d69-0921-41f1-ac6a-7a627ba0aa72'
}))
assert resp.status == 401, resp.response
@asyncio.coroutine
def test_callback_view_with_jwt(self, loop, test_client):
"""Test that the notification callback view works with JWT."""
hass = MagicMock()
data = {
'device': SUBSCRIPTION_1
}
m = mock_open(read_data=json.dumps(data))
with patch(
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {'gcm_sender_id': '100'})
assert service is not None
# assert hass.called
assert len(hass.mock_calls) == 3
with patch('pywebpush.WebPusher') as mock_wp:
service.send_message(
'Hello', target=['device'], data={'icon': 'beer.png'})
assert len(mock_wp.mock_calls) == 3
# WebPusher constructor
assert mock_wp.mock_calls[0][1][0] == \
SUBSCRIPTION_1['subscription']
# Third mock_call checks the status_code of the response.
assert mock_wp.mock_calls[2][0] == '().send().status_code.__eq__'
# Call to send
push_payload = json.loads(mock_wp.mock_calls[1][1][0])
assert push_payload['body'] == 'Hello'
assert push_payload['icon'] == 'beer.png'
view = hass.mock_calls[2][1][0]
view.registrations = data
bearer_token = "Bearer {}".format(push_payload['data']['jwt'])
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(PUBLISH_URL, data=json.dumps({
'type': 'push',
}), headers={AUTHORIZATION: bearer_token})
assert resp.status == 200
body = yield from resp.json()
assert body == {"event": "push", "status": "ok"}

View file

@ -10,8 +10,7 @@ import homeassistant.util.dt as dt_util
from homeassistant.components import history, recorder from homeassistant.components import history, recorder
from tests.common import ( from tests.common import (
init_recorder_component, mock_http_component, mock_state_change_event, init_recorder_component, mock_state_change_event, get_test_home_assistant)
get_test_home_assistant)
class TestComponentHistory(unittest.TestCase): class TestComponentHistory(unittest.TestCase):
@ -38,7 +37,6 @@ class TestComponentHistory(unittest.TestCase):
def test_setup(self): def test_setup(self):
"""Test setup method of history.""" """Test setup method of history."""
mock_http_component(self.hass)
config = history.CONFIG_SCHEMA({ config = history.CONFIG_SCHEMA({
# ha.DOMAIN: {}, # ha.DOMAIN: {},
history.DOMAIN: { history.DOMAIN: {

View file

@ -14,7 +14,7 @@ from homeassistant.components import logbook
from homeassistant.setup import setup_component from homeassistant.setup import setup_component
from tests.common import ( from tests.common import (
mock_http_component, init_recorder_component, get_test_home_assistant) init_recorder_component, get_test_home_assistant)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -29,10 +29,7 @@ class TestComponentLogbook(unittest.TestCase):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant() self.hass = get_test_home_assistant()
init_recorder_component(self.hass) # Force an in memory DB init_recorder_component(self.hass) # Force an in memory DB
mock_http_component(self.hass) assert setup_component(self.hass, logbook.DOMAIN, self.EMPTY_CONFIG)
self.hass.config.components |= set(['frontend', 'recorder', 'api'])
assert setup_component(self.hass, logbook.DOMAIN,
self.EMPTY_CONFIG)
self.hass.start() self.hass.start()
def tearDown(self): def tearDown(self):

View file

@ -150,7 +150,6 @@ def test_api_update_fails(hass, test_client):
assert resp.status == 404 assert resp.status == 404
beer_id = hass.data['shopping_list'].items[0]['id'] beer_id = hass.data['shopping_list'].items[0]['id']
client = yield from test_client(hass.http.app)
resp = yield from client.post( resp = yield from client.post(
'/api/shopping_list/item/{}'.format(beer_id), json={ '/api/shopping_list/item/{}'.format(beer_id), json={
'name': 123, 'name': 123,

View file

@ -8,8 +8,9 @@ import pytest
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.components import websocket_api as wapi, frontend from homeassistant.components import websocket_api as wapi, frontend
from homeassistant.setup import async_setup_component
from tests.common import mock_http_component_app, mock_coro from tests.common import mock_coro
API_PASSWORD = 'test1234' API_PASSWORD = 'test1234'
@ -17,10 +18,10 @@ API_PASSWORD = 'test1234'
@pytest.fixture @pytest.fixture
def websocket_client(loop, hass, test_client): def websocket_client(loop, hass, test_client):
"""Websocket client fixture connected to websocket server.""" """Websocket client fixture connected to websocket server."""
websocket_app = mock_http_component_app(hass) assert loop.run_until_complete(
wapi.WebsocketAPIView().register(websocket_app.router) async_setup_component(hass, 'websocket_api'))
client = loop.run_until_complete(test_client(websocket_app)) client = loop.run_until_complete(test_client(hass.http.app))
ws = loop.run_until_complete(client.ws_connect(wapi.URL)) ws = loop.run_until_complete(client.ws_connect(wapi.URL))
auth_ok = loop.run_until_complete(ws.receive_json()) auth_ok = loop.run_until_complete(ws.receive_json())
@ -35,10 +36,14 @@ def websocket_client(loop, hass, test_client):
@pytest.fixture @pytest.fixture
def no_auth_websocket_client(hass, loop, test_client): def no_auth_websocket_client(hass, loop, test_client):
"""Websocket connection that requires authentication.""" """Websocket connection that requires authentication."""
websocket_app = mock_http_component_app(hass, API_PASSWORD) assert loop.run_until_complete(
wapi.WebsocketAPIView().register(websocket_app.router) async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
}))
client = loop.run_until_complete(test_client(websocket_app)) client = loop.run_until_complete(test_client(hass.http.app))
ws = loop.run_until_complete(client.ws_connect(wapi.URL)) ws = loop.run_until_complete(client.ws_connect(wapi.URL))
auth_ok = loop.run_until_complete(ws.receive_json()) auth_ok = loop.run_until_complete(ws.receive_json())