Move HomeAssistantView to separate file. Convert http to async syntax. [skip ci] (#12982)

* Move HomeAssistantView to separate file. Convert http to async syntax.

* pylint

* websocket api

* update emulated_hue for async/await

* Lint
This commit is contained in:
Boyi C 2018-03-09 09:51:49 +08:00 committed by Paulus Schoutsen
parent 2ee73ca911
commit 321eb2ec6f
17 changed files with 292 additions and 344 deletions

View file

@ -4,7 +4,6 @@ Support for local control of entities by emulating the Phillips Hue bridge.
For more details about this component, please refer to the documentation at For more details about this component, please refer to the documentation at
https://home-assistant.io/components/emulated_hue/ https://home-assistant.io/components/emulated_hue/
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
@ -111,17 +110,15 @@ def setup(hass, yaml_config):
config.upnp_bind_multicast, config.advertise_ip, config.upnp_bind_multicast, config.advertise_ip,
config.advertise_port) config.advertise_port)
@asyncio.coroutine async def stop_emulated_hue_bridge(event):
def stop_emulated_hue_bridge(event):
"""Stop the emulated hue bridge.""" """Stop the emulated hue bridge."""
upnp_listener.stop() upnp_listener.stop()
yield from server.stop() await server.stop()
@asyncio.coroutine async def start_emulated_hue_bridge(event):
def start_emulated_hue_bridge(event):
"""Start the emulated hue bridge.""" """Start the emulated hue bridge."""
upnp_listener.start() upnp_listener.start()
yield from server.start() await server.start()
hass.bus.async_listen_once( hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, stop_emulated_hue_bridge) EVENT_HOMEASSISTANT_STOP, stop_emulated_hue_bridge)

View file

@ -4,21 +4,18 @@ This module provides WSGI application to serve the Home Assistant API.
For more details about this component, please refer to the documentation at For more details about this component, please refer to the documentation at
https://home-assistant.io/components/http/ https://home-assistant.io/components/http/
""" """
import asyncio
from ipaddress import ip_network from ipaddress import ip_network
import json
import logging import logging
import os import os
import ssl import ssl
from aiohttp import web from aiohttp import web
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently from aiohttp.web_exceptions import HTTPMovedPermanently
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
SERVER_PORT, CONTENT_TYPE_JSON, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, SERVER_PORT)
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START,)
from homeassistant.core import is_callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
import homeassistant.remote as rem import homeassistant.remote as rem
import homeassistant.util as hass_util import homeassistant.util as hass_util
@ -28,10 +25,13 @@ from .auth import setup_auth
from .ban import setup_bans from .ban import setup_bans
from .cors import setup_cors from .cors import setup_cors
from .real_ip import setup_real_ip from .real_ip import setup_real_ip
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
from .static import ( from .static import (
CachingFileResponse, CachingStaticResource, staticresource_middleware) CachingFileResponse, CachingStaticResource, staticresource_middleware)
# Import as alias
from .const import KEY_AUTHENTICATED, KEY_REAL_IP # noqa
from .view import HomeAssistantView # noqa
REQUIREMENTS = ['aiohttp_cors==0.6.0'] REQUIREMENTS = ['aiohttp_cors==0.6.0']
DOMAIN = 'http' DOMAIN = 'http'
@ -98,8 +98,7 @@ CONFIG_SCHEMA = vol.Schema({
}, extra=vol.ALLOW_EXTRA) }, extra=vol.ALLOW_EXTRA)
@asyncio.coroutine async def async_setup(hass, config):
def async_setup(hass, config):
"""Set up the HTTP API and debug interface.""" """Set up the HTTP API and debug interface."""
conf = config.get(DOMAIN) conf = config.get(DOMAIN)
@ -135,16 +134,14 @@ def async_setup(hass, config):
is_ban_enabled=is_ban_enabled is_ban_enabled=is_ban_enabled
) )
@asyncio.coroutine async def stop_server(event):
def stop_server(event):
"""Stop the server.""" """Stop the server."""
yield from server.stop() await server.stop()
@asyncio.coroutine async def start_server(event):
def start_server(event):
"""Start the server.""" """Start the server."""
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
yield from server.start() await server.start()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_server) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_server)
@ -252,13 +249,11 @@ class HomeAssistantHTTP(object):
return return
if cache_headers: if cache_headers:
@asyncio.coroutine async def serve_file(request):
def serve_file(request):
"""Serve file from disk.""" """Serve file from disk."""
return CachingFileResponse(path) return CachingFileResponse(path)
else: else:
@asyncio.coroutine async def serve_file(request):
def serve_file(request):
"""Serve file from disk.""" """Serve file from disk."""
return web.FileResponse(path) return web.FileResponse(path)
@ -276,14 +271,13 @@ class HomeAssistantHTTP(object):
self.app.router.add_route('GET', url_pattern, serve_file) self.app.router.add_route('GET', url_pattern, serve_file)
@asyncio.coroutine async def start(self):
def start(self):
"""Start the WSGI server.""" """Start the WSGI server."""
# We misunderstood the startup signal. You're not allowed to change # We misunderstood the startup signal. You're not allowed to change
# anything during startup. Temp workaround. # anything during startup. Temp workaround.
# pylint: disable=protected-access # pylint: disable=protected-access
self.app._on_startup.freeze() self.app._on_startup.freeze()
yield from self.app.startup() await self.app.startup()
if self.ssl_certificate: if self.ssl_certificate:
try: try:
@ -308,121 +302,18 @@ class HomeAssistantHTTP(object):
self._handler = self.app.make_handler(loop=self.hass.loop) self._handler = self.app.make_handler(loop=self.hass.loop)
try: try:
self.server = yield from self.hass.loop.create_server( self.server = await self.hass.loop.create_server(
self._handler, self.server_host, self.server_port, ssl=context) self._handler, self.server_host, self.server_port, ssl=context)
except OSError as error: except OSError as error:
_LOGGER.error("Failed to create HTTP server at port %d: %s", _LOGGER.error("Failed to create HTTP server at port %d: %s",
self.server_port, error) self.server_port, error)
@asyncio.coroutine async def stop(self):
def stop(self):
"""Stop the WSGI server.""" """Stop the WSGI server."""
if self.server: if self.server:
self.server.close() self.server.close()
yield from self.server.wait_closed() await self.server.wait_closed()
yield from self.app.shutdown() await self.app.shutdown()
if self._handler: if self._handler:
yield from self._handler.shutdown(10) await self._handler.shutdown(10)
yield from self.app.cleanup() await self.app.cleanup()
class HomeAssistantView(object):
"""Base view for all views."""
url = None
extra_urls = []
requires_auth = True # Views inheriting from this class can override this
# pylint: disable=no-self-use
def json(self, result, status_code=200, headers=None):
"""Return a JSON response."""
msg = json.dumps(
result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8')
response = web.Response(
body=msg, content_type=CONTENT_TYPE_JSON, status=status_code,
headers=headers)
response.enable_compression()
return response
def json_message(self, message, status_code=200, message_code=None,
headers=None):
"""Return a JSON message response."""
data = {'message': message}
if message_code is not None:
data['code'] = message_code
return self.json(data, status_code, headers=headers)
@asyncio.coroutine
# pylint: disable=no-self-use
def file(self, request, fil):
"""Return a file."""
assert isinstance(fil, str), 'only string paths allowed'
return web.FileResponse(fil)
def register(self, router):
"""Register the view with a router."""
assert self.url is not None, 'No url set for view'
urls = [self.url] + self.extra_urls
for method in ('get', 'post', 'delete', 'put'):
handler = getattr(self, method, None)
if not handler:
continue
handler = request_handler_factory(self, handler)
for url in urls:
router.add_route(method, url, handler)
# aiohttp_cors does not work with class based views
# self.app.router.add_route('*', self.url, self, name=self.name)
# for url in self.extra_urls:
# self.app.router.add_route('*', url, self)
def request_handler_factory(view, handler):
"""Wrap the handler classes."""
assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \
"Handler should be a coroutine or a callback."
@asyncio.coroutine
def handle(request):
"""Handle incoming request."""
if not request.app['hass'].is_running:
return web.Response(status=503)
authenticated = request.get(KEY_AUTHENTICATED, False)
if view.requires_auth and not authenticated:
raise HTTPUnauthorized()
_LOGGER.info('Serving %s to %s (auth: %s)',
request.path, request.get(KEY_REAL_IP), authenticated)
result = handler(request, **request.match_info)
if asyncio.iscoroutine(result):
result = yield from result
if isinstance(result, web.StreamResponse):
# The method handler returned a ready-made Response, how nice of it
return result
status_code = 200
if isinstance(result, tuple):
result, status_code = result
if isinstance(result, str):
result = result.encode('utf-8')
elif result is None:
result = b''
elif not isinstance(result, bytes):
assert False, ('Result should be None, string, bytes or Response. '
'Got: {}').format(result)
return web.Response(body=result, status=status_code)
return handle

View file

@ -1,5 +1,5 @@
"""Authentication for HTTP component.""" """Authentication for HTTP component."""
import asyncio
import base64 import base64
import hmac import hmac
import logging import logging
@ -20,13 +20,12 @@ _LOGGER = logging.getLogger(__name__)
def setup_auth(app, trusted_networks, api_password): def setup_auth(app, trusted_networks, api_password):
"""Create auth middleware for the app.""" """Create auth middleware for the app."""
@middleware @middleware
@asyncio.coroutine async def auth_middleware(request, handler):
def auth_middleware(request, handler):
"""Authenticate as middleware.""" """Authenticate as middleware."""
# If no password set, just always set authenticated=True # If no password set, just always set authenticated=True
if api_password is None: if api_password is None:
request[KEY_AUTHENTICATED] = True request[KEY_AUTHENTICATED] = True
return (yield from handler(request)) return await handler(request)
# Check authentication # Check authentication
authenticated = False authenticated = False
@ -50,10 +49,9 @@ def setup_auth(app, trusted_networks, api_password):
authenticated = True authenticated = True
request[KEY_AUTHENTICATED] = authenticated request[KEY_AUTHENTICATED] = authenticated
return (yield from handler(request)) return await handler(request)
@asyncio.coroutine async def auth_startup(app):
def auth_startup(app):
"""Initialize auth middleware when app starts up.""" """Initialize auth middleware when app starts up."""
app.middlewares.append(auth_middleware) app.middlewares.append(auth_middleware)

View file

@ -1,5 +1,5 @@
"""Ban logic for HTTP component.""" """Ban logic for HTTP component."""
import asyncio
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from ipaddress import ip_address from ipaddress import ip_address
@ -38,11 +38,10 @@ SCHEMA_IP_BAN_ENTRY = vol.Schema({
@callback @callback
def setup_bans(hass, app, login_threshold): def setup_bans(hass, app, login_threshold):
"""Create IP Ban middleware for the app.""" """Create IP Ban middleware for the app."""
@asyncio.coroutine async def ban_startup(app):
def ban_startup(app):
"""Initialize bans when app starts up.""" """Initialize bans when app starts up."""
app.middlewares.append(ban_middleware) app.middlewares.append(ban_middleware)
app[KEY_BANNED_IPS] = yield from hass.async_add_job( app[KEY_BANNED_IPS] = await hass.async_add_job(
load_ip_bans_config, hass.config.path(IP_BANS_FILE)) load_ip_bans_config, hass.config.path(IP_BANS_FILE))
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int) app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
app[KEY_LOGIN_THRESHOLD] = login_threshold app[KEY_LOGIN_THRESHOLD] = login_threshold
@ -51,12 +50,11 @@ def setup_bans(hass, app, login_threshold):
@middleware @middleware
@asyncio.coroutine async def ban_middleware(request, handler):
def ban_middleware(request, handler):
"""IP Ban middleware.""" """IP Ban middleware."""
if KEY_BANNED_IPS not in request.app: if KEY_BANNED_IPS not in request.app:
_LOGGER.error('IP Ban middleware loaded but banned IPs not loaded') _LOGGER.error('IP Ban middleware loaded but banned IPs not loaded')
return (yield from handler(request)) return await handler(request)
# Verify if IP is not banned # Verify if IP is not banned
ip_address_ = request[KEY_REAL_IP] ip_address_ = request[KEY_REAL_IP]
@ -67,14 +65,13 @@ def ban_middleware(request, handler):
raise HTTPForbidden() raise HTTPForbidden()
try: try:
return (yield from handler(request)) return await handler(request)
except HTTPUnauthorized: except HTTPUnauthorized:
yield from process_wrong_login(request) await process_wrong_login(request)
raise raise
@asyncio.coroutine async def process_wrong_login(request):
def process_wrong_login(request):
"""Process a wrong login attempt.""" """Process a wrong login attempt."""
remote_addr = request[KEY_REAL_IP] remote_addr = request[KEY_REAL_IP]
@ -98,7 +95,7 @@ def process_wrong_login(request):
request.app[KEY_BANNED_IPS].append(new_ban) request.app[KEY_BANNED_IPS].append(new_ban)
hass = request.app['hass'] hass = request.app['hass']
yield from hass.async_add_job( await hass.async_add_job(
update_ip_bans_config, hass.config.path(IP_BANS_FILE), new_ban) update_ip_bans_config, hass.config.path(IP_BANS_FILE), new_ban)
_LOGGER.warning( _LOGGER.warning(

View file

@ -1,5 +1,5 @@
"""Provide cors support for the HTTP component.""" """Provide cors support for the HTTP component."""
import asyncio
from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE
@ -27,8 +27,7 @@ def setup_cors(app, origins):
) for host in origins ) for host in origins
}) })
@asyncio.coroutine async def cors_startup(app):
def cors_startup(app):
"""Initialize cors when app starts up.""" """Initialize cors when app starts up."""
cors_added = set() cors_added = set()

View file

@ -1,5 +1,5 @@
"""Decorator for view methods to help with data validation.""" """Decorator for view methods to help with data validation."""
import asyncio
from functools import wraps from functools import wraps
import logging import logging
@ -24,16 +24,15 @@ class RequestDataValidator:
def __call__(self, method): def __call__(self, method):
"""Decorate a function.""" """Decorate a function."""
@asyncio.coroutine
@wraps(method) @wraps(method)
def wrapper(view, request, *args, **kwargs): async def wrapper(view, request, *args, **kwargs):
"""Wrap a request handler with data validation.""" """Wrap a request handler with data validation."""
data = None data = None
try: try:
data = yield from request.json() data = await request.json()
except ValueError: except ValueError:
if not self._allow_empty or \ if not self._allow_empty or \
(yield from request.content.read()) != b'': (await request.content.read()) != b'':
_LOGGER.error('Invalid JSON received.') _LOGGER.error('Invalid JSON received.')
return view.json_message('Invalid JSON.', 400) return view.json_message('Invalid JSON.', 400)
data = {} data = {}
@ -45,7 +44,7 @@ class RequestDataValidator:
return view.json_message( return view.json_message(
'Message format incorrect: {}'.format(err), 400) 'Message format incorrect: {}'.format(err), 400)
result = yield from method(view, request, *args, **kwargs) result = await method(view, request, *args, **kwargs)
return result return result
return wrapper return wrapper

View file

@ -1,5 +1,5 @@
"""Middleware to fetch real IP.""" """Middleware to fetch real IP."""
import asyncio
from ipaddress import ip_address from ipaddress import ip_address
from aiohttp.web import middleware from aiohttp.web import middleware
@ -14,8 +14,7 @@ from .const import KEY_REAL_IP
def setup_real_ip(app, use_x_forwarded_for): def setup_real_ip(app, use_x_forwarded_for):
"""Create IP Ban middleware for the app.""" """Create IP Ban middleware for the app."""
@middleware @middleware
@asyncio.coroutine async def real_ip_middleware(request, handler):
def real_ip_middleware(request, handler):
"""Real IP middleware.""" """Real IP middleware."""
if (use_x_forwarded_for and if (use_x_forwarded_for and
X_FORWARDED_FOR in request.headers): X_FORWARDED_FOR in request.headers):
@ -25,10 +24,9 @@ def setup_real_ip(app, use_x_forwarded_for):
request[KEY_REAL_IP] = \ request[KEY_REAL_IP] = \
ip_address(request.transport.get_extra_info('peername')[0]) ip_address(request.transport.get_extra_info('peername')[0])
return (yield from handler(request)) return await handler(request)
@asyncio.coroutine async def app_startup(app):
def app_startup(app):
"""Initialize bans when app starts up.""" """Initialize bans when app starts up."""
app.middlewares.append(real_ip_middleware) app.middlewares.append(real_ip_middleware)

View file

@ -1,5 +1,5 @@
"""Static file handling for HTTP component.""" """Static file handling for HTTP component."""
import asyncio
import re import re
from aiohttp import hdrs from aiohttp import hdrs
@ -14,8 +14,7 @@ _FINGERPRINT = re.compile(r'^(.+)-[a-z0-9]{32}\.(\w+)$', re.IGNORECASE)
class CachingStaticResource(StaticResource): class CachingStaticResource(StaticResource):
"""Static Resource handler that will add cache headers.""" """Static Resource handler that will add cache headers."""
@asyncio.coroutine async def _handle(self, request):
def _handle(self, request):
filename = URL(request.match_info['filename']).path filename = URL(request.match_info['filename']).path
try: try:
# PyLint is wrong about resolve not being a member. # PyLint is wrong about resolve not being a member.
@ -32,7 +31,7 @@ class CachingStaticResource(StaticResource):
raise HTTPNotFound() from error raise HTTPNotFound() from error
if filepath.is_dir(): if filepath.is_dir():
return (yield from super()._handle(request)) return await super()._handle(request)
elif filepath.is_file(): elif filepath.is_file():
return CachingFileResponse(filepath, chunk_size=self._chunk_size) return CachingFileResponse(filepath, chunk_size=self._chunk_size)
else: else:
@ -49,26 +48,24 @@ class CachingFileResponse(FileResponse):
orig_sendfile = self._sendfile orig_sendfile = self._sendfile
@asyncio.coroutine async def sendfile(request, fobj, count):
def sendfile(request, fobj, count):
"""Sendfile that includes a cache header.""" """Sendfile that includes a cache header."""
cache_time = 31 * 86400 # = 1 month cache_time = 31 * 86400 # = 1 month
self.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format( self.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format(
cache_time) cache_time)
yield from orig_sendfile(request, fobj, count) await orig_sendfile(request, fobj, count)
# Overwriting like this because __init__ can change implementation. # Overwriting like this because __init__ can change implementation.
self._sendfile = sendfile self._sendfile = sendfile
@middleware @middleware
@asyncio.coroutine async def staticresource_middleware(request, handler):
def staticresource_middleware(request, handler):
"""Middleware to strip out fingerprint from fingerprinted assets.""" """Middleware to strip out fingerprint from fingerprinted assets."""
path = request.path path = request.path
if not path.startswith('/static/') and not path.startswith('/frontend'): if not path.startswith('/static/') and not path.startswith('/frontend'):
return (yield from handler(request)) return await handler(request)
fingerprinted = _FINGERPRINT.match(request.match_info['filename']) fingerprinted = _FINGERPRINT.match(request.match_info['filename'])
@ -76,4 +73,4 @@ def staticresource_middleware(request, handler):
request.match_info['filename'] = \ request.match_info['filename'] = \
'{}.{}'.format(*fingerprinted.groups()) '{}.{}'.format(*fingerprinted.groups())
return (yield from handler(request)) return await handler(request)

View file

@ -0,0 +1,121 @@
"""
This module provides WSGI application to serve the Home Assistant API.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/http/
"""
import asyncio
import json
import logging
from aiohttp import web
from aiohttp.web_exceptions import HTTPUnauthorized
import homeassistant.remote as rem
from homeassistant.core import is_callback
from homeassistant.const import CONTENT_TYPE_JSON
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
_LOGGER = logging.getLogger(__name__)
class HomeAssistantView(object):
"""Base view for all views."""
url = None
extra_urls = []
requires_auth = True # Views inheriting from this class can override this
# pylint: disable=no-self-use
def json(self, result, status_code=200, headers=None):
"""Return a JSON response."""
msg = json.dumps(
result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8')
response = web.Response(
body=msg, content_type=CONTENT_TYPE_JSON, status=status_code,
headers=headers)
response.enable_compression()
return response
def json_message(self, message, status_code=200, message_code=None,
headers=None):
"""Return a JSON message response."""
data = {'message': message}
if message_code is not None:
data['code'] = message_code
return self.json(data, status_code, headers=headers)
# pylint: disable=no-self-use
async def file(self, request, fil):
"""Return a file."""
assert isinstance(fil, str), 'only string paths allowed'
return web.FileResponse(fil)
def register(self, router):
"""Register the view with a router."""
assert self.url is not None, 'No url set for view'
urls = [self.url] + self.extra_urls
for method in ('get', 'post', 'delete', 'put'):
handler = getattr(self, method, None)
if not handler:
continue
handler = request_handler_factory(self, handler)
for url in urls:
router.add_route(method, url, handler)
# aiohttp_cors does not work with class based views
# self.app.router.add_route('*', self.url, self, name=self.name)
# for url in self.extra_urls:
# self.app.router.add_route('*', url, self)
def request_handler_factory(view, handler):
"""Wrap the handler classes."""
assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \
"Handler should be a coroutine or a callback."
async def handle(request):
"""Handle incoming request."""
if not request.app['hass'].is_running:
return web.Response(status=503)
authenticated = request.get(KEY_AUTHENTICATED, False)
if view.requires_auth and not authenticated:
raise HTTPUnauthorized()
_LOGGER.info('Serving %s to %s (auth: %s)',
request.path, request.get(KEY_REAL_IP), authenticated)
result = handler(request, **request.match_info)
if asyncio.iscoroutine(result):
result = await result
if isinstance(result, web.StreamResponse):
# The method handler returned a ready-made Response, how nice of it
return result
status_code = 200
if isinstance(result, tuple):
result, status_code = result
if isinstance(result, str):
result = result.encode('utf-8')
elif result is None:
result = b''
elif not isinstance(result, bytes):
assert False, ('Result should be None, string, bytes or Response. '
'Got: {}').format(result)
return web.Response(body=result, status=status_code)
return handle

View file

@ -191,8 +191,7 @@ def result_message(iden, result=None):
} }
@asyncio.coroutine async def async_setup(hass, config):
def async_setup(hass, config):
"""Initialize the websocket API.""" """Initialize the websocket API."""
hass.http.register_view(WebsocketAPIView) hass.http.register_view(WebsocketAPIView)
return True return True
@ -205,11 +204,10 @@ class WebsocketAPIView(HomeAssistantView):
url = URL url = URL
requires_auth = False requires_auth = False
@asyncio.coroutine async def get(self, request):
def get(self, request):
"""Handle an incoming websocket connection.""" """Handle an incoming websocket connection."""
# pylint: disable=no-self-use # pylint: disable=no-self-use
return ActiveConnection(request.app['hass'], request).handle() return await ActiveConnection(request.app['hass'], request).handle()
class ActiveConnection: class ActiveConnection:
@ -233,17 +231,16 @@ class ActiveConnection:
"""Print an error message.""" """Print an error message."""
_LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2) _LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2)
@asyncio.coroutine async def _writer(self):
def _writer(self):
"""Write outgoing messages.""" """Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler # Exceptions if Socket disconnected or cancelled by connection handler
with suppress(RuntimeError, *CANCELLATION_ERRORS): with suppress(RuntimeError, *CANCELLATION_ERRORS):
while not self.wsock.closed: while not self.wsock.closed:
message = yield from self.to_write.get() message = await self.to_write.get()
if message is None: if message is None:
break break
self.debug("Sending", message) self.debug("Sending", message)
yield from self.wsock.send_json(message, dumps=JSON_DUMP) await self.wsock.send_json(message, dumps=JSON_DUMP)
@callback @callback
def send_message_outside(self, message): def send_message_outside(self, message):
@ -266,12 +263,11 @@ class ActiveConnection:
self._handle_task.cancel() self._handle_task.cancel()
self._writer_task.cancel() self._writer_task.cancel()
@asyncio.coroutine async def handle(self):
def handle(self):
"""Handle the websocket connection.""" """Handle the websocket connection."""
request = self.request request = self.request
wsock = self.wsock = web.WebSocketResponse(heartbeat=55) wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
yield from wsock.prepare(request) await wsock.prepare(request)
self.debug("Connected") self.debug("Connected")
# Get a reference to current task so we can cancel our connection # Get a reference to current task so we can cancel our connection
@ -294,8 +290,8 @@ class ActiveConnection:
authenticated = True authenticated = True
else: else:
yield from self.wsock.send_json(auth_required_message()) await self.wsock.send_json(auth_required_message())
msg = yield from wsock.receive_json() msg = await wsock.receive_json()
msg = AUTH_MESSAGE_SCHEMA(msg) msg = AUTH_MESSAGE_SCHEMA(msg)
if validate_password(request, msg['api_password']): if validate_password(request, msg['api_password']):
@ -303,18 +299,18 @@ class ActiveConnection:
else: else:
self.debug("Invalid password") self.debug("Invalid password")
yield from self.wsock.send_json( await self.wsock.send_json(
auth_invalid_message('Invalid password')) auth_invalid_message('Invalid password'))
if not authenticated: if not authenticated:
yield from process_wrong_login(request) await process_wrong_login(request)
return wsock return wsock
yield from self.wsock.send_json(auth_ok_message()) await self.wsock.send_json(auth_ok_message())
# ---------- AUTH PHASE OVER ---------- # ---------- AUTH PHASE OVER ----------
msg = yield from wsock.receive_json() msg = await wsock.receive_json()
last_id = 0 last_id = 0
while msg: while msg:
@ -332,7 +328,7 @@ class ActiveConnection:
getattr(self, handler_name)(msg) getattr(self, handler_name)(msg)
last_id = cur_id last_id = cur_id
msg = yield from wsock.receive_json() msg = await wsock.receive_json()
except vol.Invalid as err: except vol.Invalid as err:
error_msg = "Message incorrectly formatted: " error_msg = "Message incorrectly formatted: "
@ -394,11 +390,11 @@ class ActiveConnection:
self.to_write.put_nowait(final_message) self.to_write.put_nowait(final_message)
self.to_write.put_nowait(None) self.to_write.put_nowait(None)
# Make sure all error messages are written before closing # Make sure all error messages are written before closing
yield from self._writer_task await self._writer_task
except asyncio.QueueFull: except asyncio.QueueFull:
self._writer_task.cancel() self._writer_task.cancel()
yield from wsock.close() await wsock.close()
self.debug("Closed connection") self.debug("Closed connection")
return wsock return wsock
@ -410,8 +406,7 @@ class ActiveConnection:
""" """
msg = SUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg) msg = SUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg)
@asyncio.coroutine async def forward_events(event):
def forward_events(event):
"""Forward events to websocket.""" """Forward events to websocket."""
if event.event_type == EVENT_TIME_CHANGED: if event.event_type == EVENT_TIME_CHANGED:
return return
@ -447,10 +442,9 @@ class ActiveConnection:
""" """
msg = CALL_SERVICE_MESSAGE_SCHEMA(msg) msg = CALL_SERVICE_MESSAGE_SCHEMA(msg)
@asyncio.coroutine async def call_service_helper(msg):
def call_service_helper(msg):
"""Call a service and fire complete message.""" """Call a service and fire complete message."""
yield from self.hass.services.async_call( await self.hass.services.async_call(
msg['domain'], msg['service'], msg.get('service_data'), True) msg['domain'], msg['service'], msg.get('service_data'), True)
self.send_message_outside(result_message(msg['id'])) self.send_message_outside(result_message(msg['id']))
@ -473,10 +467,9 @@ class ActiveConnection:
""" """
msg = GET_SERVICES_MESSAGE_SCHEMA(msg) msg = GET_SERVICES_MESSAGE_SCHEMA(msg)
@asyncio.coroutine async def get_services_helper(msg):
def get_services_helper(msg):
"""Get available services and fire complete message.""" """Get available services and fire complete message."""
descriptions = yield from async_get_all_descriptions(self.hass) descriptions = await async_get_all_descriptions(self.hass)
self.send_message_outside(result_message(msg['id'], descriptions)) self.send_message_outside(result_message(msg['id'], descriptions))
self.hass.async_add_job(get_services_helper(msg)) self.hass.async_add_job(get_services_helper(msg))

View file

@ -1,5 +1,4 @@
"""Tests for the HTTP component.""" """Tests for the HTTP component."""
import asyncio
from ipaddress import ip_address from ipaddress import ip_address
from aiohttp import web from aiohttp import web
@ -18,18 +17,16 @@ def mock_real_ip(app):
nonlocal ip_to_mock nonlocal ip_to_mock
ip_to_mock = value ip_to_mock = value
@asyncio.coroutine
@web.middleware @web.middleware
def mock_real_ip(request, handler): async def mock_real_ip(request, handler):
"""Mock Real IP middleware.""" """Mock Real IP middleware."""
nonlocal ip_to_mock nonlocal ip_to_mock
request[KEY_REAL_IP] = ip_address(ip_to_mock) request[KEY_REAL_IP] = ip_address(ip_to_mock)
return (yield from handler(request)) return (await handler(request))
@asyncio.coroutine async def real_ip_startup(app):
def real_ip_startup(app):
"""Startup of real ip.""" """Startup of real ip."""
app.middlewares.insert(0, mock_real_ip) app.middlewares.insert(0, mock_real_ip)

View file

@ -1,6 +1,5 @@
"""The tests for the Home Assistant HTTP component.""" """The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access # pylint: disable=protected-access
import asyncio
from ipaddress import ip_network from ipaddress import ip_network
from unittest.mock import patch from unittest.mock import patch
@ -30,8 +29,7 @@ TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1'] UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
@asyncio.coroutine async def mock_handler(request):
def mock_handler(request):
"""Return if request was authenticated.""" """Return if request was authenticated."""
if not request[KEY_AUTHENTICATED]: if not request[KEY_AUTHENTICATED]:
raise HTTPUnauthorized raise HTTPUnauthorized
@ -47,84 +45,79 @@ def app():
return app return app
@asyncio.coroutine async def test_auth_middleware_loaded_by_default(hass):
def test_auth_middleware_loaded_by_default(hass):
"""Test accessing to server from banned IP when feature is off.""" """Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_auth') as mock_setup: with patch('homeassistant.components.http.setup_auth') as mock_setup:
yield from async_setup_component(hass, 'http', { await async_setup_component(hass, 'http', {
'http': {} 'http': {}
}) })
assert len(mock_setup.mock_calls) == 1 assert len(mock_setup.mock_calls) == 1
@asyncio.coroutine async def test_access_without_password(app, test_client):
def test_access_without_password(app, test_client):
"""Test access without password.""" """Test access without password."""
setup_auth(app, [], None) setup_auth(app, [], None)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.get('/') resp = await client.get('/')
assert resp.status == 200 assert resp.status == 200
@asyncio.coroutine async def test_access_with_password_in_header(app, test_client):
def test_access_with_password_in_header(app, test_client):
"""Test access with password in URL.""" """Test access with password in URL."""
setup_auth(app, [], API_PASSWORD) setup_auth(app, [], API_PASSWORD)
client = yield from test_client(app) client = await test_client(app)
req = yield from client.get( req = await client.get(
'/', headers={HTTP_HEADER_HA_AUTH: API_PASSWORD}) '/', headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status == 200 assert req.status == 200
req = yield from client.get( req = await client.get(
'/', headers={HTTP_HEADER_HA_AUTH: 'wrong-pass'}) '/', headers={HTTP_HEADER_HA_AUTH: 'wrong-pass'})
assert req.status == 401 assert req.status == 401
@asyncio.coroutine async def test_access_with_password_in_query(app, test_client):
def test_access_with_password_in_query(app, test_client):
"""Test access without password.""" """Test access without password."""
setup_auth(app, [], API_PASSWORD) setup_auth(app, [], API_PASSWORD)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.get('/', params={ resp = await client.get('/', params={
'api_password': API_PASSWORD 'api_password': API_PASSWORD
}) })
assert resp.status == 200 assert resp.status == 200
resp = yield from client.get('/') resp = await client.get('/')
assert resp.status == 401 assert resp.status == 401
resp = yield from client.get('/', params={ resp = await client.get('/', params={
'api_password': 'wrong-password' 'api_password': 'wrong-password'
}) })
assert resp.status == 401 assert resp.status == 401
@asyncio.coroutine async def test_basic_auth_works(app, test_client):
def test_basic_auth_works(app, test_client):
"""Test access with basic authentication.""" """Test access with basic authentication."""
setup_auth(app, [], API_PASSWORD) setup_auth(app, [], API_PASSWORD)
client = yield from test_client(app) client = await test_client(app)
req = yield from client.get( req = await client.get(
'/', '/',
auth=BasicAuth('homeassistant', API_PASSWORD)) auth=BasicAuth('homeassistant', API_PASSWORD))
assert req.status == 200 assert req.status == 200
req = yield from client.get( req = await client.get(
'/', '/',
auth=BasicAuth('wrong_username', API_PASSWORD)) auth=BasicAuth('wrong_username', API_PASSWORD))
assert req.status == 401 assert req.status == 401
req = yield from client.get( req = await client.get(
'/', '/',
auth=BasicAuth('homeassistant', 'wrong password')) auth=BasicAuth('homeassistant', 'wrong password'))
assert req.status == 401 assert req.status == 401
req = yield from client.get( req = await client.get(
'/', '/',
headers={ headers={
'authorization': 'NotBasic abcdefg' 'authorization': 'NotBasic abcdefg'
@ -132,8 +125,7 @@ def test_basic_auth_works(app, test_client):
assert req.status == 401 assert req.status == 401
@asyncio.coroutine async def test_access_with_trusted_ip(test_client):
def test_access_with_trusted_ip(test_client):
"""Test access with an untrusted ip address.""" """Test access with an untrusted ip address."""
app = web.Application() app = web.Application()
app.router.add_get('/', mock_handler) app.router.add_get('/', mock_handler)
@ -141,16 +133,16 @@ def test_access_with_trusted_ip(test_client):
setup_auth(app, TRUSTED_NETWORKS, 'some-pass') setup_auth(app, TRUSTED_NETWORKS, 'some-pass')
set_mock_ip = mock_real_ip(app) set_mock_ip = mock_real_ip(app)
client = yield from test_client(app) client = await test_client(app)
for remote_addr in UNTRUSTED_ADDRESSES: for remote_addr in UNTRUSTED_ADDRESSES:
set_mock_ip(remote_addr) set_mock_ip(remote_addr)
resp = yield from client.get('/') resp = await client.get('/')
assert resp.status == 401, \ assert resp.status == 401, \
"{} shouldn't be trusted".format(remote_addr) "{} shouldn't be trusted".format(remote_addr)
for remote_addr in TRUSTED_ADDRESSES: for remote_addr in TRUSTED_ADDRESSES:
set_mock_ip(remote_addr) set_mock_ip(remote_addr)
resp = yield from client.get('/') resp = await client.get('/')
assert resp.status == 200, \ assert resp.status == 200, \
"{} should be trusted".format(remote_addr) "{} should be trusted".format(remote_addr)

View file

@ -1,6 +1,5 @@
"""The tests for the Home Assistant HTTP component.""" """The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access # pylint: disable=protected-access
import asyncio
from unittest.mock import patch, mock_open from unittest.mock import patch, mock_open
from aiohttp import web from aiohttp import web
@ -16,8 +15,7 @@ from . import mock_real_ip
BANNED_IPS = ['200.201.202.203', '100.64.0.2'] BANNED_IPS = ['200.201.202.203', '100.64.0.2']
@asyncio.coroutine async def test_access_from_banned_ip(hass, test_client):
def test_access_from_banned_ip(hass, test_client):
"""Test accessing to server from banned IP. Both trusted and not.""" """Test accessing to server from banned IP. Both trusted and not."""
app = web.Application() app = web.Application()
setup_bans(hass, app, 5) setup_bans(hass, app, 5)
@ -26,19 +24,18 @@ def test_access_from_banned_ip(hass, test_client):
with patch('homeassistant.components.http.ban.load_ip_bans_config', with patch('homeassistant.components.http.ban.load_ip_bans_config',
return_value=[IpBan(banned_ip) for banned_ip return_value=[IpBan(banned_ip) for banned_ip
in BANNED_IPS]): in BANNED_IPS]):
client = yield from test_client(app) client = await test_client(app)
for remote_addr in BANNED_IPS: for remote_addr in BANNED_IPS:
set_real_ip(remote_addr) set_real_ip(remote_addr)
resp = yield from client.get('/') resp = await client.get('/')
assert resp.status == 403 assert resp.status == 403
@asyncio.coroutine async def test_ban_middleware_not_loaded_by_config(hass):
def test_ban_middleware_not_loaded_by_config(hass):
"""Test accessing to server from banned IP when feature is off.""" """Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_bans') as mock_setup: with patch('homeassistant.components.http.setup_bans') as mock_setup:
yield from async_setup_component(hass, 'http', { await async_setup_component(hass, 'http', {
'http': { 'http': {
http.CONF_IP_BAN_ENABLED: False, http.CONF_IP_BAN_ENABLED: False,
} }
@ -47,25 +44,22 @@ def test_ban_middleware_not_loaded_by_config(hass):
assert len(mock_setup.mock_calls) == 0 assert len(mock_setup.mock_calls) == 0
@asyncio.coroutine async def test_ban_middleware_loaded_by_default(hass):
def test_ban_middleware_loaded_by_default(hass):
"""Test accessing to server from banned IP when feature is off.""" """Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_bans') as mock_setup: with patch('homeassistant.components.http.setup_bans') as mock_setup:
yield from async_setup_component(hass, 'http', { await async_setup_component(hass, 'http', {
'http': {} 'http': {}
}) })
assert len(mock_setup.mock_calls) == 1 assert len(mock_setup.mock_calls) == 1
@asyncio.coroutine async def test_ip_bans_file_creation(hass, test_client):
def test_ip_bans_file_creation(hass, test_client):
"""Testing if banned IP file created.""" """Testing if banned IP file created."""
app = web.Application() app = web.Application()
app['hass'] = hass app['hass'] = hass
@asyncio.coroutine async def unauth_handler(request):
def unauth_handler(request):
"""Return a mock web response.""" """Return a mock web response."""
raise HTTPUnauthorized raise HTTPUnauthorized
@ -76,21 +70,21 @@ def test_ip_bans_file_creation(hass, test_client):
with patch('homeassistant.components.http.ban.load_ip_bans_config', with patch('homeassistant.components.http.ban.load_ip_bans_config',
return_value=[IpBan(banned_ip) for banned_ip return_value=[IpBan(banned_ip) for banned_ip
in BANNED_IPS]): in BANNED_IPS]):
client = yield from test_client(app) client = await test_client(app)
m = mock_open() m = mock_open()
with patch('homeassistant.components.http.ban.open', m, create=True): with patch('homeassistant.components.http.ban.open', m, create=True):
resp = yield from client.get('/') resp = await client.get('/')
assert resp.status == 401 assert resp.status == 401
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS)
assert m.call_count == 0 assert m.call_count == 0
resp = yield from client.get('/') resp = await client.get('/')
assert resp.status == 401 assert resp.status == 401
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1 assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a') m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
resp = yield from client.get('/') resp = await client.get('/')
assert resp.status == 403 assert resp.status == 403
assert m.call_count == 1 assert m.call_count == 1

View file

@ -1,5 +1,4 @@
"""Test cors for the HTTP component.""" """Test cors for the HTTP component."""
import asyncio
from unittest.mock import patch from unittest.mock import patch
from aiohttp import web from aiohttp import web
@ -20,22 +19,20 @@ from homeassistant.components.http.cors import setup_cors
TRUSTED_ORIGIN = 'https://home-assistant.io' TRUSTED_ORIGIN = 'https://home-assistant.io'
@asyncio.coroutine async def test_cors_middleware_not_loaded_by_default(hass):
def test_cors_middleware_not_loaded_by_default(hass):
"""Test accessing to server from banned IP when feature is off.""" """Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_cors') as mock_setup: with patch('homeassistant.components.http.setup_cors') as mock_setup:
yield from async_setup_component(hass, 'http', { await async_setup_component(hass, 'http', {
'http': {} 'http': {}
}) })
assert len(mock_setup.mock_calls) == 0 assert len(mock_setup.mock_calls) == 0
@asyncio.coroutine async def test_cors_middleware_loaded_from_config(hass):
def test_cors_middleware_loaded_from_config(hass):
"""Test accessing to server from banned IP when feature is off.""" """Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_cors') as mock_setup: with patch('homeassistant.components.http.setup_cors') as mock_setup:
yield from async_setup_component(hass, 'http', { await async_setup_component(hass, 'http', {
'http': { 'http': {
'cors_allowed_origins': ['http://home-assistant.io'] 'cors_allowed_origins': ['http://home-assistant.io']
} }
@ -44,8 +41,7 @@ def test_cors_middleware_loaded_from_config(hass):
assert len(mock_setup.mock_calls) == 1 assert len(mock_setup.mock_calls) == 1
@asyncio.coroutine async def mock_handler(request):
def mock_handler(request):
"""Return if request was authenticated.""" """Return if request was authenticated."""
return web.Response(status=200) return web.Response(status=200)
@ -59,10 +55,9 @@ def client(loop, test_client):
return loop.run_until_complete(test_client(app)) return loop.run_until_complete(test_client(app))
@asyncio.coroutine async def test_cors_requests(client):
def test_cors_requests(client):
"""Test cross origin requests.""" """Test cross origin requests."""
req = yield from client.get('/', headers={ req = await client.get('/', headers={
ORIGIN: TRUSTED_ORIGIN ORIGIN: TRUSTED_ORIGIN
}) })
assert req.status == 200 assert req.status == 200
@ -70,7 +65,7 @@ def test_cors_requests(client):
TRUSTED_ORIGIN TRUSTED_ORIGIN
# With password in URL # With password in URL
req = yield from client.get('/', params={ req = await client.get('/', params={
'api_password': 'some-pass' 'api_password': 'some-pass'
}, headers={ }, headers={
ORIGIN: TRUSTED_ORIGIN ORIGIN: TRUSTED_ORIGIN
@ -80,7 +75,7 @@ def test_cors_requests(client):
TRUSTED_ORIGIN TRUSTED_ORIGIN
# With password in headers # With password in headers
req = yield from client.get('/', headers={ req = await client.get('/', headers={
HTTP_HEADER_HA_AUTH: 'some-pass', HTTP_HEADER_HA_AUTH: 'some-pass',
ORIGIN: TRUSTED_ORIGIN ORIGIN: TRUSTED_ORIGIN
}) })
@ -89,10 +84,9 @@ def test_cors_requests(client):
TRUSTED_ORIGIN TRUSTED_ORIGIN
@asyncio.coroutine async def test_cors_preflight_allowed(client):
def test_cors_preflight_allowed(client):
"""Test cross origin resource sharing preflight (OPTIONS) request.""" """Test cross origin resource sharing preflight (OPTIONS) request."""
req = yield from client.options('/', headers={ req = await client.options('/', headers={
ORIGIN: TRUSTED_ORIGIN, ORIGIN: TRUSTED_ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD: 'GET', ACCESS_CONTROL_REQUEST_METHOD: 'GET',
ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access' ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access'

View file

@ -1,5 +1,4 @@
"""Test data validator decorator.""" """Test data validator decorator."""
import asyncio
from unittest.mock import Mock from unittest.mock import Mock
from aiohttp import web from aiohttp import web
@ -9,8 +8,7 @@ from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.components.http.data_validator import RequestDataValidator
@asyncio.coroutine async def get_client(test_client, validator):
def get_client(test_client, validator):
"""Generate a client that hits a view decorated with validator.""" """Generate a client that hits a view decorated with validator."""
app = web.Application() app = web.Application()
app['hass'] = Mock(is_running=True) app['hass'] = Mock(is_running=True)
@ -20,58 +18,55 @@ def get_client(test_client, validator):
name = 'test' name = 'test'
requires_auth = False requires_auth = False
@asyncio.coroutine
@validator @validator
def post(self, request, data): async def post(self, request, data):
"""Test method.""" """Test method."""
return b'' return b''
TestView().register(app.router) TestView().register(app.router)
client = yield from test_client(app) client = await test_client(app)
return client return client
@asyncio.coroutine async def test_validator(test_client):
def test_validator(test_client):
"""Test the validator.""" """Test the validator."""
client = yield from get_client( client = await get_client(
test_client, RequestDataValidator(vol.Schema({ test_client, RequestDataValidator(vol.Schema({
vol.Required('test'): str vol.Required('test'): str
}))) })))
resp = yield from client.post('/', json={ resp = await client.post('/', json={
'test': 'bla' 'test': 'bla'
}) })
assert resp.status == 200 assert resp.status == 200
resp = yield from client.post('/', json={ resp = await client.post('/', json={
'test': 100 'test': 100
}) })
assert resp.status == 400 assert resp.status == 400
resp = yield from client.post('/') resp = await client.post('/')
assert resp.status == 400 assert resp.status == 400
@asyncio.coroutine async def test_validator_allow_empty(test_client):
def test_validator_allow_empty(test_client):
"""Test the validator with empty data.""" """Test the validator with empty data."""
client = yield from get_client( client = await get_client(
test_client, RequestDataValidator(vol.Schema({ test_client, RequestDataValidator(vol.Schema({
# Although we allow empty, our schema should still be able # Although we allow empty, our schema should still be able
# to validate an empty dict. # to validate an empty dict.
vol.Optional('test'): str vol.Optional('test'): str
}), allow_empty=True)) }), allow_empty=True))
resp = yield from client.post('/', json={ resp = await client.post('/', json={
'test': 'bla' 'test': 'bla'
}) })
assert resp.status == 200 assert resp.status == 200
resp = yield from client.post('/', json={ resp = await client.post('/', json={
'test': 100 'test': 100
}) })
assert resp.status == 400 assert resp.status == 400
resp = yield from client.post('/') resp = await client.post('/')
assert resp.status == 200 assert resp.status == 200

View file

@ -1,6 +1,4 @@
"""The tests for the Home Assistant HTTP component.""" """The tests for the Home Assistant HTTP component."""
import asyncio
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.components.http as http import homeassistant.components.http as http
@ -12,16 +10,14 @@ class TestView(http.HomeAssistantView):
name = 'test' name = 'test'
url = '/hello' url = '/hello'
@asyncio.coroutine async def get(self, request):
def get(self, request):
"""Return a get request.""" """Return a get request."""
return 'hello' return 'hello'
@asyncio.coroutine async def test_registering_view_while_running(hass, test_client, unused_port):
def test_registering_view_while_running(hass, test_client, unused_port):
"""Test that we can register a view while the server is running.""" """Test that we can register a view while the server is running."""
yield from async_setup_component( await async_setup_component(
hass, http.DOMAIN, { hass, http.DOMAIN, {
http.DOMAIN: { http.DOMAIN: {
http.CONF_SERVER_PORT: unused_port(), http.CONF_SERVER_PORT: unused_port(),
@ -29,15 +25,14 @@ def test_registering_view_while_running(hass, test_client, unused_port):
} }
) )
yield from hass.async_start() await hass.async_start()
# This raises a RuntimeError if app is frozen # This raises a RuntimeError if app is frozen
hass.http.register_view(TestView) hass.http.register_view(TestView)
@asyncio.coroutine async def test_api_base_url_with_domain(hass):
def test_api_base_url_with_domain(hass):
"""Test setting API URL.""" """Test setting API URL."""
result = yield from async_setup_component(hass, 'http', { result = await async_setup_component(hass, 'http', {
'http': { 'http': {
'base_url': 'example.com' 'base_url': 'example.com'
} }
@ -46,10 +41,9 @@ def test_api_base_url_with_domain(hass):
assert hass.config.api.base_url == 'http://example.com' assert hass.config.api.base_url == 'http://example.com'
@asyncio.coroutine async def test_api_base_url_with_ip(hass):
def test_api_base_url_with_ip(hass):
"""Test setting api url.""" """Test setting api url."""
result = yield from async_setup_component(hass, 'http', { result = await async_setup_component(hass, 'http', {
'http': { 'http': {
'server_host': '1.1.1.1' 'server_host': '1.1.1.1'
} }
@ -58,10 +52,9 @@ def test_api_base_url_with_ip(hass):
assert hass.config.api.base_url == 'http://1.1.1.1:8123' assert hass.config.api.base_url == 'http://1.1.1.1:8123'
@asyncio.coroutine async def test_api_base_url_with_ip_port(hass):
def test_api_base_url_with_ip_port(hass):
"""Test setting api url.""" """Test setting api url."""
result = yield from async_setup_component(hass, 'http', { result = await async_setup_component(hass, 'http', {
'http': { 'http': {
'base_url': '1.1.1.1:8124' 'base_url': '1.1.1.1:8124'
} }
@ -70,10 +63,9 @@ def test_api_base_url_with_ip_port(hass):
assert hass.config.api.base_url == 'http://1.1.1.1:8124' assert hass.config.api.base_url == 'http://1.1.1.1:8124'
@asyncio.coroutine async def test_api_no_base_url(hass):
def test_api_no_base_url(hass):
"""Test setting api url.""" """Test setting api url."""
result = yield from async_setup_component(hass, 'http', { result = await async_setup_component(hass, 'http', {
'http': { 'http': {
} }
}) })
@ -81,10 +73,9 @@ def test_api_no_base_url(hass):
assert hass.config.api.base_url == 'http://127.0.0.1:8123' assert hass.config.api.base_url == 'http://127.0.0.1:8123'
@asyncio.coroutine async def test_not_log_password(hass, unused_port, test_client, caplog):
def test_not_log_password(hass, unused_port, test_client, caplog):
"""Test access with password doesn't get logged.""" """Test access with password doesn't get logged."""
result = yield from async_setup_component(hass, 'api', { result = await async_setup_component(hass, 'api', {
'http': { 'http': {
http.CONF_SERVER_PORT: unused_port(), http.CONF_SERVER_PORT: unused_port(),
http.CONF_API_PASSWORD: 'some-pass' http.CONF_API_PASSWORD: 'some-pass'
@ -92,9 +83,9 @@ def test_not_log_password(hass, unused_port, test_client, caplog):
}) })
assert result assert result
client = yield from test_client(hass.http.app) client = await test_client(hass.http.app)
resp = yield from client.get('/api/', params={ resp = await client.get('/api/', params={
'api_password': 'some-pass' 'api_password': 'some-pass'
}) })

View file

@ -1,6 +1,4 @@
"""Test real IP middleware.""" """Test real IP middleware."""
import asyncio
from aiohttp import web from aiohttp import web
from aiohttp.hdrs import X_FORWARDED_FOR from aiohttp.hdrs import X_FORWARDED_FOR
@ -8,41 +6,38 @@ from homeassistant.components.http.real_ip import setup_real_ip
from homeassistant.components.http.const import KEY_REAL_IP from homeassistant.components.http.const import KEY_REAL_IP
@asyncio.coroutine async def mock_handler(request):
def mock_handler(request):
"""Handler that returns the real IP as text.""" """Handler that returns the real IP as text."""
return web.Response(text=str(request[KEY_REAL_IP])) return web.Response(text=str(request[KEY_REAL_IP]))
@asyncio.coroutine async def test_ignore_x_forwarded_for(test_client):
def test_ignore_x_forwarded_for(test_client):
"""Test that we get the IP from the transport.""" """Test that we get the IP from the transport."""
app = web.Application() app = web.Application()
app.router.add_get('/', mock_handler) app.router.add_get('/', mock_handler)
setup_real_ip(app, False) setup_real_ip(app, False)
mock_api_client = yield from test_client(app) mock_api_client = await test_client(app)
resp = yield from mock_api_client.get('/', headers={ resp = await mock_api_client.get('/', headers={
X_FORWARDED_FOR: '255.255.255.255' X_FORWARDED_FOR: '255.255.255.255'
}) })
assert resp.status == 200 assert resp.status == 200
text = yield from resp.text() text = await resp.text()
assert text != '255.255.255.255' assert text != '255.255.255.255'
@asyncio.coroutine async def test_use_x_forwarded_for(test_client):
def test_use_x_forwarded_for(test_client):
"""Test that we get the IP from the transport.""" """Test that we get the IP from the transport."""
app = web.Application() app = web.Application()
app.router.add_get('/', mock_handler) app.router.add_get('/', mock_handler)
setup_real_ip(app, True) setup_real_ip(app, True)
mock_api_client = yield from test_client(app) mock_api_client = await test_client(app)
resp = yield from mock_api_client.get('/', headers={ resp = await mock_api_client.get('/', headers={
X_FORWARDED_FOR: '255.255.255.255' X_FORWARDED_FOR: '255.255.255.255'
}) })
assert resp.status == 200 assert resp.status == 200
text = yield from resp.text() text = await resp.text()
assert text == '255.255.255.255' assert text == '255.255.255.255'