Move HomeAssistantView to separate file. Convert http to async syntax. [skip ci] (#12982)
* Move HomeAssistantView to separate file. Convert http to async syntax. * pylint * websocket api * update emulated_hue for async/await * Lint
This commit is contained in:
parent
2ee73ca911
commit
321eb2ec6f
17 changed files with 292 additions and 344 deletions
|
@ -4,7 +4,6 @@ Support for local control of entities by emulating the Phillips Hue bridge.
|
||||||
For more details about this component, please refer to the documentation at
|
For more details about this component, please refer to the documentation at
|
||||||
https://home-assistant.io/components/emulated_hue/
|
https://home-assistant.io/components/emulated_hue/
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -111,17 +110,15 @@ def setup(hass, yaml_config):
|
||||||
config.upnp_bind_multicast, config.advertise_ip,
|
config.upnp_bind_multicast, config.advertise_ip,
|
||||||
config.advertise_port)
|
config.advertise_port)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def stop_emulated_hue_bridge(event):
|
||||||
def stop_emulated_hue_bridge(event):
|
|
||||||
"""Stop the emulated hue bridge."""
|
"""Stop the emulated hue bridge."""
|
||||||
upnp_listener.stop()
|
upnp_listener.stop()
|
||||||
yield from server.stop()
|
await server.stop()
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def start_emulated_hue_bridge(event):
|
||||||
def start_emulated_hue_bridge(event):
|
|
||||||
"""Start the emulated hue bridge."""
|
"""Start the emulated hue bridge."""
|
||||||
upnp_listener.start()
|
upnp_listener.start()
|
||||||
yield from server.start()
|
await server.start()
|
||||||
hass.bus.async_listen_once(
|
hass.bus.async_listen_once(
|
||||||
EVENT_HOMEASSISTANT_STOP, stop_emulated_hue_bridge)
|
EVENT_HOMEASSISTANT_STOP, stop_emulated_hue_bridge)
|
||||||
|
|
||||||
|
|
|
@ -4,21 +4,18 @@ This module provides WSGI application to serve the Home Assistant API.
|
||||||
For more details about this component, please refer to the documentation at
|
For more details about this component, please refer to the documentation at
|
||||||
https://home-assistant.io/components/http/
|
https://home-assistant.io/components/http/
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
from ipaddress import ip_network
|
from ipaddress import ip_network
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import ssl
|
import ssl
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently
|
from aiohttp.web_exceptions import HTTPMovedPermanently
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
SERVER_PORT, CONTENT_TYPE_JSON,
|
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, SERVER_PORT)
|
||||||
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START,)
|
|
||||||
from homeassistant.core import is_callback
|
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
import homeassistant.remote as rem
|
import homeassistant.remote as rem
|
||||||
import homeassistant.util as hass_util
|
import homeassistant.util as hass_util
|
||||||
|
@ -28,10 +25,13 @@ from .auth import setup_auth
|
||||||
from .ban import setup_bans
|
from .ban import setup_bans
|
||||||
from .cors import setup_cors
|
from .cors import setup_cors
|
||||||
from .real_ip import setup_real_ip
|
from .real_ip import setup_real_ip
|
||||||
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
|
|
||||||
from .static import (
|
from .static import (
|
||||||
CachingFileResponse, CachingStaticResource, staticresource_middleware)
|
CachingFileResponse, CachingStaticResource, staticresource_middleware)
|
||||||
|
|
||||||
|
# Import as alias
|
||||||
|
from .const import KEY_AUTHENTICATED, KEY_REAL_IP # noqa
|
||||||
|
from .view import HomeAssistantView # noqa
|
||||||
|
|
||||||
REQUIREMENTS = ['aiohttp_cors==0.6.0']
|
REQUIREMENTS = ['aiohttp_cors==0.6.0']
|
||||||
|
|
||||||
DOMAIN = 'http'
|
DOMAIN = 'http'
|
||||||
|
@ -98,8 +98,7 @@ CONFIG_SCHEMA = vol.Schema({
|
||||||
}, extra=vol.ALLOW_EXTRA)
|
}, extra=vol.ALLOW_EXTRA)
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_setup(hass, config):
|
||||||
def async_setup(hass, config):
|
|
||||||
"""Set up the HTTP API and debug interface."""
|
"""Set up the HTTP API and debug interface."""
|
||||||
conf = config.get(DOMAIN)
|
conf = config.get(DOMAIN)
|
||||||
|
|
||||||
|
@ -135,16 +134,14 @@ def async_setup(hass, config):
|
||||||
is_ban_enabled=is_ban_enabled
|
is_ban_enabled=is_ban_enabled
|
||||||
)
|
)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def stop_server(event):
|
||||||
def stop_server(event):
|
|
||||||
"""Stop the server."""
|
"""Stop the server."""
|
||||||
yield from server.stop()
|
await server.stop()
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def start_server(event):
|
||||||
def start_server(event):
|
|
||||||
"""Start the server."""
|
"""Start the server."""
|
||||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
|
||||||
yield from server.start()
|
await server.start()
|
||||||
|
|
||||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_server)
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_server)
|
||||||
|
|
||||||
|
@ -252,13 +249,11 @@ class HomeAssistantHTTP(object):
|
||||||
return
|
return
|
||||||
|
|
||||||
if cache_headers:
|
if cache_headers:
|
||||||
@asyncio.coroutine
|
async def serve_file(request):
|
||||||
def serve_file(request):
|
|
||||||
"""Serve file from disk."""
|
"""Serve file from disk."""
|
||||||
return CachingFileResponse(path)
|
return CachingFileResponse(path)
|
||||||
else:
|
else:
|
||||||
@asyncio.coroutine
|
async def serve_file(request):
|
||||||
def serve_file(request):
|
|
||||||
"""Serve file from disk."""
|
"""Serve file from disk."""
|
||||||
return web.FileResponse(path)
|
return web.FileResponse(path)
|
||||||
|
|
||||||
|
@ -276,14 +271,13 @@ class HomeAssistantHTTP(object):
|
||||||
|
|
||||||
self.app.router.add_route('GET', url_pattern, serve_file)
|
self.app.router.add_route('GET', url_pattern, serve_file)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def start(self):
|
||||||
def start(self):
|
|
||||||
"""Start the WSGI server."""
|
"""Start the WSGI server."""
|
||||||
# We misunderstood the startup signal. You're not allowed to change
|
# We misunderstood the startup signal. You're not allowed to change
|
||||||
# anything during startup. Temp workaround.
|
# anything during startup. Temp workaround.
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
self.app._on_startup.freeze()
|
self.app._on_startup.freeze()
|
||||||
yield from self.app.startup()
|
await self.app.startup()
|
||||||
|
|
||||||
if self.ssl_certificate:
|
if self.ssl_certificate:
|
||||||
try:
|
try:
|
||||||
|
@ -308,121 +302,18 @@ class HomeAssistantHTTP(object):
|
||||||
self._handler = self.app.make_handler(loop=self.hass.loop)
|
self._handler = self.app.make_handler(loop=self.hass.loop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.server = yield from self.hass.loop.create_server(
|
self.server = await self.hass.loop.create_server(
|
||||||
self._handler, self.server_host, self.server_port, ssl=context)
|
self._handler, self.server_host, self.server_port, ssl=context)
|
||||||
except OSError as error:
|
except OSError as error:
|
||||||
_LOGGER.error("Failed to create HTTP server at port %d: %s",
|
_LOGGER.error("Failed to create HTTP server at port %d: %s",
|
||||||
self.server_port, error)
|
self.server_port, error)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def stop(self):
|
||||||
def stop(self):
|
|
||||||
"""Stop the WSGI server."""
|
"""Stop the WSGI server."""
|
||||||
if self.server:
|
if self.server:
|
||||||
self.server.close()
|
self.server.close()
|
||||||
yield from self.server.wait_closed()
|
await self.server.wait_closed()
|
||||||
yield from self.app.shutdown()
|
await self.app.shutdown()
|
||||||
if self._handler:
|
if self._handler:
|
||||||
yield from self._handler.shutdown(10)
|
await self._handler.shutdown(10)
|
||||||
yield from self.app.cleanup()
|
await self.app.cleanup()
|
||||||
|
|
||||||
|
|
||||||
class HomeAssistantView(object):
|
|
||||||
"""Base view for all views."""
|
|
||||||
|
|
||||||
url = None
|
|
||||||
extra_urls = []
|
|
||||||
requires_auth = True # Views inheriting from this class can override this
|
|
||||||
|
|
||||||
# pylint: disable=no-self-use
|
|
||||||
def json(self, result, status_code=200, headers=None):
|
|
||||||
"""Return a JSON response."""
|
|
||||||
msg = json.dumps(
|
|
||||||
result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8')
|
|
||||||
response = web.Response(
|
|
||||||
body=msg, content_type=CONTENT_TYPE_JSON, status=status_code,
|
|
||||||
headers=headers)
|
|
||||||
response.enable_compression()
|
|
||||||
return response
|
|
||||||
|
|
||||||
def json_message(self, message, status_code=200, message_code=None,
|
|
||||||
headers=None):
|
|
||||||
"""Return a JSON message response."""
|
|
||||||
data = {'message': message}
|
|
||||||
if message_code is not None:
|
|
||||||
data['code'] = message_code
|
|
||||||
return self.json(data, status_code, headers=headers)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
# pylint: disable=no-self-use
|
|
||||||
def file(self, request, fil):
|
|
||||||
"""Return a file."""
|
|
||||||
assert isinstance(fil, str), 'only string paths allowed'
|
|
||||||
return web.FileResponse(fil)
|
|
||||||
|
|
||||||
def register(self, router):
|
|
||||||
"""Register the view with a router."""
|
|
||||||
assert self.url is not None, 'No url set for view'
|
|
||||||
urls = [self.url] + self.extra_urls
|
|
||||||
|
|
||||||
for method in ('get', 'post', 'delete', 'put'):
|
|
||||||
handler = getattr(self, method, None)
|
|
||||||
|
|
||||||
if not handler:
|
|
||||||
continue
|
|
||||||
|
|
||||||
handler = request_handler_factory(self, handler)
|
|
||||||
|
|
||||||
for url in urls:
|
|
||||||
router.add_route(method, url, handler)
|
|
||||||
|
|
||||||
# aiohttp_cors does not work with class based views
|
|
||||||
# self.app.router.add_route('*', self.url, self, name=self.name)
|
|
||||||
|
|
||||||
# for url in self.extra_urls:
|
|
||||||
# self.app.router.add_route('*', url, self)
|
|
||||||
|
|
||||||
|
|
||||||
def request_handler_factory(view, handler):
|
|
||||||
"""Wrap the handler classes."""
|
|
||||||
assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \
|
|
||||||
"Handler should be a coroutine or a callback."
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def handle(request):
|
|
||||||
"""Handle incoming request."""
|
|
||||||
if not request.app['hass'].is_running:
|
|
||||||
return web.Response(status=503)
|
|
||||||
|
|
||||||
authenticated = request.get(KEY_AUTHENTICATED, False)
|
|
||||||
|
|
||||||
if view.requires_auth and not authenticated:
|
|
||||||
raise HTTPUnauthorized()
|
|
||||||
|
|
||||||
_LOGGER.info('Serving %s to %s (auth: %s)',
|
|
||||||
request.path, request.get(KEY_REAL_IP), authenticated)
|
|
||||||
|
|
||||||
result = handler(request, **request.match_info)
|
|
||||||
|
|
||||||
if asyncio.iscoroutine(result):
|
|
||||||
result = yield from result
|
|
||||||
|
|
||||||
if isinstance(result, web.StreamResponse):
|
|
||||||
# The method handler returned a ready-made Response, how nice of it
|
|
||||||
return result
|
|
||||||
|
|
||||||
status_code = 200
|
|
||||||
|
|
||||||
if isinstance(result, tuple):
|
|
||||||
result, status_code = result
|
|
||||||
|
|
||||||
if isinstance(result, str):
|
|
||||||
result = result.encode('utf-8')
|
|
||||||
elif result is None:
|
|
||||||
result = b''
|
|
||||||
elif not isinstance(result, bytes):
|
|
||||||
assert False, ('Result should be None, string, bytes or Response. '
|
|
||||||
'Got: {}').format(result)
|
|
||||||
|
|
||||||
return web.Response(body=result, status=status_code)
|
|
||||||
|
|
||||||
return handle
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Authentication for HTTP component."""
|
"""Authentication for HTTP component."""
|
||||||
import asyncio
|
|
||||||
import base64
|
import base64
|
||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
|
@ -20,13 +20,12 @@ _LOGGER = logging.getLogger(__name__)
|
||||||
def setup_auth(app, trusted_networks, api_password):
|
def setup_auth(app, trusted_networks, api_password):
|
||||||
"""Create auth middleware for the app."""
|
"""Create auth middleware for the app."""
|
||||||
@middleware
|
@middleware
|
||||||
@asyncio.coroutine
|
async def auth_middleware(request, handler):
|
||||||
def auth_middleware(request, handler):
|
|
||||||
"""Authenticate as middleware."""
|
"""Authenticate as middleware."""
|
||||||
# If no password set, just always set authenticated=True
|
# If no password set, just always set authenticated=True
|
||||||
if api_password is None:
|
if api_password is None:
|
||||||
request[KEY_AUTHENTICATED] = True
|
request[KEY_AUTHENTICATED] = True
|
||||||
return (yield from handler(request))
|
return await handler(request)
|
||||||
|
|
||||||
# Check authentication
|
# Check authentication
|
||||||
authenticated = False
|
authenticated = False
|
||||||
|
@ -50,10 +49,9 @@ def setup_auth(app, trusted_networks, api_password):
|
||||||
authenticated = True
|
authenticated = True
|
||||||
|
|
||||||
request[KEY_AUTHENTICATED] = authenticated
|
request[KEY_AUTHENTICATED] = authenticated
|
||||||
return (yield from handler(request))
|
return await handler(request)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def auth_startup(app):
|
||||||
def auth_startup(app):
|
|
||||||
"""Initialize auth middleware when app starts up."""
|
"""Initialize auth middleware when app starts up."""
|
||||||
app.middlewares.append(auth_middleware)
|
app.middlewares.append(auth_middleware)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Ban logic for HTTP component."""
|
"""Ban logic for HTTP component."""
|
||||||
import asyncio
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
|
@ -38,11 +38,10 @@ SCHEMA_IP_BAN_ENTRY = vol.Schema({
|
||||||
@callback
|
@callback
|
||||||
def setup_bans(hass, app, login_threshold):
|
def setup_bans(hass, app, login_threshold):
|
||||||
"""Create IP Ban middleware for the app."""
|
"""Create IP Ban middleware for the app."""
|
||||||
@asyncio.coroutine
|
async def ban_startup(app):
|
||||||
def ban_startup(app):
|
|
||||||
"""Initialize bans when app starts up."""
|
"""Initialize bans when app starts up."""
|
||||||
app.middlewares.append(ban_middleware)
|
app.middlewares.append(ban_middleware)
|
||||||
app[KEY_BANNED_IPS] = yield from hass.async_add_job(
|
app[KEY_BANNED_IPS] = await hass.async_add_job(
|
||||||
load_ip_bans_config, hass.config.path(IP_BANS_FILE))
|
load_ip_bans_config, hass.config.path(IP_BANS_FILE))
|
||||||
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
|
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
|
||||||
app[KEY_LOGIN_THRESHOLD] = login_threshold
|
app[KEY_LOGIN_THRESHOLD] = login_threshold
|
||||||
|
@ -51,12 +50,11 @@ def setup_bans(hass, app, login_threshold):
|
||||||
|
|
||||||
|
|
||||||
@middleware
|
@middleware
|
||||||
@asyncio.coroutine
|
async def ban_middleware(request, handler):
|
||||||
def ban_middleware(request, handler):
|
|
||||||
"""IP Ban middleware."""
|
"""IP Ban middleware."""
|
||||||
if KEY_BANNED_IPS not in request.app:
|
if KEY_BANNED_IPS not in request.app:
|
||||||
_LOGGER.error('IP Ban middleware loaded but banned IPs not loaded')
|
_LOGGER.error('IP Ban middleware loaded but banned IPs not loaded')
|
||||||
return (yield from handler(request))
|
return await handler(request)
|
||||||
|
|
||||||
# Verify if IP is not banned
|
# Verify if IP is not banned
|
||||||
ip_address_ = request[KEY_REAL_IP]
|
ip_address_ = request[KEY_REAL_IP]
|
||||||
|
@ -67,14 +65,13 @@ def ban_middleware(request, handler):
|
||||||
raise HTTPForbidden()
|
raise HTTPForbidden()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return (yield from handler(request))
|
return await handler(request)
|
||||||
except HTTPUnauthorized:
|
except HTTPUnauthorized:
|
||||||
yield from process_wrong_login(request)
|
await process_wrong_login(request)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def process_wrong_login(request):
|
||||||
def process_wrong_login(request):
|
|
||||||
"""Process a wrong login attempt."""
|
"""Process a wrong login attempt."""
|
||||||
remote_addr = request[KEY_REAL_IP]
|
remote_addr = request[KEY_REAL_IP]
|
||||||
|
|
||||||
|
@ -98,7 +95,7 @@ def process_wrong_login(request):
|
||||||
request.app[KEY_BANNED_IPS].append(new_ban)
|
request.app[KEY_BANNED_IPS].append(new_ban)
|
||||||
|
|
||||||
hass = request.app['hass']
|
hass = request.app['hass']
|
||||||
yield from hass.async_add_job(
|
await hass.async_add_job(
|
||||||
update_ip_bans_config, hass.config.path(IP_BANS_FILE), new_ban)
|
update_ip_bans_config, hass.config.path(IP_BANS_FILE), new_ban)
|
||||||
|
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Provide cors support for the HTTP component."""
|
"""Provide cors support for the HTTP component."""
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE
|
from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE
|
||||||
|
|
||||||
|
@ -27,8 +27,7 @@ def setup_cors(app, origins):
|
||||||
) for host in origins
|
) for host in origins
|
||||||
})
|
})
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def cors_startup(app):
|
||||||
def cors_startup(app):
|
|
||||||
"""Initialize cors when app starts up."""
|
"""Initialize cors when app starts up."""
|
||||||
cors_added = set()
|
cors_added = set()
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Decorator for view methods to help with data validation."""
|
"""Decorator for view methods to help with data validation."""
|
||||||
import asyncio
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -24,16 +24,15 @@ class RequestDataValidator:
|
||||||
|
|
||||||
def __call__(self, method):
|
def __call__(self, method):
|
||||||
"""Decorate a function."""
|
"""Decorate a function."""
|
||||||
@asyncio.coroutine
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def wrapper(view, request, *args, **kwargs):
|
async def wrapper(view, request, *args, **kwargs):
|
||||||
"""Wrap a request handler with data validation."""
|
"""Wrap a request handler with data validation."""
|
||||||
data = None
|
data = None
|
||||||
try:
|
try:
|
||||||
data = yield from request.json()
|
data = await request.json()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
if not self._allow_empty or \
|
if not self._allow_empty or \
|
||||||
(yield from request.content.read()) != b'':
|
(await request.content.read()) != b'':
|
||||||
_LOGGER.error('Invalid JSON received.')
|
_LOGGER.error('Invalid JSON received.')
|
||||||
return view.json_message('Invalid JSON.', 400)
|
return view.json_message('Invalid JSON.', 400)
|
||||||
data = {}
|
data = {}
|
||||||
|
@ -45,7 +44,7 @@ class RequestDataValidator:
|
||||||
return view.json_message(
|
return view.json_message(
|
||||||
'Message format incorrect: {}'.format(err), 400)
|
'Message format incorrect: {}'.format(err), 400)
|
||||||
|
|
||||||
result = yield from method(view, request, *args, **kwargs)
|
result = await method(view, request, *args, **kwargs)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Middleware to fetch real IP."""
|
"""Middleware to fetch real IP."""
|
||||||
import asyncio
|
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
|
|
||||||
from aiohttp.web import middleware
|
from aiohttp.web import middleware
|
||||||
|
@ -14,8 +14,7 @@ from .const import KEY_REAL_IP
|
||||||
def setup_real_ip(app, use_x_forwarded_for):
|
def setup_real_ip(app, use_x_forwarded_for):
|
||||||
"""Create IP Ban middleware for the app."""
|
"""Create IP Ban middleware for the app."""
|
||||||
@middleware
|
@middleware
|
||||||
@asyncio.coroutine
|
async def real_ip_middleware(request, handler):
|
||||||
def real_ip_middleware(request, handler):
|
|
||||||
"""Real IP middleware."""
|
"""Real IP middleware."""
|
||||||
if (use_x_forwarded_for and
|
if (use_x_forwarded_for and
|
||||||
X_FORWARDED_FOR in request.headers):
|
X_FORWARDED_FOR in request.headers):
|
||||||
|
@ -25,10 +24,9 @@ def setup_real_ip(app, use_x_forwarded_for):
|
||||||
request[KEY_REAL_IP] = \
|
request[KEY_REAL_IP] = \
|
||||||
ip_address(request.transport.get_extra_info('peername')[0])
|
ip_address(request.transport.get_extra_info('peername')[0])
|
||||||
|
|
||||||
return (yield from handler(request))
|
return await handler(request)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def app_startup(app):
|
||||||
def app_startup(app):
|
|
||||||
"""Initialize bans when app starts up."""
|
"""Initialize bans when app starts up."""
|
||||||
app.middlewares.append(real_ip_middleware)
|
app.middlewares.append(real_ip_middleware)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Static file handling for HTTP component."""
|
"""Static file handling for HTTP component."""
|
||||||
import asyncio
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from aiohttp import hdrs
|
from aiohttp import hdrs
|
||||||
|
@ -14,8 +14,7 @@ _FINGERPRINT = re.compile(r'^(.+)-[a-z0-9]{32}\.(\w+)$', re.IGNORECASE)
|
||||||
class CachingStaticResource(StaticResource):
|
class CachingStaticResource(StaticResource):
|
||||||
"""Static Resource handler that will add cache headers."""
|
"""Static Resource handler that will add cache headers."""
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def _handle(self, request):
|
||||||
def _handle(self, request):
|
|
||||||
filename = URL(request.match_info['filename']).path
|
filename = URL(request.match_info['filename']).path
|
||||||
try:
|
try:
|
||||||
# PyLint is wrong about resolve not being a member.
|
# PyLint is wrong about resolve not being a member.
|
||||||
|
@ -32,7 +31,7 @@ class CachingStaticResource(StaticResource):
|
||||||
raise HTTPNotFound() from error
|
raise HTTPNotFound() from error
|
||||||
|
|
||||||
if filepath.is_dir():
|
if filepath.is_dir():
|
||||||
return (yield from super()._handle(request))
|
return await super()._handle(request)
|
||||||
elif filepath.is_file():
|
elif filepath.is_file():
|
||||||
return CachingFileResponse(filepath, chunk_size=self._chunk_size)
|
return CachingFileResponse(filepath, chunk_size=self._chunk_size)
|
||||||
else:
|
else:
|
||||||
|
@ -49,26 +48,24 @@ class CachingFileResponse(FileResponse):
|
||||||
|
|
||||||
orig_sendfile = self._sendfile
|
orig_sendfile = self._sendfile
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def sendfile(request, fobj, count):
|
||||||
def sendfile(request, fobj, count):
|
|
||||||
"""Sendfile that includes a cache header."""
|
"""Sendfile that includes a cache header."""
|
||||||
cache_time = 31 * 86400 # = 1 month
|
cache_time = 31 * 86400 # = 1 month
|
||||||
self.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format(
|
self.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format(
|
||||||
cache_time)
|
cache_time)
|
||||||
|
|
||||||
yield from orig_sendfile(request, fobj, count)
|
await orig_sendfile(request, fobj, count)
|
||||||
|
|
||||||
# Overwriting like this because __init__ can change implementation.
|
# Overwriting like this because __init__ can change implementation.
|
||||||
self._sendfile = sendfile
|
self._sendfile = sendfile
|
||||||
|
|
||||||
|
|
||||||
@middleware
|
@middleware
|
||||||
@asyncio.coroutine
|
async def staticresource_middleware(request, handler):
|
||||||
def staticresource_middleware(request, handler):
|
|
||||||
"""Middleware to strip out fingerprint from fingerprinted assets."""
|
"""Middleware to strip out fingerprint from fingerprinted assets."""
|
||||||
path = request.path
|
path = request.path
|
||||||
if not path.startswith('/static/') and not path.startswith('/frontend'):
|
if not path.startswith('/static/') and not path.startswith('/frontend'):
|
||||||
return (yield from handler(request))
|
return await handler(request)
|
||||||
|
|
||||||
fingerprinted = _FINGERPRINT.match(request.match_info['filename'])
|
fingerprinted = _FINGERPRINT.match(request.match_info['filename'])
|
||||||
|
|
||||||
|
@ -76,4 +73,4 @@ def staticresource_middleware(request, handler):
|
||||||
request.match_info['filename'] = \
|
request.match_info['filename'] = \
|
||||||
'{}.{}'.format(*fingerprinted.groups())
|
'{}.{}'.format(*fingerprinted.groups())
|
||||||
|
|
||||||
return (yield from handler(request))
|
return await handler(request)
|
||||||
|
|
121
homeassistant/components/http/view.py
Normal file
121
homeassistant/components/http/view.py
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
"""
|
||||||
|
This module provides WSGI application to serve the Home Assistant API.
|
||||||
|
|
||||||
|
For more details about this component, please refer to the documentation at
|
||||||
|
https://home-assistant.io/components/http/
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
from aiohttp.web_exceptions import HTTPUnauthorized
|
||||||
|
|
||||||
|
import homeassistant.remote as rem
|
||||||
|
from homeassistant.core import is_callback
|
||||||
|
from homeassistant.const import CONTENT_TYPE_JSON
|
||||||
|
|
||||||
|
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
|
||||||
|
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HomeAssistantView(object):
|
||||||
|
"""Base view for all views."""
|
||||||
|
|
||||||
|
url = None
|
||||||
|
extra_urls = []
|
||||||
|
requires_auth = True # Views inheriting from this class can override this
|
||||||
|
|
||||||
|
# pylint: disable=no-self-use
|
||||||
|
def json(self, result, status_code=200, headers=None):
|
||||||
|
"""Return a JSON response."""
|
||||||
|
msg = json.dumps(
|
||||||
|
result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8')
|
||||||
|
response = web.Response(
|
||||||
|
body=msg, content_type=CONTENT_TYPE_JSON, status=status_code,
|
||||||
|
headers=headers)
|
||||||
|
response.enable_compression()
|
||||||
|
return response
|
||||||
|
|
||||||
|
def json_message(self, message, status_code=200, message_code=None,
|
||||||
|
headers=None):
|
||||||
|
"""Return a JSON message response."""
|
||||||
|
data = {'message': message}
|
||||||
|
if message_code is not None:
|
||||||
|
data['code'] = message_code
|
||||||
|
return self.json(data, status_code, headers=headers)
|
||||||
|
|
||||||
|
# pylint: disable=no-self-use
|
||||||
|
async def file(self, request, fil):
|
||||||
|
"""Return a file."""
|
||||||
|
assert isinstance(fil, str), 'only string paths allowed'
|
||||||
|
return web.FileResponse(fil)
|
||||||
|
|
||||||
|
def register(self, router):
|
||||||
|
"""Register the view with a router."""
|
||||||
|
assert self.url is not None, 'No url set for view'
|
||||||
|
urls = [self.url] + self.extra_urls
|
||||||
|
|
||||||
|
for method in ('get', 'post', 'delete', 'put'):
|
||||||
|
handler = getattr(self, method, None)
|
||||||
|
|
||||||
|
if not handler:
|
||||||
|
continue
|
||||||
|
|
||||||
|
handler = request_handler_factory(self, handler)
|
||||||
|
|
||||||
|
for url in urls:
|
||||||
|
router.add_route(method, url, handler)
|
||||||
|
|
||||||
|
# aiohttp_cors does not work with class based views
|
||||||
|
# self.app.router.add_route('*', self.url, self, name=self.name)
|
||||||
|
|
||||||
|
# for url in self.extra_urls:
|
||||||
|
# self.app.router.add_route('*', url, self)
|
||||||
|
|
||||||
|
|
||||||
|
def request_handler_factory(view, handler):
|
||||||
|
"""Wrap the handler classes."""
|
||||||
|
assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \
|
||||||
|
"Handler should be a coroutine or a callback."
|
||||||
|
|
||||||
|
async def handle(request):
|
||||||
|
"""Handle incoming request."""
|
||||||
|
if not request.app['hass'].is_running:
|
||||||
|
return web.Response(status=503)
|
||||||
|
|
||||||
|
authenticated = request.get(KEY_AUTHENTICATED, False)
|
||||||
|
|
||||||
|
if view.requires_auth and not authenticated:
|
||||||
|
raise HTTPUnauthorized()
|
||||||
|
|
||||||
|
_LOGGER.info('Serving %s to %s (auth: %s)',
|
||||||
|
request.path, request.get(KEY_REAL_IP), authenticated)
|
||||||
|
|
||||||
|
result = handler(request, **request.match_info)
|
||||||
|
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
result = await result
|
||||||
|
|
||||||
|
if isinstance(result, web.StreamResponse):
|
||||||
|
# The method handler returned a ready-made Response, how nice of it
|
||||||
|
return result
|
||||||
|
|
||||||
|
status_code = 200
|
||||||
|
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
result, status_code = result
|
||||||
|
|
||||||
|
if isinstance(result, str):
|
||||||
|
result = result.encode('utf-8')
|
||||||
|
elif result is None:
|
||||||
|
result = b''
|
||||||
|
elif not isinstance(result, bytes):
|
||||||
|
assert False, ('Result should be None, string, bytes or Response. '
|
||||||
|
'Got: {}').format(result)
|
||||||
|
|
||||||
|
return web.Response(body=result, status=status_code)
|
||||||
|
|
||||||
|
return handle
|
|
@ -191,8 +191,7 @@ def result_message(iden, result=None):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_setup(hass, config):
|
||||||
def async_setup(hass, config):
|
|
||||||
"""Initialize the websocket API."""
|
"""Initialize the websocket API."""
|
||||||
hass.http.register_view(WebsocketAPIView)
|
hass.http.register_view(WebsocketAPIView)
|
||||||
return True
|
return True
|
||||||
|
@ -205,11 +204,10 @@ class WebsocketAPIView(HomeAssistantView):
|
||||||
url = URL
|
url = URL
|
||||||
requires_auth = False
|
requires_auth = False
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def get(self, request):
|
||||||
def get(self, request):
|
|
||||||
"""Handle an incoming websocket connection."""
|
"""Handle an incoming websocket connection."""
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
return ActiveConnection(request.app['hass'], request).handle()
|
return await ActiveConnection(request.app['hass'], request).handle()
|
||||||
|
|
||||||
|
|
||||||
class ActiveConnection:
|
class ActiveConnection:
|
||||||
|
@ -233,17 +231,16 @@ class ActiveConnection:
|
||||||
"""Print an error message."""
|
"""Print an error message."""
|
||||||
_LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2)
|
_LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def _writer(self):
|
||||||
def _writer(self):
|
|
||||||
"""Write outgoing messages."""
|
"""Write outgoing messages."""
|
||||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||||
with suppress(RuntimeError, *CANCELLATION_ERRORS):
|
with suppress(RuntimeError, *CANCELLATION_ERRORS):
|
||||||
while not self.wsock.closed:
|
while not self.wsock.closed:
|
||||||
message = yield from self.to_write.get()
|
message = await self.to_write.get()
|
||||||
if message is None:
|
if message is None:
|
||||||
break
|
break
|
||||||
self.debug("Sending", message)
|
self.debug("Sending", message)
|
||||||
yield from self.wsock.send_json(message, dumps=JSON_DUMP)
|
await self.wsock.send_json(message, dumps=JSON_DUMP)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def send_message_outside(self, message):
|
def send_message_outside(self, message):
|
||||||
|
@ -266,12 +263,11 @@ class ActiveConnection:
|
||||||
self._handle_task.cancel()
|
self._handle_task.cancel()
|
||||||
self._writer_task.cancel()
|
self._writer_task.cancel()
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def handle(self):
|
||||||
def handle(self):
|
|
||||||
"""Handle the websocket connection."""
|
"""Handle the websocket connection."""
|
||||||
request = self.request
|
request = self.request
|
||||||
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
|
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
|
||||||
yield from wsock.prepare(request)
|
await wsock.prepare(request)
|
||||||
self.debug("Connected")
|
self.debug("Connected")
|
||||||
|
|
||||||
# Get a reference to current task so we can cancel our connection
|
# Get a reference to current task so we can cancel our connection
|
||||||
|
@ -294,8 +290,8 @@ class ActiveConnection:
|
||||||
authenticated = True
|
authenticated = True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
yield from self.wsock.send_json(auth_required_message())
|
await self.wsock.send_json(auth_required_message())
|
||||||
msg = yield from wsock.receive_json()
|
msg = await wsock.receive_json()
|
||||||
msg = AUTH_MESSAGE_SCHEMA(msg)
|
msg = AUTH_MESSAGE_SCHEMA(msg)
|
||||||
|
|
||||||
if validate_password(request, msg['api_password']):
|
if validate_password(request, msg['api_password']):
|
||||||
|
@ -303,18 +299,18 @@ class ActiveConnection:
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.debug("Invalid password")
|
self.debug("Invalid password")
|
||||||
yield from self.wsock.send_json(
|
await self.wsock.send_json(
|
||||||
auth_invalid_message('Invalid password'))
|
auth_invalid_message('Invalid password'))
|
||||||
|
|
||||||
if not authenticated:
|
if not authenticated:
|
||||||
yield from process_wrong_login(request)
|
await process_wrong_login(request)
|
||||||
return wsock
|
return wsock
|
||||||
|
|
||||||
yield from self.wsock.send_json(auth_ok_message())
|
await self.wsock.send_json(auth_ok_message())
|
||||||
|
|
||||||
# ---------- AUTH PHASE OVER ----------
|
# ---------- AUTH PHASE OVER ----------
|
||||||
|
|
||||||
msg = yield from wsock.receive_json()
|
msg = await wsock.receive_json()
|
||||||
last_id = 0
|
last_id = 0
|
||||||
|
|
||||||
while msg:
|
while msg:
|
||||||
|
@ -332,7 +328,7 @@ class ActiveConnection:
|
||||||
getattr(self, handler_name)(msg)
|
getattr(self, handler_name)(msg)
|
||||||
|
|
||||||
last_id = cur_id
|
last_id = cur_id
|
||||||
msg = yield from wsock.receive_json()
|
msg = await wsock.receive_json()
|
||||||
|
|
||||||
except vol.Invalid as err:
|
except vol.Invalid as err:
|
||||||
error_msg = "Message incorrectly formatted: "
|
error_msg = "Message incorrectly formatted: "
|
||||||
|
@ -394,11 +390,11 @@ class ActiveConnection:
|
||||||
self.to_write.put_nowait(final_message)
|
self.to_write.put_nowait(final_message)
|
||||||
self.to_write.put_nowait(None)
|
self.to_write.put_nowait(None)
|
||||||
# Make sure all error messages are written before closing
|
# Make sure all error messages are written before closing
|
||||||
yield from self._writer_task
|
await self._writer_task
|
||||||
except asyncio.QueueFull:
|
except asyncio.QueueFull:
|
||||||
self._writer_task.cancel()
|
self._writer_task.cancel()
|
||||||
|
|
||||||
yield from wsock.close()
|
await wsock.close()
|
||||||
self.debug("Closed connection")
|
self.debug("Closed connection")
|
||||||
|
|
||||||
return wsock
|
return wsock
|
||||||
|
@ -410,8 +406,7 @@ class ActiveConnection:
|
||||||
"""
|
"""
|
||||||
msg = SUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg)
|
msg = SUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def forward_events(event):
|
||||||
def forward_events(event):
|
|
||||||
"""Forward events to websocket."""
|
"""Forward events to websocket."""
|
||||||
if event.event_type == EVENT_TIME_CHANGED:
|
if event.event_type == EVENT_TIME_CHANGED:
|
||||||
return
|
return
|
||||||
|
@ -447,10 +442,9 @@ class ActiveConnection:
|
||||||
"""
|
"""
|
||||||
msg = CALL_SERVICE_MESSAGE_SCHEMA(msg)
|
msg = CALL_SERVICE_MESSAGE_SCHEMA(msg)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def call_service_helper(msg):
|
||||||
def call_service_helper(msg):
|
|
||||||
"""Call a service and fire complete message."""
|
"""Call a service and fire complete message."""
|
||||||
yield from self.hass.services.async_call(
|
await self.hass.services.async_call(
|
||||||
msg['domain'], msg['service'], msg.get('service_data'), True)
|
msg['domain'], msg['service'], msg.get('service_data'), True)
|
||||||
self.send_message_outside(result_message(msg['id']))
|
self.send_message_outside(result_message(msg['id']))
|
||||||
|
|
||||||
|
@ -473,10 +467,9 @@ class ActiveConnection:
|
||||||
"""
|
"""
|
||||||
msg = GET_SERVICES_MESSAGE_SCHEMA(msg)
|
msg = GET_SERVICES_MESSAGE_SCHEMA(msg)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def get_services_helper(msg):
|
||||||
def get_services_helper(msg):
|
|
||||||
"""Get available services and fire complete message."""
|
"""Get available services and fire complete message."""
|
||||||
descriptions = yield from async_get_all_descriptions(self.hass)
|
descriptions = await async_get_all_descriptions(self.hass)
|
||||||
self.send_message_outside(result_message(msg['id'], descriptions))
|
self.send_message_outside(result_message(msg['id'], descriptions))
|
||||||
|
|
||||||
self.hass.async_add_job(get_services_helper(msg))
|
self.hass.async_add_job(get_services_helper(msg))
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
"""Tests for the HTTP component."""
|
"""Tests for the HTTP component."""
|
||||||
import asyncio
|
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
@ -18,18 +17,16 @@ def mock_real_ip(app):
|
||||||
nonlocal ip_to_mock
|
nonlocal ip_to_mock
|
||||||
ip_to_mock = value
|
ip_to_mock = value
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
def mock_real_ip(request, handler):
|
async def mock_real_ip(request, handler):
|
||||||
"""Mock Real IP middleware."""
|
"""Mock Real IP middleware."""
|
||||||
nonlocal ip_to_mock
|
nonlocal ip_to_mock
|
||||||
|
|
||||||
request[KEY_REAL_IP] = ip_address(ip_to_mock)
|
request[KEY_REAL_IP] = ip_address(ip_to_mock)
|
||||||
|
|
||||||
return (yield from handler(request))
|
return (await handler(request))
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def real_ip_startup(app):
|
||||||
def real_ip_startup(app):
|
|
||||||
"""Startup of real ip."""
|
"""Startup of real ip."""
|
||||||
app.middlewares.insert(0, mock_real_ip)
|
app.middlewares.insert(0, mock_real_ip)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
"""The tests for the Home Assistant HTTP component."""
|
"""The tests for the Home Assistant HTTP component."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
import asyncio
|
|
||||||
from ipaddress import ip_network
|
from ipaddress import ip_network
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
@ -30,8 +29,7 @@ TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
|
||||||
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
|
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def mock_handler(request):
|
||||||
def mock_handler(request):
|
|
||||||
"""Return if request was authenticated."""
|
"""Return if request was authenticated."""
|
||||||
if not request[KEY_AUTHENTICATED]:
|
if not request[KEY_AUTHENTICATED]:
|
||||||
raise HTTPUnauthorized
|
raise HTTPUnauthorized
|
||||||
|
@ -47,84 +45,79 @@ def app():
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_auth_middleware_loaded_by_default(hass):
|
||||||
def test_auth_middleware_loaded_by_default(hass):
|
|
||||||
"""Test accessing to server from banned IP when feature is off."""
|
"""Test accessing to server from banned IP when feature is off."""
|
||||||
with patch('homeassistant.components.http.setup_auth') as mock_setup:
|
with patch('homeassistant.components.http.setup_auth') as mock_setup:
|
||||||
yield from async_setup_component(hass, 'http', {
|
await async_setup_component(hass, 'http', {
|
||||||
'http': {}
|
'http': {}
|
||||||
})
|
})
|
||||||
|
|
||||||
assert len(mock_setup.mock_calls) == 1
|
assert len(mock_setup.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_access_without_password(app, test_client):
|
||||||
def test_access_without_password(app, test_client):
|
|
||||||
"""Test access without password."""
|
"""Test access without password."""
|
||||||
setup_auth(app, [], None)
|
setup_auth(app, [], None)
|
||||||
client = yield from test_client(app)
|
client = await test_client(app)
|
||||||
|
|
||||||
resp = yield from client.get('/')
|
resp = await client.get('/')
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_access_with_password_in_header(app, test_client):
|
||||||
def test_access_with_password_in_header(app, test_client):
|
|
||||||
"""Test access with password in URL."""
|
"""Test access with password in URL."""
|
||||||
setup_auth(app, [], API_PASSWORD)
|
setup_auth(app, [], API_PASSWORD)
|
||||||
client = yield from test_client(app)
|
client = await test_client(app)
|
||||||
|
|
||||||
req = yield from client.get(
|
req = await client.get(
|
||||||
'/', headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
'/', headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||||
assert req.status == 200
|
assert req.status == 200
|
||||||
|
|
||||||
req = yield from client.get(
|
req = await client.get(
|
||||||
'/', headers={HTTP_HEADER_HA_AUTH: 'wrong-pass'})
|
'/', headers={HTTP_HEADER_HA_AUTH: 'wrong-pass'})
|
||||||
assert req.status == 401
|
assert req.status == 401
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_access_with_password_in_query(app, test_client):
|
||||||
def test_access_with_password_in_query(app, test_client):
|
|
||||||
"""Test access without password."""
|
"""Test access without password."""
|
||||||
setup_auth(app, [], API_PASSWORD)
|
setup_auth(app, [], API_PASSWORD)
|
||||||
client = yield from test_client(app)
|
client = await test_client(app)
|
||||||
|
|
||||||
resp = yield from client.get('/', params={
|
resp = await client.get('/', params={
|
||||||
'api_password': API_PASSWORD
|
'api_password': API_PASSWORD
|
||||||
})
|
})
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
|
||||||
resp = yield from client.get('/')
|
resp = await client.get('/')
|
||||||
assert resp.status == 401
|
assert resp.status == 401
|
||||||
|
|
||||||
resp = yield from client.get('/', params={
|
resp = await client.get('/', params={
|
||||||
'api_password': 'wrong-password'
|
'api_password': 'wrong-password'
|
||||||
})
|
})
|
||||||
assert resp.status == 401
|
assert resp.status == 401
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_basic_auth_works(app, test_client):
|
||||||
def test_basic_auth_works(app, test_client):
|
|
||||||
"""Test access with basic authentication."""
|
"""Test access with basic authentication."""
|
||||||
setup_auth(app, [], API_PASSWORD)
|
setup_auth(app, [], API_PASSWORD)
|
||||||
client = yield from test_client(app)
|
client = await test_client(app)
|
||||||
|
|
||||||
req = yield from client.get(
|
req = await client.get(
|
||||||
'/',
|
'/',
|
||||||
auth=BasicAuth('homeassistant', API_PASSWORD))
|
auth=BasicAuth('homeassistant', API_PASSWORD))
|
||||||
assert req.status == 200
|
assert req.status == 200
|
||||||
|
|
||||||
req = yield from client.get(
|
req = await client.get(
|
||||||
'/',
|
'/',
|
||||||
auth=BasicAuth('wrong_username', API_PASSWORD))
|
auth=BasicAuth('wrong_username', API_PASSWORD))
|
||||||
assert req.status == 401
|
assert req.status == 401
|
||||||
|
|
||||||
req = yield from client.get(
|
req = await client.get(
|
||||||
'/',
|
'/',
|
||||||
auth=BasicAuth('homeassistant', 'wrong password'))
|
auth=BasicAuth('homeassistant', 'wrong password'))
|
||||||
assert req.status == 401
|
assert req.status == 401
|
||||||
|
|
||||||
req = yield from client.get(
|
req = await client.get(
|
||||||
'/',
|
'/',
|
||||||
headers={
|
headers={
|
||||||
'authorization': 'NotBasic abcdefg'
|
'authorization': 'NotBasic abcdefg'
|
||||||
|
@ -132,8 +125,7 @@ def test_basic_auth_works(app, test_client):
|
||||||
assert req.status == 401
|
assert req.status == 401
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_access_with_trusted_ip(test_client):
|
||||||
def test_access_with_trusted_ip(test_client):
|
|
||||||
"""Test access with an untrusted ip address."""
|
"""Test access with an untrusted ip address."""
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.router.add_get('/', mock_handler)
|
app.router.add_get('/', mock_handler)
|
||||||
|
@ -141,16 +133,16 @@ def test_access_with_trusted_ip(test_client):
|
||||||
setup_auth(app, TRUSTED_NETWORKS, 'some-pass')
|
setup_auth(app, TRUSTED_NETWORKS, 'some-pass')
|
||||||
|
|
||||||
set_mock_ip = mock_real_ip(app)
|
set_mock_ip = mock_real_ip(app)
|
||||||
client = yield from test_client(app)
|
client = await test_client(app)
|
||||||
|
|
||||||
for remote_addr in UNTRUSTED_ADDRESSES:
|
for remote_addr in UNTRUSTED_ADDRESSES:
|
||||||
set_mock_ip(remote_addr)
|
set_mock_ip(remote_addr)
|
||||||
resp = yield from client.get('/')
|
resp = await client.get('/')
|
||||||
assert resp.status == 401, \
|
assert resp.status == 401, \
|
||||||
"{} shouldn't be trusted".format(remote_addr)
|
"{} shouldn't be trusted".format(remote_addr)
|
||||||
|
|
||||||
for remote_addr in TRUSTED_ADDRESSES:
|
for remote_addr in TRUSTED_ADDRESSES:
|
||||||
set_mock_ip(remote_addr)
|
set_mock_ip(remote_addr)
|
||||||
resp = yield from client.get('/')
|
resp = await client.get('/')
|
||||||
assert resp.status == 200, \
|
assert resp.status == 200, \
|
||||||
"{} should be trusted".format(remote_addr)
|
"{} should be trusted".format(remote_addr)
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
"""The tests for the Home Assistant HTTP component."""
|
"""The tests for the Home Assistant HTTP component."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
import asyncio
|
|
||||||
from unittest.mock import patch, mock_open
|
from unittest.mock import patch, mock_open
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
@ -16,8 +15,7 @@ from . import mock_real_ip
|
||||||
BANNED_IPS = ['200.201.202.203', '100.64.0.2']
|
BANNED_IPS = ['200.201.202.203', '100.64.0.2']
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_access_from_banned_ip(hass, test_client):
|
||||||
def test_access_from_banned_ip(hass, test_client):
|
|
||||||
"""Test accessing to server from banned IP. Both trusted and not."""
|
"""Test accessing to server from banned IP. Both trusted and not."""
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
setup_bans(hass, app, 5)
|
setup_bans(hass, app, 5)
|
||||||
|
@ -26,19 +24,18 @@ def test_access_from_banned_ip(hass, test_client):
|
||||||
with patch('homeassistant.components.http.ban.load_ip_bans_config',
|
with patch('homeassistant.components.http.ban.load_ip_bans_config',
|
||||||
return_value=[IpBan(banned_ip) for banned_ip
|
return_value=[IpBan(banned_ip) for banned_ip
|
||||||
in BANNED_IPS]):
|
in BANNED_IPS]):
|
||||||
client = yield from test_client(app)
|
client = await test_client(app)
|
||||||
|
|
||||||
for remote_addr in BANNED_IPS:
|
for remote_addr in BANNED_IPS:
|
||||||
set_real_ip(remote_addr)
|
set_real_ip(remote_addr)
|
||||||
resp = yield from client.get('/')
|
resp = await client.get('/')
|
||||||
assert resp.status == 403
|
assert resp.status == 403
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_ban_middleware_not_loaded_by_config(hass):
|
||||||
def test_ban_middleware_not_loaded_by_config(hass):
|
|
||||||
"""Test accessing to server from banned IP when feature is off."""
|
"""Test accessing to server from banned IP when feature is off."""
|
||||||
with patch('homeassistant.components.http.setup_bans') as mock_setup:
|
with patch('homeassistant.components.http.setup_bans') as mock_setup:
|
||||||
yield from async_setup_component(hass, 'http', {
|
await async_setup_component(hass, 'http', {
|
||||||
'http': {
|
'http': {
|
||||||
http.CONF_IP_BAN_ENABLED: False,
|
http.CONF_IP_BAN_ENABLED: False,
|
||||||
}
|
}
|
||||||
|
@ -47,25 +44,22 @@ def test_ban_middleware_not_loaded_by_config(hass):
|
||||||
assert len(mock_setup.mock_calls) == 0
|
assert len(mock_setup.mock_calls) == 0
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_ban_middleware_loaded_by_default(hass):
|
||||||
def test_ban_middleware_loaded_by_default(hass):
|
|
||||||
"""Test accessing to server from banned IP when feature is off."""
|
"""Test accessing to server from banned IP when feature is off."""
|
||||||
with patch('homeassistant.components.http.setup_bans') as mock_setup:
|
with patch('homeassistant.components.http.setup_bans') as mock_setup:
|
||||||
yield from async_setup_component(hass, 'http', {
|
await async_setup_component(hass, 'http', {
|
||||||
'http': {}
|
'http': {}
|
||||||
})
|
})
|
||||||
|
|
||||||
assert len(mock_setup.mock_calls) == 1
|
assert len(mock_setup.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_ip_bans_file_creation(hass, test_client):
|
||||||
def test_ip_bans_file_creation(hass, test_client):
|
|
||||||
"""Testing if banned IP file created."""
|
"""Testing if banned IP file created."""
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app['hass'] = hass
|
app['hass'] = hass
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def unauth_handler(request):
|
||||||
def unauth_handler(request):
|
|
||||||
"""Return a mock web response."""
|
"""Return a mock web response."""
|
||||||
raise HTTPUnauthorized
|
raise HTTPUnauthorized
|
||||||
|
|
||||||
|
@ -76,21 +70,21 @@ def test_ip_bans_file_creation(hass, test_client):
|
||||||
with patch('homeassistant.components.http.ban.load_ip_bans_config',
|
with patch('homeassistant.components.http.ban.load_ip_bans_config',
|
||||||
return_value=[IpBan(banned_ip) for banned_ip
|
return_value=[IpBan(banned_ip) for banned_ip
|
||||||
in BANNED_IPS]):
|
in BANNED_IPS]):
|
||||||
client = yield from test_client(app)
|
client = await test_client(app)
|
||||||
|
|
||||||
m = mock_open()
|
m = mock_open()
|
||||||
|
|
||||||
with patch('homeassistant.components.http.ban.open', m, create=True):
|
with patch('homeassistant.components.http.ban.open', m, create=True):
|
||||||
resp = yield from client.get('/')
|
resp = await client.get('/')
|
||||||
assert resp.status == 401
|
assert resp.status == 401
|
||||||
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS)
|
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS)
|
||||||
assert m.call_count == 0
|
assert m.call_count == 0
|
||||||
|
|
||||||
resp = yield from client.get('/')
|
resp = await client.get('/')
|
||||||
assert resp.status == 401
|
assert resp.status == 401
|
||||||
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
|
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
|
||||||
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
|
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
|
||||||
|
|
||||||
resp = yield from client.get('/')
|
resp = await client.get('/')
|
||||||
assert resp.status == 403
|
assert resp.status == 403
|
||||||
assert m.call_count == 1
|
assert m.call_count == 1
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
"""Test cors for the HTTP component."""
|
"""Test cors for the HTTP component."""
|
||||||
import asyncio
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
@ -20,22 +19,20 @@ from homeassistant.components.http.cors import setup_cors
|
||||||
TRUSTED_ORIGIN = 'https://home-assistant.io'
|
TRUSTED_ORIGIN = 'https://home-assistant.io'
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_cors_middleware_not_loaded_by_default(hass):
|
||||||
def test_cors_middleware_not_loaded_by_default(hass):
|
|
||||||
"""Test accessing to server from banned IP when feature is off."""
|
"""Test accessing to server from banned IP when feature is off."""
|
||||||
with patch('homeassistant.components.http.setup_cors') as mock_setup:
|
with patch('homeassistant.components.http.setup_cors') as mock_setup:
|
||||||
yield from async_setup_component(hass, 'http', {
|
await async_setup_component(hass, 'http', {
|
||||||
'http': {}
|
'http': {}
|
||||||
})
|
})
|
||||||
|
|
||||||
assert len(mock_setup.mock_calls) == 0
|
assert len(mock_setup.mock_calls) == 0
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_cors_middleware_loaded_from_config(hass):
|
||||||
def test_cors_middleware_loaded_from_config(hass):
|
|
||||||
"""Test accessing to server from banned IP when feature is off."""
|
"""Test accessing to server from banned IP when feature is off."""
|
||||||
with patch('homeassistant.components.http.setup_cors') as mock_setup:
|
with patch('homeassistant.components.http.setup_cors') as mock_setup:
|
||||||
yield from async_setup_component(hass, 'http', {
|
await async_setup_component(hass, 'http', {
|
||||||
'http': {
|
'http': {
|
||||||
'cors_allowed_origins': ['http://home-assistant.io']
|
'cors_allowed_origins': ['http://home-assistant.io']
|
||||||
}
|
}
|
||||||
|
@ -44,8 +41,7 @@ def test_cors_middleware_loaded_from_config(hass):
|
||||||
assert len(mock_setup.mock_calls) == 1
|
assert len(mock_setup.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def mock_handler(request):
|
||||||
def mock_handler(request):
|
|
||||||
"""Return if request was authenticated."""
|
"""Return if request was authenticated."""
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
|
@ -59,10 +55,9 @@ def client(loop, test_client):
|
||||||
return loop.run_until_complete(test_client(app))
|
return loop.run_until_complete(test_client(app))
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_cors_requests(client):
|
||||||
def test_cors_requests(client):
|
|
||||||
"""Test cross origin requests."""
|
"""Test cross origin requests."""
|
||||||
req = yield from client.get('/', headers={
|
req = await client.get('/', headers={
|
||||||
ORIGIN: TRUSTED_ORIGIN
|
ORIGIN: TRUSTED_ORIGIN
|
||||||
})
|
})
|
||||||
assert req.status == 200
|
assert req.status == 200
|
||||||
|
@ -70,7 +65,7 @@ def test_cors_requests(client):
|
||||||
TRUSTED_ORIGIN
|
TRUSTED_ORIGIN
|
||||||
|
|
||||||
# With password in URL
|
# With password in URL
|
||||||
req = yield from client.get('/', params={
|
req = await client.get('/', params={
|
||||||
'api_password': 'some-pass'
|
'api_password': 'some-pass'
|
||||||
}, headers={
|
}, headers={
|
||||||
ORIGIN: TRUSTED_ORIGIN
|
ORIGIN: TRUSTED_ORIGIN
|
||||||
|
@ -80,7 +75,7 @@ def test_cors_requests(client):
|
||||||
TRUSTED_ORIGIN
|
TRUSTED_ORIGIN
|
||||||
|
|
||||||
# With password in headers
|
# With password in headers
|
||||||
req = yield from client.get('/', headers={
|
req = await client.get('/', headers={
|
||||||
HTTP_HEADER_HA_AUTH: 'some-pass',
|
HTTP_HEADER_HA_AUTH: 'some-pass',
|
||||||
ORIGIN: TRUSTED_ORIGIN
|
ORIGIN: TRUSTED_ORIGIN
|
||||||
})
|
})
|
||||||
|
@ -89,10 +84,9 @@ def test_cors_requests(client):
|
||||||
TRUSTED_ORIGIN
|
TRUSTED_ORIGIN
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_cors_preflight_allowed(client):
|
||||||
def test_cors_preflight_allowed(client):
|
|
||||||
"""Test cross origin resource sharing preflight (OPTIONS) request."""
|
"""Test cross origin resource sharing preflight (OPTIONS) request."""
|
||||||
req = yield from client.options('/', headers={
|
req = await client.options('/', headers={
|
||||||
ORIGIN: TRUSTED_ORIGIN,
|
ORIGIN: TRUSTED_ORIGIN,
|
||||||
ACCESS_CONTROL_REQUEST_METHOD: 'GET',
|
ACCESS_CONTROL_REQUEST_METHOD: 'GET',
|
||||||
ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access'
|
ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access'
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
"""Test data validator decorator."""
|
"""Test data validator decorator."""
|
||||||
import asyncio
|
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
@ -9,8 +8,7 @@ from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def get_client(test_client, validator):
|
||||||
def get_client(test_client, validator):
|
|
||||||
"""Generate a client that hits a view decorated with validator."""
|
"""Generate a client that hits a view decorated with validator."""
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app['hass'] = Mock(is_running=True)
|
app['hass'] = Mock(is_running=True)
|
||||||
|
@ -20,58 +18,55 @@ def get_client(test_client, validator):
|
||||||
name = 'test'
|
name = 'test'
|
||||||
requires_auth = False
|
requires_auth = False
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
@validator
|
@validator
|
||||||
def post(self, request, data):
|
async def post(self, request, data):
|
||||||
"""Test method."""
|
"""Test method."""
|
||||||
return b''
|
return b''
|
||||||
|
|
||||||
TestView().register(app.router)
|
TestView().register(app.router)
|
||||||
client = yield from test_client(app)
|
client = await test_client(app)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_validator(test_client):
|
||||||
def test_validator(test_client):
|
|
||||||
"""Test the validator."""
|
"""Test the validator."""
|
||||||
client = yield from get_client(
|
client = await get_client(
|
||||||
test_client, RequestDataValidator(vol.Schema({
|
test_client, RequestDataValidator(vol.Schema({
|
||||||
vol.Required('test'): str
|
vol.Required('test'): str
|
||||||
})))
|
})))
|
||||||
|
|
||||||
resp = yield from client.post('/', json={
|
resp = await client.post('/', json={
|
||||||
'test': 'bla'
|
'test': 'bla'
|
||||||
})
|
})
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
|
||||||
resp = yield from client.post('/', json={
|
resp = await client.post('/', json={
|
||||||
'test': 100
|
'test': 100
|
||||||
})
|
})
|
||||||
assert resp.status == 400
|
assert resp.status == 400
|
||||||
|
|
||||||
resp = yield from client.post('/')
|
resp = await client.post('/')
|
||||||
assert resp.status == 400
|
assert resp.status == 400
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_validator_allow_empty(test_client):
|
||||||
def test_validator_allow_empty(test_client):
|
|
||||||
"""Test the validator with empty data."""
|
"""Test the validator with empty data."""
|
||||||
client = yield from get_client(
|
client = await get_client(
|
||||||
test_client, RequestDataValidator(vol.Schema({
|
test_client, RequestDataValidator(vol.Schema({
|
||||||
# Although we allow empty, our schema should still be able
|
# Although we allow empty, our schema should still be able
|
||||||
# to validate an empty dict.
|
# to validate an empty dict.
|
||||||
vol.Optional('test'): str
|
vol.Optional('test'): str
|
||||||
}), allow_empty=True))
|
}), allow_empty=True))
|
||||||
|
|
||||||
resp = yield from client.post('/', json={
|
resp = await client.post('/', json={
|
||||||
'test': 'bla'
|
'test': 'bla'
|
||||||
})
|
})
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
|
||||||
resp = yield from client.post('/', json={
|
resp = await client.post('/', json={
|
||||||
'test': 100
|
'test': 100
|
||||||
})
|
})
|
||||||
assert resp.status == 400
|
assert resp.status == 400
|
||||||
|
|
||||||
resp = yield from client.post('/')
|
resp = await client.post('/')
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
"""The tests for the Home Assistant HTTP component."""
|
"""The tests for the Home Assistant HTTP component."""
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
import homeassistant.components.http as http
|
import homeassistant.components.http as http
|
||||||
|
@ -12,16 +10,14 @@ class TestView(http.HomeAssistantView):
|
||||||
name = 'test'
|
name = 'test'
|
||||||
url = '/hello'
|
url = '/hello'
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def get(self, request):
|
||||||
def get(self, request):
|
|
||||||
"""Return a get request."""
|
"""Return a get request."""
|
||||||
return 'hello'
|
return 'hello'
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_registering_view_while_running(hass, test_client, unused_port):
|
||||||
def test_registering_view_while_running(hass, test_client, unused_port):
|
|
||||||
"""Test that we can register a view while the server is running."""
|
"""Test that we can register a view while the server is running."""
|
||||||
yield from async_setup_component(
|
await async_setup_component(
|
||||||
hass, http.DOMAIN, {
|
hass, http.DOMAIN, {
|
||||||
http.DOMAIN: {
|
http.DOMAIN: {
|
||||||
http.CONF_SERVER_PORT: unused_port(),
|
http.CONF_SERVER_PORT: unused_port(),
|
||||||
|
@ -29,15 +25,14 @@ def test_registering_view_while_running(hass, test_client, unused_port):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
yield from hass.async_start()
|
await hass.async_start()
|
||||||
# This raises a RuntimeError if app is frozen
|
# This raises a RuntimeError if app is frozen
|
||||||
hass.http.register_view(TestView)
|
hass.http.register_view(TestView)
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_api_base_url_with_domain(hass):
|
||||||
def test_api_base_url_with_domain(hass):
|
|
||||||
"""Test setting API URL."""
|
"""Test setting API URL."""
|
||||||
result = yield from async_setup_component(hass, 'http', {
|
result = await async_setup_component(hass, 'http', {
|
||||||
'http': {
|
'http': {
|
||||||
'base_url': 'example.com'
|
'base_url': 'example.com'
|
||||||
}
|
}
|
||||||
|
@ -46,10 +41,9 @@ def test_api_base_url_with_domain(hass):
|
||||||
assert hass.config.api.base_url == 'http://example.com'
|
assert hass.config.api.base_url == 'http://example.com'
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_api_base_url_with_ip(hass):
|
||||||
def test_api_base_url_with_ip(hass):
|
|
||||||
"""Test setting api url."""
|
"""Test setting api url."""
|
||||||
result = yield from async_setup_component(hass, 'http', {
|
result = await async_setup_component(hass, 'http', {
|
||||||
'http': {
|
'http': {
|
||||||
'server_host': '1.1.1.1'
|
'server_host': '1.1.1.1'
|
||||||
}
|
}
|
||||||
|
@ -58,10 +52,9 @@ def test_api_base_url_with_ip(hass):
|
||||||
assert hass.config.api.base_url == 'http://1.1.1.1:8123'
|
assert hass.config.api.base_url == 'http://1.1.1.1:8123'
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_api_base_url_with_ip_port(hass):
|
||||||
def test_api_base_url_with_ip_port(hass):
|
|
||||||
"""Test setting api url."""
|
"""Test setting api url."""
|
||||||
result = yield from async_setup_component(hass, 'http', {
|
result = await async_setup_component(hass, 'http', {
|
||||||
'http': {
|
'http': {
|
||||||
'base_url': '1.1.1.1:8124'
|
'base_url': '1.1.1.1:8124'
|
||||||
}
|
}
|
||||||
|
@ -70,10 +63,9 @@ def test_api_base_url_with_ip_port(hass):
|
||||||
assert hass.config.api.base_url == 'http://1.1.1.1:8124'
|
assert hass.config.api.base_url == 'http://1.1.1.1:8124'
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_api_no_base_url(hass):
|
||||||
def test_api_no_base_url(hass):
|
|
||||||
"""Test setting api url."""
|
"""Test setting api url."""
|
||||||
result = yield from async_setup_component(hass, 'http', {
|
result = await async_setup_component(hass, 'http', {
|
||||||
'http': {
|
'http': {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -81,10 +73,9 @@ def test_api_no_base_url(hass):
|
||||||
assert hass.config.api.base_url == 'http://127.0.0.1:8123'
|
assert hass.config.api.base_url == 'http://127.0.0.1:8123'
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_not_log_password(hass, unused_port, test_client, caplog):
|
||||||
def test_not_log_password(hass, unused_port, test_client, caplog):
|
|
||||||
"""Test access with password doesn't get logged."""
|
"""Test access with password doesn't get logged."""
|
||||||
result = yield from async_setup_component(hass, 'api', {
|
result = await async_setup_component(hass, 'api', {
|
||||||
'http': {
|
'http': {
|
||||||
http.CONF_SERVER_PORT: unused_port(),
|
http.CONF_SERVER_PORT: unused_port(),
|
||||||
http.CONF_API_PASSWORD: 'some-pass'
|
http.CONF_API_PASSWORD: 'some-pass'
|
||||||
|
@ -92,9 +83,9 @@ def test_not_log_password(hass, unused_port, test_client, caplog):
|
||||||
})
|
})
|
||||||
assert result
|
assert result
|
||||||
|
|
||||||
client = yield from test_client(hass.http.app)
|
client = await test_client(hass.http.app)
|
||||||
|
|
||||||
resp = yield from client.get('/api/', params={
|
resp = await client.get('/api/', params={
|
||||||
'api_password': 'some-pass'
|
'api_password': 'some-pass'
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
"""Test real IP middleware."""
|
"""Test real IP middleware."""
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp.hdrs import X_FORWARDED_FOR
|
from aiohttp.hdrs import X_FORWARDED_FOR
|
||||||
|
|
||||||
|
@ -8,41 +6,38 @@ from homeassistant.components.http.real_ip import setup_real_ip
|
||||||
from homeassistant.components.http.const import KEY_REAL_IP
|
from homeassistant.components.http.const import KEY_REAL_IP
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def mock_handler(request):
|
||||||
def mock_handler(request):
|
|
||||||
"""Handler that returns the real IP as text."""
|
"""Handler that returns the real IP as text."""
|
||||||
return web.Response(text=str(request[KEY_REAL_IP]))
|
return web.Response(text=str(request[KEY_REAL_IP]))
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_ignore_x_forwarded_for(test_client):
|
||||||
def test_ignore_x_forwarded_for(test_client):
|
|
||||||
"""Test that we get the IP from the transport."""
|
"""Test that we get the IP from the transport."""
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.router.add_get('/', mock_handler)
|
app.router.add_get('/', mock_handler)
|
||||||
setup_real_ip(app, False)
|
setup_real_ip(app, False)
|
||||||
|
|
||||||
mock_api_client = yield from test_client(app)
|
mock_api_client = await test_client(app)
|
||||||
|
|
||||||
resp = yield from mock_api_client.get('/', headers={
|
resp = await mock_api_client.get('/', headers={
|
||||||
X_FORWARDED_FOR: '255.255.255.255'
|
X_FORWARDED_FOR: '255.255.255.255'
|
||||||
})
|
})
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
text = yield from resp.text()
|
text = await resp.text()
|
||||||
assert text != '255.255.255.255'
|
assert text != '255.255.255.255'
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_use_x_forwarded_for(test_client):
|
||||||
def test_use_x_forwarded_for(test_client):
|
|
||||||
"""Test that we get the IP from the transport."""
|
"""Test that we get the IP from the transport."""
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.router.add_get('/', mock_handler)
|
app.router.add_get('/', mock_handler)
|
||||||
setup_real_ip(app, True)
|
setup_real_ip(app, True)
|
||||||
|
|
||||||
mock_api_client = yield from test_client(app)
|
mock_api_client = await test_client(app)
|
||||||
|
|
||||||
resp = yield from mock_api_client.get('/', headers={
|
resp = await mock_api_client.get('/', headers={
|
||||||
X_FORWARDED_FOR: '255.255.255.255'
|
X_FORWARDED_FOR: '255.255.255.255'
|
||||||
})
|
})
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
text = yield from resp.text()
|
text = await resp.text()
|
||||||
assert text == '255.255.255.255'
|
assert text == '255.255.255.255'
|
||||||
|
|
Loading…
Add table
Reference in a new issue