Cleanup http (#12424)
* Clean up HTTP component * Clean up HTTP mock * Remove unused import * Fix test * Lint
This commit is contained in:
parent
ad8fe8a93a
commit
f32911d036
28 changed files with 811 additions and 1014 deletions
|
@ -14,7 +14,7 @@ from homeassistant.const import (
|
||||||
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
||||||
)
|
)
|
||||||
from homeassistant.components.http import REQUIREMENTS # NOQA
|
from homeassistant.components.http import REQUIREMENTS # NOQA
|
||||||
from homeassistant.components.http import HomeAssistantWSGI
|
from homeassistant.components.http import HomeAssistantHTTP
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.deprecation import get_deprecated
|
from homeassistant.helpers.deprecation import get_deprecated
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
@ -86,7 +86,7 @@ def setup(hass, yaml_config):
|
||||||
"""Activate the emulated_hue component."""
|
"""Activate the emulated_hue component."""
|
||||||
config = Config(hass, yaml_config.get(DOMAIN, {}))
|
config = Config(hass, yaml_config.get(DOMAIN, {}))
|
||||||
|
|
||||||
server = HomeAssistantWSGI(
|
server = HomeAssistantHTTP(
|
||||||
hass,
|
hass,
|
||||||
server_host=config.host_ip_addr,
|
server_host=config.host_ip_addr,
|
||||||
server_port=config.listen_port,
|
server_port=config.listen_port,
|
||||||
|
|
|
@ -17,7 +17,7 @@ import jinja2
|
||||||
|
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.components.http.auth import is_trusted_ip
|
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
||||||
from homeassistant.config import find_config_file, load_yaml_config_file
|
from homeassistant.config import find_config_file, load_yaml_config_file
|
||||||
from homeassistant.const import CONF_NAME, EVENT_THEMES_UPDATED
|
from homeassistant.const import CONF_NAME, EVENT_THEMES_UPDATED
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
@ -490,7 +490,7 @@ class IndexView(HomeAssistantView):
|
||||||
panel_url = hass.data[DATA_PANELS][panel].webcomponent_url_es5
|
panel_url = hass.data[DATA_PANELS][panel].webcomponent_url_es5
|
||||||
|
|
||||||
no_auth = '1'
|
no_auth = '1'
|
||||||
if hass.config.api.api_password and not is_trusted_ip(request):
|
if hass.config.api.api_password and not request[KEY_AUTHENTICATED]:
|
||||||
# do not try to auto connect on load
|
# do not try to auto connect on load
|
||||||
no_auth = '0'
|
no_auth = '0'
|
||||||
|
|
||||||
|
|
|
@ -12,35 +12,28 @@ import os
|
||||||
import ssl
|
import ssl
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE
|
|
||||||
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently
|
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
SERVER_PORT, CONTENT_TYPE_JSON, HTTP_HEADER_HA_AUTH,
|
SERVER_PORT, CONTENT_TYPE_JSON,
|
||||||
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START,
|
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START,)
|
||||||
HTTP_HEADER_X_REQUESTED_WITH)
|
|
||||||
from homeassistant.core import is_callback
|
from homeassistant.core import is_callback
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
import homeassistant.remote as rem
|
import homeassistant.remote as rem
|
||||||
import homeassistant.util as hass_util
|
import homeassistant.util as hass_util
|
||||||
from homeassistant.util.logging import HideSensitiveDataFilter
|
from homeassistant.util.logging import HideSensitiveDataFilter
|
||||||
|
|
||||||
from .auth import auth_middleware
|
from .auth import setup_auth
|
||||||
from .ban import ban_middleware
|
from .ban import setup_bans
|
||||||
from .const import (
|
from .cors import setup_cors
|
||||||
KEY_BANS_ENABLED, KEY_AUTHENTICATED, KEY_LOGIN_THRESHOLD,
|
from .real_ip import setup_real_ip
|
||||||
KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR)
|
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
|
||||||
from .static import (
|
from .static import (
|
||||||
CachingFileResponse, CachingStaticResource, staticresource_middleware)
|
CachingFileResponse, CachingStaticResource, staticresource_middleware)
|
||||||
from .util import get_real_ip
|
|
||||||
|
|
||||||
REQUIREMENTS = ['aiohttp_cors==0.6.0']
|
REQUIREMENTS = ['aiohttp_cors==0.6.0']
|
||||||
|
|
||||||
ALLOWED_CORS_HEADERS = [
|
|
||||||
ORIGIN, ACCEPT, HTTP_HEADER_X_REQUESTED_WITH, CONTENT_TYPE,
|
|
||||||
HTTP_HEADER_HA_AUTH]
|
|
||||||
|
|
||||||
DOMAIN = 'http'
|
DOMAIN = 'http'
|
||||||
|
|
||||||
CONF_API_PASSWORD = 'api_password'
|
CONF_API_PASSWORD = 'api_password'
|
||||||
|
@ -127,7 +120,7 @@ def async_setup(hass, config):
|
||||||
logging.getLogger('aiohttp.access').addFilter(
|
logging.getLogger('aiohttp.access').addFilter(
|
||||||
HideSensitiveDataFilter(api_password))
|
HideSensitiveDataFilter(api_password))
|
||||||
|
|
||||||
server = HomeAssistantWSGI(
|
server = HomeAssistantHTTP(
|
||||||
hass,
|
hass,
|
||||||
server_host=server_host,
|
server_host=server_host,
|
||||||
server_port=server_port,
|
server_port=server_port,
|
||||||
|
@ -173,25 +166,29 @@ def async_setup(hass, config):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class HomeAssistantWSGI(object):
|
class HomeAssistantHTTP(object):
|
||||||
"""WSGI server for Home Assistant."""
|
"""HTTP server for Home Assistant."""
|
||||||
|
|
||||||
def __init__(self, hass, api_password, ssl_certificate,
|
def __init__(self, hass, api_password, ssl_certificate,
|
||||||
ssl_key, server_host, server_port, cors_origins,
|
ssl_key, server_host, server_port, cors_origins,
|
||||||
use_x_forwarded_for, trusted_networks,
|
use_x_forwarded_for, trusted_networks,
|
||||||
login_threshold, is_ban_enabled):
|
login_threshold, is_ban_enabled):
|
||||||
"""Initialize the WSGI Home Assistant server."""
|
"""Initialize the HTTP Home Assistant server."""
|
||||||
middlewares = [auth_middleware, staticresource_middleware]
|
app = self.app = web.Application(
|
||||||
|
middlewares=[staticresource_middleware])
|
||||||
|
|
||||||
|
# This order matters
|
||||||
|
setup_real_ip(app, use_x_forwarded_for)
|
||||||
|
|
||||||
if is_ban_enabled:
|
if is_ban_enabled:
|
||||||
middlewares.insert(0, ban_middleware)
|
setup_bans(hass, app, login_threshold)
|
||||||
|
|
||||||
self.app = web.Application(middlewares=middlewares)
|
setup_auth(app, trusted_networks, api_password)
|
||||||
self.app['hass'] = hass
|
|
||||||
self.app[KEY_USE_X_FORWARDED_FOR] = use_x_forwarded_for
|
if cors_origins:
|
||||||
self.app[KEY_TRUSTED_NETWORKS] = trusted_networks
|
setup_cors(app, cors_origins)
|
||||||
self.app[KEY_BANS_ENABLED] = is_ban_enabled
|
|
||||||
self.app[KEY_LOGIN_THRESHOLD] = login_threshold
|
app['hass'] = hass
|
||||||
|
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.api_password = api_password
|
self.api_password = api_password
|
||||||
|
@ -199,21 +196,10 @@ class HomeAssistantWSGI(object):
|
||||||
self.ssl_key = ssl_key
|
self.ssl_key = ssl_key
|
||||||
self.server_host = server_host
|
self.server_host = server_host
|
||||||
self.server_port = server_port
|
self.server_port = server_port
|
||||||
|
self.is_ban_enabled = is_ban_enabled
|
||||||
self._handler = None
|
self._handler = None
|
||||||
self.server = None
|
self.server = None
|
||||||
|
|
||||||
if cors_origins:
|
|
||||||
import aiohttp_cors
|
|
||||||
|
|
||||||
self.cors = aiohttp_cors.setup(self.app, defaults={
|
|
||||||
host: aiohttp_cors.ResourceOptions(
|
|
||||||
allow_headers=ALLOWED_CORS_HEADERS,
|
|
||||||
allow_methods='*',
|
|
||||||
) for host in cors_origins
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
self.cors = None
|
|
||||||
|
|
||||||
def register_view(self, view):
|
def register_view(self, view):
|
||||||
"""Register a view with the WSGI server.
|
"""Register a view with the WSGI server.
|
||||||
|
|
||||||
|
@ -292,15 +278,7 @@ class HomeAssistantWSGI(object):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Start the WSGI server."""
|
"""Start the WSGI server."""
|
||||||
cors_added = set()
|
yield from self.app.startup()
|
||||||
if self.cors is not None:
|
|
||||||
for route in list(self.app.router.routes()):
|
|
||||||
if hasattr(route, 'resource'):
|
|
||||||
route = route.resource
|
|
||||||
if route in cors_added:
|
|
||||||
continue
|
|
||||||
self.cors.add(route)
|
|
||||||
cors_added.add(route)
|
|
||||||
|
|
||||||
if self.ssl_certificate:
|
if self.ssl_certificate:
|
||||||
try:
|
try:
|
||||||
|
@ -420,7 +398,7 @@ def request_handler_factory(view, handler):
|
||||||
raise HTTPUnauthorized()
|
raise HTTPUnauthorized()
|
||||||
|
|
||||||
_LOGGER.info('Serving %s to %s (auth: %s)',
|
_LOGGER.info('Serving %s to %s (auth: %s)',
|
||||||
request.path, get_real_ip(request), authenticated)
|
request.path, request.get(KEY_REAL_IP), authenticated)
|
||||||
|
|
||||||
result = handler(request, **request.match_info)
|
result = handler(request, **request.match_info)
|
||||||
|
|
||||||
|
|
|
@ -7,55 +7,66 @@ import logging
|
||||||
from aiohttp import hdrs
|
from aiohttp import hdrs
|
||||||
from aiohttp.web import middleware
|
from aiohttp.web import middleware
|
||||||
|
|
||||||
|
from homeassistant.core import callback
|
||||||
from homeassistant.const import HTTP_HEADER_HA_AUTH
|
from homeassistant.const import HTTP_HEADER_HA_AUTH
|
||||||
from .util import get_real_ip
|
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
|
||||||
from .const import KEY_TRUSTED_NETWORKS, KEY_AUTHENTICATED
|
|
||||||
|
|
||||||
DATA_API_PASSWORD = 'api_password'
|
DATA_API_PASSWORD = 'api_password'
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@middleware
|
@callback
|
||||||
@asyncio.coroutine
|
def setup_auth(app, trusted_networks, api_password):
|
||||||
def auth_middleware(request, handler):
|
"""Create auth middleware for the app."""
|
||||||
"""Authenticate as middleware."""
|
@middleware
|
||||||
# If no password set, just always set authenticated=True
|
@asyncio.coroutine
|
||||||
if request.app['hass'].http.api_password is None:
|
def auth_middleware(request, handler):
|
||||||
request[KEY_AUTHENTICATED] = True
|
"""Authenticate as middleware."""
|
||||||
|
# If no password set, just always set authenticated=True
|
||||||
|
if api_password is None:
|
||||||
|
request[KEY_AUTHENTICATED] = True
|
||||||
|
return (yield from handler(request))
|
||||||
|
|
||||||
|
# Check authentication
|
||||||
|
authenticated = False
|
||||||
|
|
||||||
|
if (HTTP_HEADER_HA_AUTH in request.headers and
|
||||||
|
hmac.compare_digest(
|
||||||
|
api_password, request.headers[HTTP_HEADER_HA_AUTH])):
|
||||||
|
# A valid auth header has been set
|
||||||
|
authenticated = True
|
||||||
|
|
||||||
|
elif (DATA_API_PASSWORD in request.query and
|
||||||
|
hmac.compare_digest(api_password,
|
||||||
|
request.query[DATA_API_PASSWORD])):
|
||||||
|
authenticated = True
|
||||||
|
|
||||||
|
elif (hdrs.AUTHORIZATION in request.headers and
|
||||||
|
validate_authorization_header(api_password, request)):
|
||||||
|
authenticated = True
|
||||||
|
|
||||||
|
elif _is_trusted_ip(request, trusted_networks):
|
||||||
|
authenticated = True
|
||||||
|
|
||||||
|
request[KEY_AUTHENTICATED] = authenticated
|
||||||
return (yield from handler(request))
|
return (yield from handler(request))
|
||||||
|
|
||||||
# Check authentication
|
@asyncio.coroutine
|
||||||
authenticated = False
|
def auth_startup(app):
|
||||||
|
"""Initialize auth middleware when app starts up."""
|
||||||
|
app.middlewares.append(auth_middleware)
|
||||||
|
|
||||||
if (HTTP_HEADER_HA_AUTH in request.headers and
|
app.on_startup.append(auth_startup)
|
||||||
validate_password(
|
|
||||||
request, request.headers[HTTP_HEADER_HA_AUTH])):
|
|
||||||
# A valid auth header has been set
|
|
||||||
authenticated = True
|
|
||||||
|
|
||||||
elif (DATA_API_PASSWORD in request.query and
|
|
||||||
validate_password(request, request.query[DATA_API_PASSWORD])):
|
|
||||||
authenticated = True
|
|
||||||
|
|
||||||
elif (hdrs.AUTHORIZATION in request.headers and
|
|
||||||
validate_authorization_header(request)):
|
|
||||||
authenticated = True
|
|
||||||
|
|
||||||
elif is_trusted_ip(request):
|
|
||||||
authenticated = True
|
|
||||||
|
|
||||||
request[KEY_AUTHENTICATED] = authenticated
|
|
||||||
return (yield from handler(request))
|
|
||||||
|
|
||||||
|
|
||||||
def is_trusted_ip(request):
|
def _is_trusted_ip(request, trusted_networks):
|
||||||
"""Test if request is from a trusted ip."""
|
"""Test if request is from a trusted ip."""
|
||||||
ip_addr = get_real_ip(request)
|
ip_addr = request[KEY_REAL_IP]
|
||||||
|
|
||||||
return ip_addr and any(
|
return any(
|
||||||
ip_addr in trusted_network for trusted_network
|
ip_addr in trusted_network for trusted_network
|
||||||
in request.app[KEY_TRUSTED_NETWORKS])
|
in trusted_networks)
|
||||||
|
|
||||||
|
|
||||||
def validate_password(request, api_password):
|
def validate_password(request, api_password):
|
||||||
|
@ -64,7 +75,7 @@ def validate_password(request, api_password):
|
||||||
api_password, request.app['hass'].http.api_password)
|
api_password, request.app['hass'].http.api_password)
|
||||||
|
|
||||||
|
|
||||||
def validate_authorization_header(request):
|
def validate_authorization_header(api_password, request):
|
||||||
"""Test an authorization header if valid password."""
|
"""Test an authorization header if valid password."""
|
||||||
if hdrs.AUTHORIZATION not in request.headers:
|
if hdrs.AUTHORIZATION not in request.headers:
|
||||||
return False
|
return False
|
||||||
|
@ -80,4 +91,4 @@ def validate_authorization_header(request):
|
||||||
if username != 'homeassistant':
|
if username != 'homeassistant':
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return validate_password(request, password)
|
return hmac.compare_digest(api_password, password)
|
||||||
|
|
|
@ -10,18 +10,20 @@ from aiohttp.web import middleware
|
||||||
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
|
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.core import callback
|
||||||
from homeassistant.components import persistent_notification
|
from homeassistant.components import persistent_notification
|
||||||
from homeassistant.config import load_yaml_config_file
|
from homeassistant.config import load_yaml_config_file
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.util.yaml import dump
|
from homeassistant.util.yaml import dump
|
||||||
from .const import (
|
from .const import KEY_REAL_IP
|
||||||
KEY_BANS_ENABLED, KEY_BANNED_IPS, KEY_LOGIN_THRESHOLD,
|
|
||||||
KEY_FAILED_LOGIN_ATTEMPTS)
|
|
||||||
from .util import get_real_ip
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
KEY_BANNED_IPS = 'ha_banned_ips'
|
||||||
|
KEY_FAILED_LOGIN_ATTEMPTS = 'ha_failed_login_attempts'
|
||||||
|
KEY_LOGIN_THRESHOLD = 'ha_login_threshold'
|
||||||
|
|
||||||
NOTIFICATION_ID_BAN = 'ip-ban'
|
NOTIFICATION_ID_BAN = 'ip-ban'
|
||||||
NOTIFICATION_ID_LOGIN = 'http-login'
|
NOTIFICATION_ID_LOGIN = 'http-login'
|
||||||
|
|
||||||
|
@ -33,21 +35,31 @@ SCHEMA_IP_BAN_ENTRY = vol.Schema({
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def setup_bans(hass, app, login_threshold):
|
||||||
|
"""Create IP Ban middleware for the app."""
|
||||||
|
@asyncio.coroutine
|
||||||
|
def ban_startup(app):
|
||||||
|
"""Initialize bans when app starts up."""
|
||||||
|
app.middlewares.append(ban_middleware)
|
||||||
|
app[KEY_BANNED_IPS] = yield from hass.async_add_job(
|
||||||
|
load_ip_bans_config, hass.config.path(IP_BANS_FILE))
|
||||||
|
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
|
||||||
|
app[KEY_LOGIN_THRESHOLD] = login_threshold
|
||||||
|
|
||||||
|
app.on_startup.append(ban_startup)
|
||||||
|
|
||||||
|
|
||||||
@middleware
|
@middleware
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def ban_middleware(request, handler):
|
def ban_middleware(request, handler):
|
||||||
"""IP Ban middleware."""
|
"""IP Ban middleware."""
|
||||||
if not request.app[KEY_BANS_ENABLED]:
|
if KEY_BANNED_IPS not in request.app:
|
||||||
|
_LOGGER.error('IP Ban middleware loaded but banned IPs not loaded')
|
||||||
return (yield from handler(request))
|
return (yield from handler(request))
|
||||||
|
|
||||||
if KEY_BANNED_IPS not in request.app:
|
|
||||||
hass = request.app['hass']
|
|
||||||
request.app[KEY_BANNED_IPS] = yield from hass.async_add_job(
|
|
||||||
load_ip_bans_config, hass.config.path(IP_BANS_FILE))
|
|
||||||
|
|
||||||
# Verify if IP is not banned
|
# Verify if IP is not banned
|
||||||
ip_address_ = get_real_ip(request)
|
ip_address_ = request[KEY_REAL_IP]
|
||||||
|
|
||||||
is_banned = any(ip_ban.ip_address == ip_address_
|
is_banned = any(ip_ban.ip_address == ip_address_
|
||||||
for ip_ban in request.app[KEY_BANNED_IPS])
|
for ip_ban in request.app[KEY_BANNED_IPS])
|
||||||
|
|
||||||
|
@ -64,7 +76,7 @@ def ban_middleware(request, handler):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def process_wrong_login(request):
|
def process_wrong_login(request):
|
||||||
"""Process a wrong login attempt."""
|
"""Process a wrong login attempt."""
|
||||||
remote_addr = get_real_ip(request)
|
remote_addr = request[KEY_REAL_IP]
|
||||||
|
|
||||||
msg = ('Login attempt or request with invalid authentication '
|
msg = ('Login attempt or request with invalid authentication '
|
||||||
'from {}'.format(remote_addr))
|
'from {}'.format(remote_addr))
|
||||||
|
@ -73,13 +85,11 @@ def process_wrong_login(request):
|
||||||
request.app['hass'], msg, 'Login attempt failed',
|
request.app['hass'], msg, 'Login attempt failed',
|
||||||
NOTIFICATION_ID_LOGIN)
|
NOTIFICATION_ID_LOGIN)
|
||||||
|
|
||||||
if (not request.app[KEY_BANS_ENABLED] or
|
# Check if ban middleware is loaded
|
||||||
|
if (KEY_BANNED_IPS not in request.app or
|
||||||
request.app[KEY_LOGIN_THRESHOLD] < 1):
|
request.app[KEY_LOGIN_THRESHOLD] < 1):
|
||||||
return
|
return
|
||||||
|
|
||||||
if KEY_FAILED_LOGIN_ATTEMPTS not in request.app:
|
|
||||||
request.app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
|
|
||||||
|
|
||||||
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1
|
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1
|
||||||
|
|
||||||
if (request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] >
|
if (request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] >
|
||||||
|
|
|
@ -1,11 +1,3 @@
|
||||||
"""HTTP specific constants."""
|
"""HTTP specific constants."""
|
||||||
KEY_AUTHENTICATED = 'ha_authenticated'
|
KEY_AUTHENTICATED = 'ha_authenticated'
|
||||||
KEY_USE_X_FORWARDED_FOR = 'ha_use_x_forwarded_for'
|
|
||||||
KEY_TRUSTED_NETWORKS = 'ha_trusted_networks'
|
|
||||||
KEY_REAL_IP = 'ha_real_ip'
|
KEY_REAL_IP = 'ha_real_ip'
|
||||||
KEY_BANS_ENABLED = 'ha_bans_enabled'
|
|
||||||
KEY_BANNED_IPS = 'ha_banned_ips'
|
|
||||||
KEY_FAILED_LOGIN_ATTEMPTS = 'ha_failed_login_attempts'
|
|
||||||
KEY_LOGIN_THRESHOLD = 'ha_login_threshold'
|
|
||||||
|
|
||||||
HTTP_HEADER_X_FORWARDED_FOR = 'X-Forwarded-For'
|
|
||||||
|
|
43
homeassistant/components/http/cors.py
Normal file
43
homeassistant/components/http/cors.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
"""Provide cors support for the HTTP component."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE
|
||||||
|
|
||||||
|
from homeassistant.const import (
|
||||||
|
HTTP_HEADER_X_REQUESTED_WITH, HTTP_HEADER_HA_AUTH)
|
||||||
|
|
||||||
|
|
||||||
|
from homeassistant.core import callback
|
||||||
|
|
||||||
|
|
||||||
|
ALLOWED_CORS_HEADERS = [
|
||||||
|
ORIGIN, ACCEPT, HTTP_HEADER_X_REQUESTED_WITH, CONTENT_TYPE,
|
||||||
|
HTTP_HEADER_HA_AUTH]
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def setup_cors(app, origins):
|
||||||
|
"""Setup cors."""
|
||||||
|
import aiohttp_cors
|
||||||
|
|
||||||
|
cors = aiohttp_cors.setup(app, defaults={
|
||||||
|
host: aiohttp_cors.ResourceOptions(
|
||||||
|
allow_headers=ALLOWED_CORS_HEADERS,
|
||||||
|
allow_methods='*',
|
||||||
|
) for host in origins
|
||||||
|
})
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def cors_startup(app):
|
||||||
|
"""Initialize cors when app starts up."""
|
||||||
|
cors_added = set()
|
||||||
|
|
||||||
|
for route in list(app.router.routes()):
|
||||||
|
if hasattr(route, 'resource'):
|
||||||
|
route = route.resource
|
||||||
|
if route in cors_added:
|
||||||
|
continue
|
||||||
|
cors.add(route)
|
||||||
|
cors_added.add(route)
|
||||||
|
|
||||||
|
app.on_startup.append(cors_startup)
|
35
homeassistant/components/http/real_ip.py
Normal file
35
homeassistant/components/http/real_ip.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
"""Middleware to fetch real IP."""
|
||||||
|
import asyncio
|
||||||
|
from ipaddress import ip_address
|
||||||
|
|
||||||
|
from aiohttp.web import middleware
|
||||||
|
from aiohttp.hdrs import X_FORWARDED_FOR
|
||||||
|
|
||||||
|
from homeassistant.core import callback
|
||||||
|
|
||||||
|
from .const import KEY_REAL_IP
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def setup_real_ip(app, use_x_forwarded_for):
|
||||||
|
"""Create IP Ban middleware for the app."""
|
||||||
|
@middleware
|
||||||
|
@asyncio.coroutine
|
||||||
|
def real_ip_middleware(request, handler):
|
||||||
|
"""Real IP middleware."""
|
||||||
|
if (use_x_forwarded_for and
|
||||||
|
X_FORWARDED_FOR in request.headers):
|
||||||
|
request[KEY_REAL_IP] = ip_address(
|
||||||
|
request.headers.get(X_FORWARDED_FOR).split(',')[0])
|
||||||
|
else:
|
||||||
|
request[KEY_REAL_IP] = \
|
||||||
|
ip_address(request.transport.get_extra_info('peername')[0])
|
||||||
|
|
||||||
|
return (yield from handler(request))
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def app_startup(app):
|
||||||
|
"""Initialize bans when app starts up."""
|
||||||
|
app.middlewares.append(real_ip_middleware)
|
||||||
|
|
||||||
|
app.on_startup.append(app_startup)
|
|
@ -1,25 +0,0 @@
|
||||||
"""HTTP utilities."""
|
|
||||||
from ipaddress import ip_address
|
|
||||||
|
|
||||||
from .const import (
|
|
||||||
KEY_REAL_IP, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
|
|
||||||
|
|
||||||
|
|
||||||
def get_real_ip(request):
|
|
||||||
"""Get IP address of client."""
|
|
||||||
if KEY_REAL_IP in request:
|
|
||||||
return request[KEY_REAL_IP]
|
|
||||||
|
|
||||||
if (request.app.get(KEY_USE_X_FORWARDED_FOR) and
|
|
||||||
HTTP_HEADER_X_FORWARDED_FOR in request.headers):
|
|
||||||
request[KEY_REAL_IP] = ip_address(
|
|
||||||
request.headers.get(HTTP_HEADER_X_FORWARDED_FOR).split(',')[0])
|
|
||||||
else:
|
|
||||||
peername = request.transport.get_extra_info('peername')
|
|
||||||
|
|
||||||
if peername:
|
|
||||||
request[KEY_REAL_IP] = ip_address(peername[0])
|
|
||||||
else:
|
|
||||||
request[KEY_REAL_IP] = None
|
|
||||||
|
|
||||||
return request[KEY_REAL_IP]
|
|
|
@ -12,7 +12,7 @@ import logging
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.components.http.util import get_real_ip
|
from homeassistant.components.http.const import KEY_REAL_IP
|
||||||
from homeassistant.components.telegram_bot import (
|
from homeassistant.components.telegram_bot import (
|
||||||
CONF_ALLOWED_CHAT_IDS, BaseTelegramBotEntity, PLATFORM_SCHEMA)
|
CONF_ALLOWED_CHAT_IDS, BaseTelegramBotEntity, PLATFORM_SCHEMA)
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
|
@ -110,7 +110,7 @@ class BotPushReceiver(HomeAssistantView, BaseTelegramBotEntity):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
"""Accept the POST from telegram."""
|
"""Accept the POST from telegram."""
|
||||||
real_ip = get_real_ip(request)
|
real_ip = request[KEY_REAL_IP]
|
||||||
if not any(real_ip in net for net in self.trusted_networks):
|
if not any(real_ip in net for net in self.trusted_networks):
|
||||||
_LOGGER.warning("Access denied from %s", real_ip)
|
_LOGGER.warning("Access denied from %s", real_ip)
|
||||||
return self.json_message('Access denied', HTTP_UNAUTHORIZED)
|
return self.json_message('Access denied', HTTP_UNAUTHORIZED)
|
||||||
|
|
|
@ -9,8 +9,6 @@ import logging
|
||||||
import threading
|
import threading
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
from aiohttp import web
|
|
||||||
|
|
||||||
from homeassistant import core as ha, loader
|
from homeassistant import core as ha, loader
|
||||||
from homeassistant.setup import setup_component, async_setup_component
|
from homeassistant.setup import setup_component, async_setup_component
|
||||||
from homeassistant.config import async_process_component_config
|
from homeassistant.config import async_process_component_config
|
||||||
|
@ -25,9 +23,6 @@ from homeassistant.const import (
|
||||||
EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE,
|
EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE,
|
||||||
ATTR_DISCOVERED, SERVER_PORT, EVENT_HOMEASSISTANT_CLOSE)
|
ATTR_DISCOVERED, SERVER_PORT, EVENT_HOMEASSISTANT_CLOSE)
|
||||||
from homeassistant.components import mqtt, recorder
|
from homeassistant.components import mqtt, recorder
|
||||||
from homeassistant.components.http.auth import auth_middleware
|
|
||||||
from homeassistant.components.http.const import (
|
|
||||||
KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED, KEY_TRUSTED_NETWORKS)
|
|
||||||
from homeassistant.util.async import (
|
from homeassistant.util.async import (
|
||||||
run_callback_threadsafe, run_coroutine_threadsafe)
|
run_callback_threadsafe, run_coroutine_threadsafe)
|
||||||
|
|
||||||
|
@ -262,35 +257,6 @@ def mock_state_change_event(hass, new_state, old_state=None):
|
||||||
hass.bus.fire(EVENT_STATE_CHANGED, event_data)
|
hass.bus.fire(EVENT_STATE_CHANGED, event_data)
|
||||||
|
|
||||||
|
|
||||||
def mock_http_component(hass, api_password=None):
|
|
||||||
"""Mock the HTTP component."""
|
|
||||||
hass.http = MagicMock(api_password=api_password)
|
|
||||||
mock_component(hass, 'http')
|
|
||||||
hass.http.views = {}
|
|
||||||
|
|
||||||
def mock_register_view(view):
|
|
||||||
"""Store registered view."""
|
|
||||||
if isinstance(view, type):
|
|
||||||
# Instantiate the view, if needed
|
|
||||||
view = view()
|
|
||||||
|
|
||||||
hass.http.views[view.name] = view
|
|
||||||
|
|
||||||
hass.http.register_view = mock_register_view
|
|
||||||
|
|
||||||
|
|
||||||
def mock_http_component_app(hass, api_password=None):
|
|
||||||
"""Create an aiohttp.web.Application instance for testing."""
|
|
||||||
if 'http' not in hass.config.components:
|
|
||||||
mock_http_component(hass, api_password)
|
|
||||||
app = web.Application(middlewares=[auth_middleware])
|
|
||||||
app['hass'] = hass
|
|
||||||
app[KEY_USE_X_FORWARDED_FOR] = False
|
|
||||||
app[KEY_BANS_ENABLED] = False
|
|
||||||
app[KEY_TRUSTED_NETWORKS] = []
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_mock_mqtt_component(hass, config=None):
|
def async_mock_mqtt_component(hass, config=None):
|
||||||
"""Mock the MQTT component."""
|
"""Mock the MQTT component."""
|
||||||
|
|
|
@ -9,7 +9,7 @@ from uvcclient import nvr
|
||||||
|
|
||||||
from homeassistant.setup import setup_component
|
from homeassistant.setup import setup_component
|
||||||
from homeassistant.components.camera import uvc
|
from homeassistant.components.camera import uvc
|
||||||
from tests.common import get_test_home_assistant, mock_http_component
|
from tests.common import get_test_home_assistant
|
||||||
|
|
||||||
|
|
||||||
class TestUVCSetup(unittest.TestCase):
|
class TestUVCSetup(unittest.TestCase):
|
||||||
|
@ -18,7 +18,6 @@ class TestUVCSetup(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Setup things to be run when tests are started."""
|
"""Setup things to be run when tests are started."""
|
||||||
self.hass = get_test_home_assistant()
|
self.hass = get_test_home_assistant()
|
||||||
mock_http_component(self.hass)
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
"""Stop everything that was started."""
|
"""Stop everything that was started."""
|
||||||
|
|
|
@ -14,7 +14,7 @@ def test_setup_check_env_prevents_load(hass, loop):
|
||||||
with patch.dict(os.environ, clear=True), \
|
with patch.dict(os.environ, clear=True), \
|
||||||
patch.object(config, 'SECTIONS', ['hassbian']), \
|
patch.object(config, 'SECTIONS', ['hassbian']), \
|
||||||
patch('homeassistant.components.http.'
|
patch('homeassistant.components.http.'
|
||||||
'HomeAssistantWSGI.register_view') as reg_view:
|
'HomeAssistantHTTP.register_view') as reg_view:
|
||||||
loop.run_until_complete(async_setup_component(hass, 'config', {}))
|
loop.run_until_complete(async_setup_component(hass, 'config', {}))
|
||||||
assert 'config' in hass.config.components
|
assert 'config' in hass.config.components
|
||||||
assert reg_view.called is False
|
assert reg_view.called is False
|
||||||
|
@ -25,7 +25,7 @@ def test_setup_check_env_works(hass, loop):
|
||||||
with patch.dict(os.environ, {'FORCE_HASSBIAN': '1'}), \
|
with patch.dict(os.environ, {'FORCE_HASSBIAN': '1'}), \
|
||||||
patch.object(config, 'SECTIONS', ['hassbian']), \
|
patch.object(config, 'SECTIONS', ['hassbian']), \
|
||||||
patch('homeassistant.components.http.'
|
patch('homeassistant.components.http.'
|
||||||
'HomeAssistantWSGI.register_view') as reg_view:
|
'HomeAssistantHTTP.register_view') as reg_view:
|
||||||
loop.run_until_complete(async_setup_component(hass, 'config', {}))
|
loop.run_until_complete(async_setup_component(hass, 'config', {}))
|
||||||
assert 'config' in hass.config.components
|
assert 'config' in hass.config.components
|
||||||
assert len(reg_view.mock_calls) == 2
|
assert len(reg_view.mock_calls) == 2
|
||||||
|
|
|
@ -2,19 +2,11 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from homeassistant.const import EVENT_COMPONENT_LOADED
|
from homeassistant.const import EVENT_COMPONENT_LOADED
|
||||||
from homeassistant.setup import async_setup_component, ATTR_COMPONENT
|
from homeassistant.setup import async_setup_component, ATTR_COMPONENT
|
||||||
from homeassistant.components import config
|
from homeassistant.components import config
|
||||||
|
|
||||||
from tests.common import mock_http_component, mock_coro, mock_component
|
from tests.common import mock_coro, mock_component
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def stub_http(hass):
|
|
||||||
"""Stub the HTTP component."""
|
|
||||||
mock_http_component(hass)
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
|
|
@ -3,28 +3,30 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from homeassistant.bootstrap import async_setup_component
|
from homeassistant.bootstrap import async_setup_component
|
||||||
from homeassistant.components import config
|
from homeassistant.components import config
|
||||||
|
|
||||||
from homeassistant.components.zwave import DATA_NETWORK, const
|
from homeassistant.components.zwave import DATA_NETWORK, const
|
||||||
from homeassistant.components.config.zwave import (
|
|
||||||
ZWaveNodeValueView, ZWaveNodeGroupView, ZWaveNodeConfigView,
|
|
||||||
ZWaveUserCodeView, ZWaveConfigWriteView)
|
|
||||||
from tests.common import mock_http_component_app
|
|
||||||
from tests.mock.zwave import MockNode, MockValue, MockEntityValues
|
from tests.mock.zwave import MockNode, MockValue, MockEntityValues
|
||||||
|
|
||||||
|
|
||||||
VIEW_NAME = 'api:config:zwave:device_config'
|
VIEW_NAME = 'api:config:zwave:device_config'
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@pytest.fixture
|
||||||
def test_get_device_config(hass, test_client):
|
def client(loop, hass, test_client):
|
||||||
"""Test getting device config."""
|
"""Client to communicate with Z-Wave config views."""
|
||||||
with patch.object(config, 'SECTIONS', ['zwave']):
|
with patch.object(config, 'SECTIONS', ['zwave']):
|
||||||
yield from async_setup_component(hass, 'config', {})
|
loop.run_until_complete(async_setup_component(hass, 'config', {}))
|
||||||
|
|
||||||
client = yield from test_client(hass.http.app)
|
return loop.run_until_complete(test_client(hass.http.app))
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_get_device_config(client):
|
||||||
|
"""Test getting device config."""
|
||||||
def mock_read(path):
|
def mock_read(path):
|
||||||
"""Mock reading data."""
|
"""Mock reading data."""
|
||||||
return {
|
return {
|
||||||
|
@ -47,13 +49,8 @@ def test_get_device_config(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_update_device_config(hass, test_client):
|
def test_update_device_config(client):
|
||||||
"""Test updating device config."""
|
"""Test updating device config."""
|
||||||
with patch.object(config, 'SECTIONS', ['zwave']):
|
|
||||||
yield from async_setup_component(hass, 'config', {})
|
|
||||||
|
|
||||||
client = yield from test_client(hass.http.app)
|
|
||||||
|
|
||||||
orig_data = {
|
orig_data = {
|
||||||
'hello.beer': {
|
'hello.beer': {
|
||||||
'ignored': True,
|
'ignored': True,
|
||||||
|
@ -90,13 +87,8 @@ def test_update_device_config(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_update_device_config_invalid_key(hass, test_client):
|
def test_update_device_config_invalid_key(client):
|
||||||
"""Test updating device config."""
|
"""Test updating device config."""
|
||||||
with patch.object(config, 'SECTIONS', ['zwave']):
|
|
||||||
yield from async_setup_component(hass, 'config', {})
|
|
||||||
|
|
||||||
client = yield from test_client(hass.http.app)
|
|
||||||
|
|
||||||
resp = yield from client.post(
|
resp = yield from client.post(
|
||||||
'/api/config/zwave/device_config/invalid_entity', data=json.dumps({
|
'/api/config/zwave/device_config/invalid_entity', data=json.dumps({
|
||||||
'polling_intensity': 2
|
'polling_intensity': 2
|
||||||
|
@ -106,13 +98,8 @@ def test_update_device_config_invalid_key(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_update_device_config_invalid_data(hass, test_client):
|
def test_update_device_config_invalid_data(client):
|
||||||
"""Test updating device config."""
|
"""Test updating device config."""
|
||||||
with patch.object(config, 'SECTIONS', ['zwave']):
|
|
||||||
yield from async_setup_component(hass, 'config', {})
|
|
||||||
|
|
||||||
client = yield from test_client(hass.http.app)
|
|
||||||
|
|
||||||
resp = yield from client.post(
|
resp = yield from client.post(
|
||||||
'/api/config/zwave/device_config/hello.beer', data=json.dumps({
|
'/api/config/zwave/device_config/hello.beer', data=json.dumps({
|
||||||
'invalid_option': 2
|
'invalid_option': 2
|
||||||
|
@ -122,13 +109,8 @@ def test_update_device_config_invalid_data(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_update_device_config_invalid_json(hass, test_client):
|
def test_update_device_config_invalid_json(client):
|
||||||
"""Test updating device config."""
|
"""Test updating device config."""
|
||||||
with patch.object(config, 'SECTIONS', ['zwave']):
|
|
||||||
yield from async_setup_component(hass, 'config', {})
|
|
||||||
|
|
||||||
client = yield from test_client(hass.http.app)
|
|
||||||
|
|
||||||
resp = yield from client.post(
|
resp = yield from client.post(
|
||||||
'/api/config/zwave/device_config/hello.beer', data='not json')
|
'/api/config/zwave/device_config/hello.beer', data='not json')
|
||||||
|
|
||||||
|
@ -136,11 +118,8 @@ def test_update_device_config_invalid_json(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_get_values(hass, test_client):
|
def test_get_values(hass, client):
|
||||||
"""Test getting values on node."""
|
"""Test getting values on node."""
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
ZWaveNodeValueView().register(app.router)
|
|
||||||
|
|
||||||
node = MockNode(node_id=1)
|
node = MockNode(node_id=1)
|
||||||
value = MockValue(value_id=123456, node=node, label='Test Label',
|
value = MockValue(value_id=123456, node=node, label='Test Label',
|
||||||
instance=1, index=2, poll_intensity=4)
|
instance=1, index=2, poll_intensity=4)
|
||||||
|
@ -150,8 +129,6 @@ def test_get_values(hass, test_client):
|
||||||
values2 = MockEntityValues(primary=value2)
|
values2 = MockEntityValues(primary=value2)
|
||||||
hass.data[const.DATA_ENTITY_VALUES] = [values, values2]
|
hass.data[const.DATA_ENTITY_VALUES] = [values, values2]
|
||||||
|
|
||||||
client = yield from test_client(app)
|
|
||||||
|
|
||||||
resp = yield from client.get('/api/zwave/values/1')
|
resp = yield from client.get('/api/zwave/values/1')
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
@ -168,11 +145,8 @@ def test_get_values(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_get_groups(hass, test_client):
|
def test_get_groups(hass, client):
|
||||||
"""Test getting groupdata on node."""
|
"""Test getting groupdata on node."""
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
ZWaveNodeGroupView().register(app.router)
|
|
||||||
|
|
||||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||||
node = MockNode(node_id=2)
|
node = MockNode(node_id=2)
|
||||||
node.groups.associations = 'assoc'
|
node.groups.associations = 'assoc'
|
||||||
|
@ -182,8 +156,6 @@ def test_get_groups(hass, test_client):
|
||||||
node.groups = {1: node.groups}
|
node.groups = {1: node.groups}
|
||||||
network.nodes = {2: node}
|
network.nodes = {2: node}
|
||||||
|
|
||||||
client = yield from test_client(app)
|
|
||||||
|
|
||||||
resp = yield from client.get('/api/zwave/groups/2')
|
resp = yield from client.get('/api/zwave/groups/2')
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
@ -200,18 +172,13 @@ def test_get_groups(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_get_groups_nogroups(hass, test_client):
|
def test_get_groups_nogroups(hass, client):
|
||||||
"""Test getting groupdata on node with no groups."""
|
"""Test getting groupdata on node with no groups."""
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
ZWaveNodeGroupView().register(app.router)
|
|
||||||
|
|
||||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||||
node = MockNode(node_id=2)
|
node = MockNode(node_id=2)
|
||||||
|
|
||||||
network.nodes = {2: node}
|
network.nodes = {2: node}
|
||||||
|
|
||||||
client = yield from test_client(app)
|
|
||||||
|
|
||||||
resp = yield from client.get('/api/zwave/groups/2')
|
resp = yield from client.get('/api/zwave/groups/2')
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
@ -221,16 +188,11 @@ def test_get_groups_nogroups(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_get_groups_nonode(hass, test_client):
|
def test_get_groups_nonode(hass, client):
|
||||||
"""Test getting groupdata on nonexisting node."""
|
"""Test getting groupdata on nonexisting node."""
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
ZWaveNodeGroupView().register(app.router)
|
|
||||||
|
|
||||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||||
network.nodes = {1: 1, 5: 5}
|
network.nodes = {1: 1, 5: 5}
|
||||||
|
|
||||||
client = yield from test_client(app)
|
|
||||||
|
|
||||||
resp = yield from client.get('/api/zwave/groups/2')
|
resp = yield from client.get('/api/zwave/groups/2')
|
||||||
|
|
||||||
assert resp.status == 404
|
assert resp.status == 404
|
||||||
|
@ -240,11 +202,8 @@ def test_get_groups_nonode(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_get_config(hass, test_client):
|
def test_get_config(hass, client):
|
||||||
"""Test getting config on node."""
|
"""Test getting config on node."""
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
ZWaveNodeConfigView().register(app.router)
|
|
||||||
|
|
||||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||||
node = MockNode(node_id=2)
|
node = MockNode(node_id=2)
|
||||||
value = MockValue(
|
value = MockValue(
|
||||||
|
@ -261,8 +220,6 @@ def test_get_config(hass, test_client):
|
||||||
network.nodes = {2: node}
|
network.nodes = {2: node}
|
||||||
node.get_values.return_value = node.values
|
node.get_values.return_value = node.values
|
||||||
|
|
||||||
client = yield from test_client(app)
|
|
||||||
|
|
||||||
resp = yield from client.get('/api/zwave/config/2')
|
resp = yield from client.get('/api/zwave/config/2')
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
@ -278,19 +235,14 @@ def test_get_config(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_get_config_noconfig_node(hass, test_client):
|
def test_get_config_noconfig_node(hass, client):
|
||||||
"""Test getting config on node without config."""
|
"""Test getting config on node without config."""
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
ZWaveNodeConfigView().register(app.router)
|
|
||||||
|
|
||||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||||
node = MockNode(node_id=2)
|
node = MockNode(node_id=2)
|
||||||
|
|
||||||
network.nodes = {2: node}
|
network.nodes = {2: node}
|
||||||
node.get_values.return_value = node.values
|
node.get_values.return_value = node.values
|
||||||
|
|
||||||
client = yield from test_client(app)
|
|
||||||
|
|
||||||
resp = yield from client.get('/api/zwave/config/2')
|
resp = yield from client.get('/api/zwave/config/2')
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
@ -300,16 +252,11 @@ def test_get_config_noconfig_node(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_get_config_nonode(hass, test_client):
|
def test_get_config_nonode(hass, client):
|
||||||
"""Test getting config on nonexisting node."""
|
"""Test getting config on nonexisting node."""
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
ZWaveNodeConfigView().register(app.router)
|
|
||||||
|
|
||||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||||
network.nodes = {1: 1, 5: 5}
|
network.nodes = {1: 1, 5: 5}
|
||||||
|
|
||||||
client = yield from test_client(app)
|
|
||||||
|
|
||||||
resp = yield from client.get('/api/zwave/config/2')
|
resp = yield from client.get('/api/zwave/config/2')
|
||||||
|
|
||||||
assert resp.status == 404
|
assert resp.status == 404
|
||||||
|
@ -319,16 +266,11 @@ def test_get_config_nonode(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_get_usercodes_nonode(hass, test_client):
|
def test_get_usercodes_nonode(hass, client):
|
||||||
"""Test getting usercodes on nonexisting node."""
|
"""Test getting usercodes on nonexisting node."""
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
ZWaveUserCodeView().register(app.router)
|
|
||||||
|
|
||||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||||
network.nodes = {1: 1, 5: 5}
|
network.nodes = {1: 1, 5: 5}
|
||||||
|
|
||||||
client = yield from test_client(app)
|
|
||||||
|
|
||||||
resp = yield from client.get('/api/zwave/usercodes/2')
|
resp = yield from client.get('/api/zwave/usercodes/2')
|
||||||
|
|
||||||
assert resp.status == 404
|
assert resp.status == 404
|
||||||
|
@ -338,11 +280,8 @@ def test_get_usercodes_nonode(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_get_usercodes(hass, test_client):
|
def test_get_usercodes(hass, client):
|
||||||
"""Test getting usercodes on node."""
|
"""Test getting usercodes on node."""
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
ZWaveUserCodeView().register(app.router)
|
|
||||||
|
|
||||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||||
node = MockNode(node_id=18,
|
node = MockNode(node_id=18,
|
||||||
command_classes=[const.COMMAND_CLASS_USER_CODE])
|
command_classes=[const.COMMAND_CLASS_USER_CODE])
|
||||||
|
@ -356,8 +295,6 @@ def test_get_usercodes(hass, test_client):
|
||||||
network.nodes = {18: node}
|
network.nodes = {18: node}
|
||||||
node.get_values.return_value = node.values
|
node.get_values.return_value = node.values
|
||||||
|
|
||||||
client = yield from test_client(app)
|
|
||||||
|
|
||||||
resp = yield from client.get('/api/zwave/usercodes/18')
|
resp = yield from client.get('/api/zwave/usercodes/18')
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
@ -369,19 +306,14 @@ def test_get_usercodes(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_get_usercode_nousercode_node(hass, test_client):
|
def test_get_usercode_nousercode_node(hass, client):
|
||||||
"""Test getting usercodes on node without usercodes."""
|
"""Test getting usercodes on node without usercodes."""
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
ZWaveUserCodeView().register(app.router)
|
|
||||||
|
|
||||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||||
node = MockNode(node_id=18)
|
node = MockNode(node_id=18)
|
||||||
|
|
||||||
network.nodes = {18: node}
|
network.nodes = {18: node}
|
||||||
node.get_values.return_value = node.values
|
node.get_values.return_value = node.values
|
||||||
|
|
||||||
client = yield from test_client(app)
|
|
||||||
|
|
||||||
resp = yield from client.get('/api/zwave/usercodes/18')
|
resp = yield from client.get('/api/zwave/usercodes/18')
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
@ -391,11 +323,8 @@ def test_get_usercode_nousercode_node(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_get_usercodes_no_genreuser(hass, test_client):
|
def test_get_usercodes_no_genreuser(hass, client):
|
||||||
"""Test getting usercodes on node missing genre user."""
|
"""Test getting usercodes on node missing genre user."""
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
ZWaveUserCodeView().register(app.router)
|
|
||||||
|
|
||||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||||
node = MockNode(node_id=18,
|
node = MockNode(node_id=18,
|
||||||
command_classes=[const.COMMAND_CLASS_USER_CODE])
|
command_classes=[const.COMMAND_CLASS_USER_CODE])
|
||||||
|
@ -409,8 +338,6 @@ def test_get_usercodes_no_genreuser(hass, test_client):
|
||||||
network.nodes = {18: node}
|
network.nodes = {18: node}
|
||||||
node.get_values.return_value = node.values
|
node.get_values.return_value = node.values
|
||||||
|
|
||||||
client = yield from test_client(app)
|
|
||||||
|
|
||||||
resp = yield from client.get('/api/zwave/usercodes/18')
|
resp = yield from client.get('/api/zwave/usercodes/18')
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
@ -420,13 +347,8 @@ def test_get_usercodes_no_genreuser(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_save_config_no_network(hass, test_client):
|
def test_save_config_no_network(hass, client):
|
||||||
"""Test saving configuration without network data."""
|
"""Test saving configuration without network data."""
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
ZWaveConfigWriteView().register(app.router)
|
|
||||||
|
|
||||||
client = yield from test_client(app)
|
|
||||||
|
|
||||||
resp = yield from client.post('/api/zwave/saveconfig')
|
resp = yield from client.post('/api/zwave/saveconfig')
|
||||||
|
|
||||||
assert resp.status == 404
|
assert resp.status == 404
|
||||||
|
@ -435,15 +357,10 @@ def test_save_config_no_network(hass, test_client):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_save_config(hass, test_client):
|
def test_save_config(hass, client):
|
||||||
"""Test saving configuration."""
|
"""Test saving configuration."""
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
ZWaveConfigWriteView().register(app.router)
|
|
||||||
|
|
||||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||||
|
|
||||||
client = yield from test_client(app)
|
|
||||||
|
|
||||||
resp = yield from client.post('/api/zwave/saveconfig')
|
resp = yield from client.post('/api/zwave/saveconfig')
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
|
|
@ -5,11 +5,10 @@ import logging
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
import aioautomatic
|
import aioautomatic
|
||||||
|
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.components.device_tracker.automatic import (
|
from homeassistant.components.device_tracker.automatic import (
|
||||||
async_setup_scanner)
|
async_setup_scanner)
|
||||||
|
|
||||||
from tests.common import mock_http_component
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,8 +22,7 @@ def test_invalid_credentials(
|
||||||
mock_open, mock_isfile, mock_makedirs, mock_json_dump, mock_json_load,
|
mock_open, mock_isfile, mock_makedirs, mock_json_dump, mock_json_load,
|
||||||
mock_create_session, hass):
|
mock_create_session, hass):
|
||||||
"""Test with invalid credentials."""
|
"""Test with invalid credentials."""
|
||||||
mock_http_component(hass)
|
hass.loop.run_until_complete(async_setup_component(hass, 'http', {}))
|
||||||
|
|
||||||
mock_json_load.return_value = {'refresh_token': 'bad_token'}
|
mock_json_load.return_value = {'refresh_token': 'bad_token'}
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
@ -59,8 +57,7 @@ def test_valid_credentials(
|
||||||
mock_open, mock_isfile, mock_makedirs, mock_json_dump, mock_json_load,
|
mock_open, mock_isfile, mock_makedirs, mock_json_dump, mock_json_load,
|
||||||
mock_ws_connect, mock_create_session, hass):
|
mock_ws_connect, mock_create_session, hass):
|
||||||
"""Test with valid credentials."""
|
"""Test with valid credentials."""
|
||||||
mock_http_component(hass)
|
hass.loop.run_until_complete(async_setup_component(hass, 'http', {}))
|
||||||
|
|
||||||
mock_json_load.return_value = {'refresh_token': 'good_token'}
|
mock_json_load.return_value = {'refresh_token': 'good_token'}
|
||||||
|
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
|
|
|
@ -1 +1,38 @@
|
||||||
"""Tests for the HTTP component."""
|
"""Tests for the HTTP component."""
|
||||||
|
import asyncio
|
||||||
|
from ipaddress import ip_address
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from homeassistant.components.http.const import KEY_REAL_IP
|
||||||
|
|
||||||
|
|
||||||
|
def mock_real_ip(app):
|
||||||
|
"""Inject middleware to mock real IP.
|
||||||
|
|
||||||
|
Returns a function to set the real IP.
|
||||||
|
"""
|
||||||
|
ip_to_mock = None
|
||||||
|
|
||||||
|
def set_ip_to_mock(value):
|
||||||
|
nonlocal ip_to_mock
|
||||||
|
ip_to_mock = value
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
@web.middleware
|
||||||
|
def mock_real_ip(request, handler):
|
||||||
|
"""Mock Real IP middleware."""
|
||||||
|
nonlocal ip_to_mock
|
||||||
|
|
||||||
|
request[KEY_REAL_IP] = ip_address(ip_to_mock)
|
||||||
|
|
||||||
|
return (yield from handler(request))
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def real_ip_startup(app):
|
||||||
|
"""Startup of real ip."""
|
||||||
|
app.middlewares.insert(0, mock_real_ip)
|
||||||
|
|
||||||
|
app.on_startup.append(real_ip_startup)
|
||||||
|
|
||||||
|
return set_ip_to_mock
|
||||||
|
|
|
@ -1,195 +1,156 @@
|
||||||
"""The tests for the Home Assistant HTTP component."""
|
"""The tests for the Home Assistant HTTP component."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
import asyncio
|
import asyncio
|
||||||
from ipaddress import ip_address, ip_network
|
from ipaddress import ip_network
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import aiohttp
|
from aiohttp import BasicAuth, web
|
||||||
|
from aiohttp.web_exceptions import HTTPUnauthorized
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant import const
|
from homeassistant.const import HTTP_HEADER_HA_AUTH
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
import homeassistant.components.http as http
|
from homeassistant.components.http.auth import setup_auth
|
||||||
from homeassistant.components.http.const import (
|
from homeassistant.components.http.real_ip import setup_real_ip
|
||||||
KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
|
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
||||||
|
|
||||||
|
from . import mock_real_ip
|
||||||
|
|
||||||
API_PASSWORD = 'test1234'
|
API_PASSWORD = 'test1234'
|
||||||
|
|
||||||
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
|
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
|
||||||
TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1',
|
TRUSTED_NETWORKS = [
|
||||||
'FD01:DB8::1']
|
ip_network('192.0.2.0/24'),
|
||||||
|
ip_network('2001:DB8:ABCD::/48'),
|
||||||
|
ip_network('100.64.0.1'),
|
||||||
|
ip_network('FD01:DB8::1'),
|
||||||
|
]
|
||||||
TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
|
TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
|
||||||
'2001:DB8:ABCD::1']
|
'2001:DB8:ABCD::1']
|
||||||
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
|
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@asyncio.coroutine
|
||||||
def mock_api_client(hass, test_client):
|
def mock_handler(request):
|
||||||
"""Start the Hass HTTP component."""
|
"""Return if request was authenticated."""
|
||||||
hass.loop.run_until_complete(async_setup_component(hass, 'api', {
|
if not request[KEY_AUTHENTICATED]:
|
||||||
'http': {
|
raise HTTPUnauthorized
|
||||||
http.CONF_API_PASSWORD: API_PASSWORD,
|
return web.Response(status=200)
|
||||||
}
|
|
||||||
}))
|
|
||||||
return hass.loop.run_until_complete(test_client(hass.http.app))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_trusted_networks(hass, mock_api_client):
|
def app():
|
||||||
"""Mock trusted networks."""
|
"""Fixture to setup a web.Application."""
|
||||||
hass.http.app[KEY_TRUSTED_NETWORKS] = [
|
app = web.Application()
|
||||||
ip_network(trusted_network)
|
app.router.add_get('/', mock_handler)
|
||||||
for trusted_network in TRUSTED_NETWORKS]
|
setup_real_ip(app, False)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_access_denied_without_password(mock_api_client):
|
def test_auth_middleware_loaded_by_default(hass):
|
||||||
|
"""Test accessing to server from banned IP when feature is off."""
|
||||||
|
with patch('homeassistant.components.http.setup_auth') as mock_setup:
|
||||||
|
yield from async_setup_component(hass, 'http', {
|
||||||
|
'http': {}
|
||||||
|
})
|
||||||
|
|
||||||
|
assert len(mock_setup.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_access_without_password(app, test_client):
|
||||||
"""Test access without password."""
|
"""Test access without password."""
|
||||||
resp = yield from mock_api_client.get(const.URL_API)
|
setup_auth(app, [], None)
|
||||||
|
client = yield from test_client(app)
|
||||||
|
|
||||||
|
resp = yield from client.get('/')
|
||||||
|
assert resp.status == 200
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_access_with_password_in_header(app, test_client):
|
||||||
|
"""Test access with password in URL."""
|
||||||
|
setup_auth(app, [], API_PASSWORD)
|
||||||
|
client = yield from test_client(app)
|
||||||
|
|
||||||
|
req = yield from client.get(
|
||||||
|
'/', headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||||
|
assert req.status == 200
|
||||||
|
|
||||||
|
req = yield from client.get(
|
||||||
|
'/', headers={HTTP_HEADER_HA_AUTH: 'wrong-pass'})
|
||||||
|
assert req.status == 401
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_access_with_password_in_query(app, test_client):
|
||||||
|
"""Test access without password."""
|
||||||
|
setup_auth(app, [], API_PASSWORD)
|
||||||
|
client = yield from test_client(app)
|
||||||
|
|
||||||
|
resp = yield from client.get('/', params={
|
||||||
|
'api_password': API_PASSWORD
|
||||||
|
})
|
||||||
|
assert resp.status == 200
|
||||||
|
|
||||||
|
resp = yield from client.get('/')
|
||||||
assert resp.status == 401
|
assert resp.status == 401
|
||||||
|
|
||||||
|
resp = yield from client.get('/', params={
|
||||||
@asyncio.coroutine
|
'api_password': 'wrong-password'
|
||||||
def test_access_denied_with_wrong_password_in_header(mock_api_client):
|
|
||||||
"""Test access with wrong password."""
|
|
||||||
resp = yield from mock_api_client.get(const.URL_API, headers={
|
|
||||||
const.HTTP_HEADER_HA_AUTH: 'wrongpassword'
|
|
||||||
})
|
})
|
||||||
assert resp.status == 401
|
assert resp.status == 401
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_access_denied_with_x_forwarded_for(hass, mock_api_client,
|
def test_basic_auth_works(app, test_client):
|
||||||
mock_trusted_networks):
|
|
||||||
"""Test access denied through the X-Forwarded-For http header."""
|
|
||||||
hass.http.use_x_forwarded_for = True
|
|
||||||
for remote_addr in UNTRUSTED_ADDRESSES:
|
|
||||||
resp = yield from mock_api_client.get(const.URL_API, headers={
|
|
||||||
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
|
|
||||||
|
|
||||||
assert resp.status == 401, \
|
|
||||||
"{} shouldn't be trusted".format(remote_addr)
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def test_access_denied_with_untrusted_ip(mock_api_client,
|
|
||||||
mock_trusted_networks):
|
|
||||||
"""Test access with an untrusted ip address."""
|
|
||||||
for remote_addr in UNTRUSTED_ADDRESSES:
|
|
||||||
with patch('homeassistant.components.http.'
|
|
||||||
'util.get_real_ip',
|
|
||||||
return_value=ip_address(remote_addr)):
|
|
||||||
resp = yield from mock_api_client.get(
|
|
||||||
const.URL_API, params={'api_password': ''})
|
|
||||||
|
|
||||||
assert resp.status == 401, \
|
|
||||||
"{} shouldn't be trusted".format(remote_addr)
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def test_access_with_password_in_header(mock_api_client, caplog):
|
|
||||||
"""Test access with password in URL."""
|
|
||||||
# Hide logging from requests package that we use to test logging
|
|
||||||
req = yield from mock_api_client.get(
|
|
||||||
const.URL_API, headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
|
||||||
|
|
||||||
assert req.status == 200
|
|
||||||
|
|
||||||
logs = caplog.text
|
|
||||||
|
|
||||||
assert const.URL_API in logs
|
|
||||||
assert API_PASSWORD not in logs
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def test_access_denied_with_wrong_password_in_url(mock_api_client):
|
|
||||||
"""Test access with wrong password."""
|
|
||||||
resp = yield from mock_api_client.get(
|
|
||||||
const.URL_API, params={'api_password': 'wrongpassword'})
|
|
||||||
|
|
||||||
assert resp.status == 401
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def test_access_with_password_in_url(mock_api_client, caplog):
|
|
||||||
"""Test access with password in URL."""
|
|
||||||
req = yield from mock_api_client.get(
|
|
||||||
const.URL_API, params={'api_password': API_PASSWORD})
|
|
||||||
|
|
||||||
assert req.status == 200
|
|
||||||
|
|
||||||
logs = caplog.text
|
|
||||||
|
|
||||||
assert const.URL_API in logs
|
|
||||||
assert API_PASSWORD not in logs
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def test_access_granted_with_x_forwarded_for(hass, mock_api_client, caplog,
|
|
||||||
mock_trusted_networks):
|
|
||||||
"""Test access denied through the X-Forwarded-For http header."""
|
|
||||||
hass.http.app[KEY_USE_X_FORWARDED_FOR] = True
|
|
||||||
for remote_addr in TRUSTED_ADDRESSES:
|
|
||||||
resp = yield from mock_api_client.get(const.URL_API, headers={
|
|
||||||
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
|
|
||||||
|
|
||||||
assert resp.status == 200, \
|
|
||||||
"{} should be trusted".format(remote_addr)
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def test_access_granted_with_trusted_ip(mock_api_client, caplog,
|
|
||||||
mock_trusted_networks):
|
|
||||||
"""Test access with trusted addresses."""
|
|
||||||
for remote_addr in TRUSTED_ADDRESSES:
|
|
||||||
with patch('homeassistant.components.http.'
|
|
||||||
'auth.get_real_ip',
|
|
||||||
return_value=ip_address(remote_addr)):
|
|
||||||
resp = yield from mock_api_client.get(
|
|
||||||
const.URL_API, params={'api_password': ''})
|
|
||||||
|
|
||||||
assert resp.status == 200, \
|
|
||||||
'{} should be trusted'.format(remote_addr)
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def test_basic_auth_works(mock_api_client, caplog):
|
|
||||||
"""Test access with basic authentication."""
|
"""Test access with basic authentication."""
|
||||||
req = yield from mock_api_client.get(
|
setup_auth(app, [], API_PASSWORD)
|
||||||
const.URL_API,
|
client = yield from test_client(app)
|
||||||
auth=aiohttp.BasicAuth('homeassistant', API_PASSWORD))
|
|
||||||
|
|
||||||
|
req = yield from client.get(
|
||||||
|
'/',
|
||||||
|
auth=BasicAuth('homeassistant', API_PASSWORD))
|
||||||
assert req.status == 200
|
assert req.status == 200
|
||||||
assert const.URL_API in caplog.text
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def test_basic_auth_username_homeassistant(mock_api_client, caplog):
|
|
||||||
"""Test access with basic auth requires username homeassistant."""
|
|
||||||
req = yield from mock_api_client.get(
|
|
||||||
const.URL_API,
|
|
||||||
auth=aiohttp.BasicAuth('wrong_username', API_PASSWORD))
|
|
||||||
|
|
||||||
|
req = yield from client.get(
|
||||||
|
'/',
|
||||||
|
auth=BasicAuth('wrong_username', API_PASSWORD))
|
||||||
assert req.status == 401
|
assert req.status == 401
|
||||||
|
|
||||||
|
req = yield from client.get(
|
||||||
@asyncio.coroutine
|
'/',
|
||||||
def test_basic_auth_wrong_password(mock_api_client, caplog):
|
auth=BasicAuth('homeassistant', 'wrong password'))
|
||||||
"""Test access with basic auth not allowed with wrong password."""
|
|
||||||
req = yield from mock_api_client.get(
|
|
||||||
const.URL_API,
|
|
||||||
auth=aiohttp.BasicAuth('homeassistant', 'wrong password'))
|
|
||||||
|
|
||||||
assert req.status == 401
|
assert req.status == 401
|
||||||
|
|
||||||
|
req = yield from client.get(
|
||||||
@asyncio.coroutine
|
'/',
|
||||||
def test_authorization_header_must_be_basic_type(mock_api_client, caplog):
|
|
||||||
"""Test only basic authorization is allowed for auth header."""
|
|
||||||
req = yield from mock_api_client.get(
|
|
||||||
const.URL_API,
|
|
||||||
headers={
|
headers={
|
||||||
'authorization': 'NotBasic abcdefg'
|
'authorization': 'NotBasic abcdefg'
|
||||||
})
|
})
|
||||||
|
|
||||||
assert req.status == 401
|
assert req.status == 401
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_access_with_trusted_ip(test_client):
|
||||||
|
"""Test access with an untrusted ip address."""
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get('/', mock_handler)
|
||||||
|
|
||||||
|
setup_auth(app, TRUSTED_NETWORKS, 'some-pass')
|
||||||
|
|
||||||
|
set_mock_ip = mock_real_ip(app)
|
||||||
|
client = yield from test_client(app)
|
||||||
|
|
||||||
|
for remote_addr in UNTRUSTED_ADDRESSES:
|
||||||
|
set_mock_ip(remote_addr)
|
||||||
|
resp = yield from client.get('/')
|
||||||
|
assert resp.status == 401, \
|
||||||
|
"{} shouldn't be trusted".format(remote_addr)
|
||||||
|
|
||||||
|
for remote_addr in TRUSTED_ADDRESSES:
|
||||||
|
set_mock_ip(remote_addr)
|
||||||
|
resp = yield from client.get('/')
|
||||||
|
assert resp.status == 200, \
|
||||||
|
"{} should be trusted".format(remote_addr)
|
||||||
|
|
|
@ -1,91 +1,96 @@
|
||||||
"""The tests for the Home Assistant HTTP component."""
|
"""The tests for the Home Assistant HTTP component."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
import asyncio
|
import asyncio
|
||||||
from ipaddress import ip_address
|
|
||||||
from unittest.mock import patch, mock_open
|
from unittest.mock import patch, mock_open
|
||||||
|
|
||||||
import pytest
|
from aiohttp import web
|
||||||
|
from aiohttp.web_exceptions import HTTPUnauthorized
|
||||||
|
|
||||||
from homeassistant import const
|
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
import homeassistant.components.http as http
|
import homeassistant.components.http as http
|
||||||
from homeassistant.components.http.const import (
|
from homeassistant.components.http.ban import (
|
||||||
KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD, KEY_BANNED_IPS)
|
IpBan, IP_BANS_FILE, setup_bans, KEY_BANNED_IPS)
|
||||||
from homeassistant.components.http.ban import IpBan, IP_BANS_FILE
|
|
||||||
|
from . import mock_real_ip
|
||||||
|
|
||||||
API_PASSWORD = 'test1234'
|
|
||||||
BANNED_IPS = ['200.201.202.203', '100.64.0.2']
|
BANNED_IPS = ['200.201.202.203', '100.64.0.2']
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_api_client(hass, test_client):
|
|
||||||
"""Start the Hass HTTP component."""
|
|
||||||
hass.loop.run_until_complete(async_setup_component(hass, 'api', {
|
|
||||||
'http': {
|
|
||||||
http.CONF_API_PASSWORD: API_PASSWORD,
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
hass.http.app[KEY_BANNED_IPS] = [IpBan(banned_ip) for banned_ip
|
|
||||||
in BANNED_IPS]
|
|
||||||
return hass.loop.run_until_complete(test_client(hass.http.app))
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_access_from_banned_ip(hass, mock_api_client):
|
def test_access_from_banned_ip(hass, test_client):
|
||||||
"""Test accessing to server from banned IP. Both trusted and not."""
|
"""Test accessing to server from banned IP. Both trusted and not."""
|
||||||
hass.http.app[KEY_BANS_ENABLED] = True
|
app = web.Application()
|
||||||
|
setup_bans(hass, app, 5)
|
||||||
|
set_real_ip = mock_real_ip(app)
|
||||||
|
|
||||||
|
with patch('homeassistant.components.http.ban.load_ip_bans_config',
|
||||||
|
return_value=[IpBan(banned_ip) for banned_ip
|
||||||
|
in BANNED_IPS]):
|
||||||
|
client = yield from test_client(app)
|
||||||
|
|
||||||
for remote_addr in BANNED_IPS:
|
for remote_addr in BANNED_IPS:
|
||||||
with patch('homeassistant.components.http.'
|
set_real_ip(remote_addr)
|
||||||
'ban.get_real_ip',
|
resp = yield from client.get('/')
|
||||||
return_value=ip_address(remote_addr)):
|
assert resp.status == 403
|
||||||
resp = yield from mock_api_client.get(
|
|
||||||
const.URL_API)
|
|
||||||
assert resp.status == 403
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_access_from_banned_ip_when_ban_is_off(hass, mock_api_client):
|
def test_ban_middleware_not_loaded_by_config(hass):
|
||||||
"""Test accessing to server from banned IP when feature is off."""
|
"""Test accessing to server from banned IP when feature is off."""
|
||||||
hass.http.app[KEY_BANS_ENABLED] = False
|
with patch('homeassistant.components.http.setup_bans') as mock_setup:
|
||||||
for remote_addr in BANNED_IPS:
|
yield from async_setup_component(hass, 'http', {
|
||||||
with patch('homeassistant.components.http.'
|
'http': {
|
||||||
'ban.get_real_ip',
|
http.CONF_IP_BAN_ENABLED: False,
|
||||||
return_value=ip_address(remote_addr)):
|
}
|
||||||
resp = yield from mock_api_client.get(
|
})
|
||||||
const.URL_API,
|
|
||||||
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
assert len(mock_setup.mock_calls) == 0
|
||||||
assert resp.status == 200
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_ip_bans_file_creation(hass, mock_api_client):
|
def test_ban_middleware_loaded_by_default(hass):
|
||||||
|
"""Test accessing to server from banned IP when feature is off."""
|
||||||
|
with patch('homeassistant.components.http.setup_bans') as mock_setup:
|
||||||
|
yield from async_setup_component(hass, 'http', {
|
||||||
|
'http': {}
|
||||||
|
})
|
||||||
|
|
||||||
|
assert len(mock_setup.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_ip_bans_file_creation(hass, test_client):
|
||||||
"""Testing if banned IP file created."""
|
"""Testing if banned IP file created."""
|
||||||
hass.http.app[KEY_BANS_ENABLED] = True
|
app = web.Application()
|
||||||
hass.http.app[KEY_LOGIN_THRESHOLD] = 1
|
app['hass'] = hass
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def unauth_handler(request):
|
||||||
|
"""Return a mock web response."""
|
||||||
|
raise HTTPUnauthorized
|
||||||
|
|
||||||
|
app.router.add_get('/', unauth_handler)
|
||||||
|
setup_bans(hass, app, 1)
|
||||||
|
mock_real_ip(app)("200.201.202.204")
|
||||||
|
|
||||||
|
with patch('homeassistant.components.http.ban.load_ip_bans_config',
|
||||||
|
return_value=[IpBan(banned_ip) for banned_ip
|
||||||
|
in BANNED_IPS]):
|
||||||
|
client = yield from test_client(app)
|
||||||
|
|
||||||
m = mock_open()
|
m = mock_open()
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def call_server():
|
|
||||||
with patch('homeassistant.components.http.'
|
|
||||||
'ban.get_real_ip',
|
|
||||||
return_value=ip_address("200.201.202.204")):
|
|
||||||
resp = yield from mock_api_client.get(
|
|
||||||
const.URL_API,
|
|
||||||
headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'})
|
|
||||||
return resp
|
|
||||||
|
|
||||||
with patch('homeassistant.components.http.ban.open', m, create=True):
|
with patch('homeassistant.components.http.ban.open', m, create=True):
|
||||||
resp = yield from call_server()
|
resp = yield from client.get('/')
|
||||||
assert resp.status == 401
|
assert resp.status == 401
|
||||||
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS)
|
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS)
|
||||||
assert m.call_count == 0
|
assert m.call_count == 0
|
||||||
|
|
||||||
resp = yield from call_server()
|
resp = yield from client.get('/')
|
||||||
assert resp.status == 401
|
assert resp.status == 401
|
||||||
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
|
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
|
||||||
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
|
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
|
||||||
|
|
||||||
resp = yield from call_server()
|
resp = yield from client.get('/')
|
||||||
assert resp.status == 403
|
assert resp.status == 403
|
||||||
assert m.call_count == 1
|
assert m.call_count == 1
|
||||||
|
|
104
tests/components/http/test_cors.py
Normal file
104
tests/components/http/test_cors.py
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
"""Test cors for the HTTP component."""
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
from aiohttp.hdrs import (
|
||||||
|
ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||||
|
ACCESS_CONTROL_ALLOW_HEADERS,
|
||||||
|
ACCESS_CONTROL_REQUEST_HEADERS,
|
||||||
|
ACCESS_CONTROL_REQUEST_METHOD,
|
||||||
|
ORIGIN
|
||||||
|
)
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.const import HTTP_HEADER_HA_AUTH
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
from homeassistant.components.http.cors import setup_cors
|
||||||
|
|
||||||
|
|
||||||
|
TRUSTED_ORIGIN = 'https://home-assistant.io'
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_cors_middleware_not_loaded_by_default(hass):
|
||||||
|
"""Test accessing to server from banned IP when feature is off."""
|
||||||
|
with patch('homeassistant.components.http.setup_cors') as mock_setup:
|
||||||
|
yield from async_setup_component(hass, 'http', {
|
||||||
|
'http': {}
|
||||||
|
})
|
||||||
|
|
||||||
|
assert len(mock_setup.mock_calls) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_cors_middleware_loaded_from_config(hass):
|
||||||
|
"""Test accessing to server from banned IP when feature is off."""
|
||||||
|
with patch('homeassistant.components.http.setup_cors') as mock_setup:
|
||||||
|
yield from async_setup_component(hass, 'http', {
|
||||||
|
'http': {
|
||||||
|
'cors_allowed_origins': ['http://home-assistant.io']
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
assert len(mock_setup.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def mock_handler(request):
|
||||||
|
"""Return if request was authenticated."""
|
||||||
|
return web.Response(status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(loop, test_client):
|
||||||
|
"""Fixture to setup a web.Application."""
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get('/', mock_handler)
|
||||||
|
setup_cors(app, [TRUSTED_ORIGIN])
|
||||||
|
return loop.run_until_complete(test_client(app))
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_cors_requests(client):
|
||||||
|
"""Test cross origin requests."""
|
||||||
|
req = yield from client.get('/', headers={
|
||||||
|
ORIGIN: TRUSTED_ORIGIN
|
||||||
|
})
|
||||||
|
assert req.status == 200
|
||||||
|
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \
|
||||||
|
TRUSTED_ORIGIN
|
||||||
|
|
||||||
|
# With password in URL
|
||||||
|
req = yield from client.get('/', params={
|
||||||
|
'api_password': 'some-pass'
|
||||||
|
}, headers={
|
||||||
|
ORIGIN: TRUSTED_ORIGIN
|
||||||
|
})
|
||||||
|
assert req.status == 200
|
||||||
|
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \
|
||||||
|
TRUSTED_ORIGIN
|
||||||
|
|
||||||
|
# With password in headers
|
||||||
|
req = yield from client.get('/', headers={
|
||||||
|
HTTP_HEADER_HA_AUTH: 'some-pass',
|
||||||
|
ORIGIN: TRUSTED_ORIGIN
|
||||||
|
})
|
||||||
|
assert req.status == 200
|
||||||
|
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \
|
||||||
|
TRUSTED_ORIGIN
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_cors_preflight_allowed(client):
|
||||||
|
"""Test cross origin resource sharing preflight (OPTIONS) request."""
|
||||||
|
req = yield from client.options('/', headers={
|
||||||
|
ORIGIN: TRUSTED_ORIGIN,
|
||||||
|
ACCESS_CONTROL_REQUEST_METHOD: 'GET',
|
||||||
|
ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access'
|
||||||
|
})
|
||||||
|
|
||||||
|
assert req.status == 200
|
||||||
|
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == TRUSTED_ORIGIN
|
||||||
|
assert req.headers[ACCESS_CONTROL_ALLOW_HEADERS] == \
|
||||||
|
HTTP_HEADER_HA_AUTH.upper()
|
|
@ -1,124 +1,10 @@
|
||||||
"""The tests for the Home Assistant HTTP component."""
|
"""The tests for the Home Assistant HTTP component."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from aiohttp.hdrs import (
|
from homeassistant.setup import async_setup_component
|
||||||
ORIGIN, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_ALLOW_HEADERS,
|
|
||||||
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS,
|
|
||||||
CONTENT_TYPE)
|
|
||||||
import requests
|
|
||||||
from tests.common import get_test_instance_port, get_test_home_assistant
|
|
||||||
|
|
||||||
from homeassistant import const, setup
|
|
||||||
import homeassistant.components.http as http
|
import homeassistant.components.http as http
|
||||||
|
|
||||||
API_PASSWORD = 'test1234'
|
|
||||||
SERVER_PORT = get_test_instance_port()
|
|
||||||
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
|
|
||||||
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
|
|
||||||
HA_HEADERS = {
|
|
||||||
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
|
|
||||||
CONTENT_TYPE: const.CONTENT_TYPE_JSON,
|
|
||||||
}
|
|
||||||
CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE]
|
|
||||||
|
|
||||||
hass = None
|
|
||||||
|
|
||||||
|
|
||||||
def _url(path=''):
|
|
||||||
"""Helper method to generate URLs."""
|
|
||||||
return HTTP_BASE_URL + path
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def setUpModule():
|
|
||||||
"""Initialize a Home Assistant server."""
|
|
||||||
global hass
|
|
||||||
|
|
||||||
hass = get_test_home_assistant()
|
|
||||||
|
|
||||||
setup.setup_component(
|
|
||||||
hass, http.DOMAIN, {
|
|
||||||
http.DOMAIN: {
|
|
||||||
http.CONF_API_PASSWORD: API_PASSWORD,
|
|
||||||
http.CONF_SERVER_PORT: SERVER_PORT,
|
|
||||||
http.CONF_CORS_ORIGINS: CORS_ORIGINS,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
setup.setup_component(hass, 'api')
|
|
||||||
|
|
||||||
# Registering static path as it caused CORS to blow up
|
|
||||||
hass.http.register_static_path(
|
|
||||||
'/custom_components', hass.config.path('custom_components'))
|
|
||||||
|
|
||||||
hass.start()
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def tearDownModule():
|
|
||||||
"""Stop the Home Assistant server."""
|
|
||||||
hass.stop()
|
|
||||||
|
|
||||||
|
|
||||||
class TestCors:
|
|
||||||
"""Test HTTP component."""
|
|
||||||
|
|
||||||
def test_cors_allowed_with_password_in_url(self):
|
|
||||||
"""Test cross origin resource sharing with password in url."""
|
|
||||||
req = requests.get(_url(const.URL_API),
|
|
||||||
params={'api_password': API_PASSWORD},
|
|
||||||
headers={ORIGIN: HTTP_BASE_URL})
|
|
||||||
|
|
||||||
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
|
|
||||||
|
|
||||||
assert req.status_code == 200
|
|
||||||
assert req.headers.get(allow_origin) == HTTP_BASE_URL
|
|
||||||
|
|
||||||
def test_cors_allowed_with_password_in_header(self):
|
|
||||||
"""Test cross origin resource sharing with password in header."""
|
|
||||||
headers = {
|
|
||||||
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
|
|
||||||
ORIGIN: HTTP_BASE_URL
|
|
||||||
}
|
|
||||||
req = requests.get(_url(const.URL_API), headers=headers)
|
|
||||||
|
|
||||||
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
|
|
||||||
|
|
||||||
assert req.status_code == 200
|
|
||||||
assert req.headers.get(allow_origin) == HTTP_BASE_URL
|
|
||||||
|
|
||||||
def test_cors_denied_without_origin_header(self):
|
|
||||||
"""Test cross origin resource sharing with password in header."""
|
|
||||||
headers = {
|
|
||||||
const.HTTP_HEADER_HA_AUTH: API_PASSWORD
|
|
||||||
}
|
|
||||||
req = requests.get(_url(const.URL_API), headers=headers)
|
|
||||||
|
|
||||||
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
|
|
||||||
allow_headers = ACCESS_CONTROL_ALLOW_HEADERS
|
|
||||||
|
|
||||||
assert req.status_code == 200
|
|
||||||
assert allow_origin not in req.headers
|
|
||||||
assert allow_headers not in req.headers
|
|
||||||
|
|
||||||
def test_cors_preflight_allowed(self):
|
|
||||||
"""Test cross origin resource sharing preflight (OPTIONS) request."""
|
|
||||||
headers = {
|
|
||||||
ORIGIN: HTTP_BASE_URL,
|
|
||||||
ACCESS_CONTROL_REQUEST_METHOD: 'GET',
|
|
||||||
ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access'
|
|
||||||
}
|
|
||||||
req = requests.options(_url(const.URL_API), headers=headers)
|
|
||||||
|
|
||||||
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
|
|
||||||
allow_headers = ACCESS_CONTROL_ALLOW_HEADERS
|
|
||||||
|
|
||||||
assert req.status_code == 200
|
|
||||||
assert req.headers.get(allow_origin) == HTTP_BASE_URL
|
|
||||||
assert req.headers.get(allow_headers) == \
|
|
||||||
const.HTTP_HEADER_HA_AUTH.upper()
|
|
||||||
|
|
||||||
|
|
||||||
class TestView(http.HomeAssistantView):
|
class TestView(http.HomeAssistantView):
|
||||||
"""Test the HTTP views."""
|
"""Test the HTTP views."""
|
||||||
|
@ -133,12 +19,12 @@ class TestView(http.HomeAssistantView):
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_registering_view_while_running(hass, test_client):
|
def test_registering_view_while_running(hass, test_client, unused_port):
|
||||||
"""Test that we can register a view while the server is running."""
|
"""Test that we can register a view while the server is running."""
|
||||||
yield from setup.async_setup_component(
|
yield from async_setup_component(
|
||||||
hass, http.DOMAIN, {
|
hass, http.DOMAIN, {
|
||||||
http.DOMAIN: {
|
http.DOMAIN: {
|
||||||
http.CONF_SERVER_PORT: get_test_instance_port(),
|
http.CONF_SERVER_PORT: unused_port(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -151,7 +37,7 @@ def test_registering_view_while_running(hass, test_client):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_api_base_url_with_domain(hass):
|
def test_api_base_url_with_domain(hass):
|
||||||
"""Test setting API URL."""
|
"""Test setting API URL."""
|
||||||
result = yield from setup.async_setup_component(hass, 'http', {
|
result = yield from async_setup_component(hass, 'http', {
|
||||||
'http': {
|
'http': {
|
||||||
'base_url': 'example.com'
|
'base_url': 'example.com'
|
||||||
}
|
}
|
||||||
|
@ -163,7 +49,7 @@ def test_api_base_url_with_domain(hass):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_api_base_url_with_ip(hass):
|
def test_api_base_url_with_ip(hass):
|
||||||
"""Test setting api url."""
|
"""Test setting api url."""
|
||||||
result = yield from setup.async_setup_component(hass, 'http', {
|
result = yield from async_setup_component(hass, 'http', {
|
||||||
'http': {
|
'http': {
|
||||||
'server_host': '1.1.1.1'
|
'server_host': '1.1.1.1'
|
||||||
}
|
}
|
||||||
|
@ -175,7 +61,7 @@ def test_api_base_url_with_ip(hass):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_api_base_url_with_ip_port(hass):
|
def test_api_base_url_with_ip_port(hass):
|
||||||
"""Test setting api url."""
|
"""Test setting api url."""
|
||||||
result = yield from setup.async_setup_component(hass, 'http', {
|
result = yield from async_setup_component(hass, 'http', {
|
||||||
'http': {
|
'http': {
|
||||||
'base_url': '1.1.1.1:8124'
|
'base_url': '1.1.1.1:8124'
|
||||||
}
|
}
|
||||||
|
@ -187,9 +73,34 @@ def test_api_base_url_with_ip_port(hass):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_api_no_base_url(hass):
|
def test_api_no_base_url(hass):
|
||||||
"""Test setting api url."""
|
"""Test setting api url."""
|
||||||
result = yield from setup.async_setup_component(hass, 'http', {
|
result = yield from async_setup_component(hass, 'http', {
|
||||||
'http': {
|
'http': {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
assert result
|
assert result
|
||||||
assert hass.config.api.base_url == 'http://127.0.0.1:8123'
|
assert hass.config.api.base_url == 'http://127.0.0.1:8123'
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_not_log_password(hass, unused_port, test_client, caplog):
|
||||||
|
"""Test access with password doesn't get logged."""
|
||||||
|
result = yield from async_setup_component(hass, 'api', {
|
||||||
|
'http': {
|
||||||
|
http.CONF_SERVER_PORT: unused_port(),
|
||||||
|
http.CONF_API_PASSWORD: 'some-pass'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
assert result
|
||||||
|
|
||||||
|
client = yield from test_client(hass.http.app)
|
||||||
|
|
||||||
|
resp = yield from client.get('/api/', params={
|
||||||
|
'api_password': 'some-pass'
|
||||||
|
})
|
||||||
|
|
||||||
|
assert resp.status == 200
|
||||||
|
logs = caplog.text
|
||||||
|
|
||||||
|
# Ensure we don't log API passwords
|
||||||
|
assert '/api/' in logs
|
||||||
|
assert 'some-pass' not in logs
|
||||||
|
|
48
tests/components/http/test_real_ip.py
Normal file
48
tests/components/http/test_real_ip.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
"""Test real IP middleware."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
from aiohttp.hdrs import X_FORWARDED_FOR
|
||||||
|
|
||||||
|
from homeassistant.components.http.real_ip import setup_real_ip
|
||||||
|
from homeassistant.components.http.const import KEY_REAL_IP
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def mock_handler(request):
|
||||||
|
"""Handler that returns the real IP as text."""
|
||||||
|
return web.Response(text=str(request[KEY_REAL_IP]))
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_ignore_x_forwarded_for(test_client):
|
||||||
|
"""Test that we get the IP from the transport."""
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get('/', mock_handler)
|
||||||
|
setup_real_ip(app, False)
|
||||||
|
|
||||||
|
mock_api_client = yield from test_client(app)
|
||||||
|
|
||||||
|
resp = yield from mock_api_client.get('/', headers={
|
||||||
|
X_FORWARDED_FOR: '255.255.255.255'
|
||||||
|
})
|
||||||
|
assert resp.status == 200
|
||||||
|
text = yield from resp.text()
|
||||||
|
assert text != '255.255.255.255'
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_use_x_forwarded_for(test_client):
|
||||||
|
"""Test that we get the IP from the transport."""
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get('/', mock_handler)
|
||||||
|
setup_real_ip(app, True)
|
||||||
|
|
||||||
|
mock_api_client = yield from test_client(app)
|
||||||
|
|
||||||
|
resp = yield from mock_api_client.get('/', headers={
|
||||||
|
X_FORWARDED_FOR: '255.255.255.255'
|
||||||
|
})
|
||||||
|
assert resp.status == 200
|
||||||
|
text = yield from resp.text()
|
||||||
|
assert text == '255.255.255.255'
|
|
@ -4,8 +4,7 @@ from unittest.mock import Mock, MagicMock, patch
|
||||||
from homeassistant.setup import setup_component
|
from homeassistant.setup import setup_component
|
||||||
import homeassistant.components.mqtt as mqtt
|
import homeassistant.components.mqtt as mqtt
|
||||||
|
|
||||||
from tests.common import (
|
from tests.common import get_test_home_assistant, mock_coro
|
||||||
get_test_home_assistant, mock_coro, mock_http_component)
|
|
||||||
|
|
||||||
|
|
||||||
class TestMQTT:
|
class TestMQTT:
|
||||||
|
@ -14,7 +13,9 @@ class TestMQTT:
|
||||||
def setup_method(self, method):
|
def setup_method(self, method):
|
||||||
"""Setup things to be run when tests are started."""
|
"""Setup things to be run when tests are started."""
|
||||||
self.hass = get_test_home_assistant()
|
self.hass = get_test_home_assistant()
|
||||||
mock_http_component(self.hass, 'super_secret')
|
setup_component(self.hass, 'http', {
|
||||||
|
'api_password': 'super_secret'
|
||||||
|
})
|
||||||
|
|
||||||
def teardown_method(self, method):
|
def teardown_method(self, method):
|
||||||
"""Stop everything that was started."""
|
"""Stop everything that was started."""
|
||||||
|
|
|
@ -4,12 +4,10 @@ import json
|
||||||
from unittest.mock import patch, MagicMock, mock_open
|
from unittest.mock import patch, MagicMock, mock_open
|
||||||
from aiohttp.hdrs import AUTHORIZATION
|
from aiohttp.hdrs import AUTHORIZATION
|
||||||
|
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.util.json import save_json
|
|
||||||
from homeassistant.components.notify import html5
|
from homeassistant.components.notify import html5
|
||||||
|
|
||||||
from tests.common import mock_http_component_app
|
|
||||||
|
|
||||||
CONFIG_FILE = 'file.conf'
|
CONFIG_FILE = 'file.conf'
|
||||||
|
|
||||||
SUBSCRIPTION_1 = {
|
SUBSCRIPTION_1 = {
|
||||||
|
@ -52,6 +50,23 @@ REGISTER_URL = '/api/notify.html5'
|
||||||
PUBLISH_URL = '/api/notify.html5/callback'
|
PUBLISH_URL = '/api/notify.html5/callback'
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def mock_client(hass, test_client, registrations=None):
|
||||||
|
"""Create a test client for HTML5 views."""
|
||||||
|
if registrations is None:
|
||||||
|
registrations = {}
|
||||||
|
|
||||||
|
with patch('homeassistant.components.notify.html5._load_config',
|
||||||
|
return_value=registrations):
|
||||||
|
yield from async_setup_component(hass, 'notify', {
|
||||||
|
'notify': {
|
||||||
|
'platform': 'html5'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return (yield from test_client(hass.http.app))
|
||||||
|
|
||||||
|
|
||||||
class TestHtml5Notify(object):
|
class TestHtml5Notify(object):
|
||||||
"""Tests for HTML5 notify platform."""
|
"""Tests for HTML5 notify platform."""
|
||||||
|
|
||||||
|
@ -89,8 +104,6 @@ class TestHtml5Notify(object):
|
||||||
service.send_message('Hello', target=['device', 'non_existing'],
|
service.send_message('Hello', target=['device', 'non_existing'],
|
||||||
data={'icon': 'beer.png'})
|
data={'icon': 'beer.png'})
|
||||||
|
|
||||||
print(mock_wp.mock_calls)
|
|
||||||
|
|
||||||
assert len(mock_wp.mock_calls) == 3
|
assert len(mock_wp.mock_calls) == 3
|
||||||
|
|
||||||
# WebPusher constructor
|
# WebPusher constructor
|
||||||
|
@ -104,421 +117,224 @@ class TestHtml5Notify(object):
|
||||||
assert payload['body'] == 'Hello'
|
assert payload['body'] == 'Hello'
|
||||||
assert payload['icon'] == 'beer.png'
|
assert payload['icon'] == 'beer.png'
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def test_registering_new_device_view(self, loop, test_client):
|
|
||||||
"""Test that the HTML view works."""
|
|
||||||
hass = MagicMock()
|
|
||||||
expected = {
|
|
||||||
'unnamed device': SUBSCRIPTION_1,
|
|
||||||
}
|
|
||||||
|
|
||||||
hass.config.path.return_value = CONFIG_FILE
|
@asyncio.coroutine
|
||||||
service = html5.get_service(hass, {})
|
def test_registering_new_device_view(hass, test_client):
|
||||||
|
"""Test that the HTML view works."""
|
||||||
|
client = yield from mock_client(hass, test_client)
|
||||||
|
|
||||||
assert service is not None
|
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
|
||||||
|
|
||||||
assert len(hass.mock_calls) == 3
|
|
||||||
|
|
||||||
view = hass.mock_calls[1][1][0]
|
|
||||||
assert view.json_path == hass.config.path.return_value
|
|
||||||
assert view.registrations == {}
|
|
||||||
|
|
||||||
hass.loop = loop
|
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
view.register(app.router)
|
|
||||||
client = yield from test_client(app)
|
|
||||||
hass.http.is_banned_ip.return_value = False
|
|
||||||
resp = yield from client.post(REGISTER_URL,
|
resp = yield from client.post(REGISTER_URL,
|
||||||
data=json.dumps(SUBSCRIPTION_1))
|
data=json.dumps(SUBSCRIPTION_1))
|
||||||
|
|
||||||
content = yield from resp.text()
|
assert resp.status == 200
|
||||||
assert resp.status == 200, content
|
assert len(mock_save.mock_calls) == 1
|
||||||
assert view.registrations == expected
|
assert mock_save.mock_calls[0][1][1] == {
|
||||||
|
'unnamed device': SUBSCRIPTION_1,
|
||||||
|
}
|
||||||
|
|
||||||
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_registering_new_device_expiration_view(self, loop, test_client):
|
def test_registering_new_device_expiration_view(hass, test_client):
|
||||||
"""Test that the HTML view works."""
|
"""Test that the HTML view works."""
|
||||||
hass = MagicMock()
|
client = yield from mock_client(hass, test_client)
|
||||||
expected = {
|
|
||||||
'unnamed device': SUBSCRIPTION_4,
|
|
||||||
}
|
|
||||||
|
|
||||||
hass.config.path.return_value = CONFIG_FILE
|
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
|
||||||
service = html5.get_service(hass, {})
|
|
||||||
|
|
||||||
assert service is not None
|
|
||||||
|
|
||||||
# assert hass.called
|
|
||||||
assert len(hass.mock_calls) == 3
|
|
||||||
|
|
||||||
view = hass.mock_calls[1][1][0]
|
|
||||||
assert view.json_path == hass.config.path.return_value
|
|
||||||
assert view.registrations == {}
|
|
||||||
|
|
||||||
hass.loop = loop
|
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
view.register(app.router)
|
|
||||||
client = yield from test_client(app)
|
|
||||||
hass.http.is_banned_ip.return_value = False
|
|
||||||
resp = yield from client.post(REGISTER_URL,
|
resp = yield from client.post(REGISTER_URL,
|
||||||
data=json.dumps(SUBSCRIPTION_4))
|
data=json.dumps(SUBSCRIPTION_4))
|
||||||
|
|
||||||
content = yield from resp.text()
|
assert resp.status == 200
|
||||||
assert resp.status == 200, content
|
assert mock_save.mock_calls[0][1][1] == {
|
||||||
assert view.registrations == expected
|
'unnamed device': SUBSCRIPTION_4,
|
||||||
|
}
|
||||||
|
|
||||||
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_registering_new_device_fails_view(self, loop, test_client):
|
def test_registering_new_device_fails_view(hass, test_client):
|
||||||
"""Test subs. are not altered when registering a new device fails."""
|
"""Test subs. are not altered when registering a new device fails."""
|
||||||
hass = MagicMock()
|
registrations = {}
|
||||||
expected = {}
|
client = yield from mock_client(hass, test_client, registrations)
|
||||||
|
|
||||||
hass.config.path.return_value = CONFIG_FILE
|
|
||||||
html5.get_service(hass, {})
|
|
||||||
view = hass.mock_calls[1][1][0]
|
|
||||||
|
|
||||||
hass.loop = loop
|
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
view.register(app.router)
|
|
||||||
client = yield from test_client(app)
|
|
||||||
hass.http.is_banned_ip.return_value = False
|
|
||||||
|
|
||||||
hass.async_add_job.side_effect = HomeAssistantError()
|
|
||||||
|
|
||||||
|
with patch('homeassistant.components.notify.html5.save_json',
|
||||||
|
side_effect=HomeAssistantError()):
|
||||||
resp = yield from client.post(REGISTER_URL,
|
resp = yield from client.post(REGISTER_URL,
|
||||||
data=json.dumps(SUBSCRIPTION_1))
|
data=json.dumps(SUBSCRIPTION_4))
|
||||||
|
|
||||||
content = yield from resp.text()
|
assert resp.status == 500
|
||||||
assert resp.status == 500, content
|
assert registrations == {}
|
||||||
assert view.registrations == expected
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def test_registering_existing_device_view(self, loop, test_client):
|
|
||||||
"""Test subscription is updated when registering existing device."""
|
|
||||||
hass = MagicMock()
|
|
||||||
expected = {
|
|
||||||
'unnamed device': SUBSCRIPTION_4,
|
|
||||||
}
|
|
||||||
|
|
||||||
hass.config.path.return_value = CONFIG_FILE
|
@asyncio.coroutine
|
||||||
html5.get_service(hass, {})
|
def test_registering_existing_device_view(hass, test_client):
|
||||||
view = hass.mock_calls[1][1][0]
|
"""Test subscription is updated when registering existing device."""
|
||||||
|
registrations = {}
|
||||||
hass.loop = loop
|
client = yield from mock_client(hass, test_client, registrations)
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
view.register(app.router)
|
|
||||||
client = yield from test_client(app)
|
|
||||||
hass.http.is_banned_ip.return_value = False
|
|
||||||
|
|
||||||
|
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
|
||||||
yield from client.post(REGISTER_URL,
|
yield from client.post(REGISTER_URL,
|
||||||
data=json.dumps(SUBSCRIPTION_1))
|
data=json.dumps(SUBSCRIPTION_1))
|
||||||
resp = yield from client.post(REGISTER_URL,
|
resp = yield from client.post(REGISTER_URL,
|
||||||
data=json.dumps(SUBSCRIPTION_4))
|
data=json.dumps(SUBSCRIPTION_4))
|
||||||
|
|
||||||
content = yield from resp.text()
|
assert resp.status == 200
|
||||||
assert resp.status == 200, content
|
assert mock_save.mock_calls[0][1][1] == {
|
||||||
assert view.registrations == expected
|
'unnamed device': SUBSCRIPTION_4,
|
||||||
|
}
|
||||||
|
assert registrations == {
|
||||||
|
'unnamed device': SUBSCRIPTION_4,
|
||||||
|
}
|
||||||
|
|
||||||
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_registering_existing_device_fails_view(self, loop, test_client):
|
def test_registering_existing_device_fails_view(hass, test_client):
|
||||||
"""Test sub. is not updated when registering existing device fails."""
|
"""Test sub. is not updated when registering existing device fails."""
|
||||||
hass = MagicMock()
|
registrations = {}
|
||||||
expected = {
|
client = yield from mock_client(hass, test_client, registrations)
|
||||||
'unnamed device': SUBSCRIPTION_1,
|
|
||||||
}
|
|
||||||
|
|
||||||
hass.config.path.return_value = CONFIG_FILE
|
|
||||||
html5.get_service(hass, {})
|
|
||||||
view = hass.mock_calls[1][1][0]
|
|
||||||
|
|
||||||
hass.loop = loop
|
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
view.register(app.router)
|
|
||||||
client = yield from test_client(app)
|
|
||||||
hass.http.is_banned_ip.return_value = False
|
|
||||||
|
|
||||||
|
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
|
||||||
yield from client.post(REGISTER_URL,
|
yield from client.post(REGISTER_URL,
|
||||||
data=json.dumps(SUBSCRIPTION_1))
|
data=json.dumps(SUBSCRIPTION_1))
|
||||||
|
mock_save.side_effect = HomeAssistantError
|
||||||
hass.async_add_job.side_effect = HomeAssistantError()
|
|
||||||
resp = yield from client.post(REGISTER_URL,
|
resp = yield from client.post(REGISTER_URL,
|
||||||
data=json.dumps(SUBSCRIPTION_4))
|
data=json.dumps(SUBSCRIPTION_4))
|
||||||
|
|
||||||
content = yield from resp.text()
|
assert resp.status == 500
|
||||||
assert resp.status == 500, content
|
assert registrations == {
|
||||||
assert view.registrations == expected
|
'unnamed device': SUBSCRIPTION_1,
|
||||||
|
}
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def test_registering_new_device_validation(self, loop, test_client):
|
|
||||||
"""Test various errors when registering a new device."""
|
|
||||||
hass = MagicMock()
|
|
||||||
|
|
||||||
m = mock_open()
|
@asyncio.coroutine
|
||||||
with patch(
|
def test_registering_new_device_validation(hass, test_client):
|
||||||
'homeassistant.util.json.open',
|
"""Test various errors when registering a new device."""
|
||||||
m, create=True
|
client = yield from mock_client(hass, test_client)
|
||||||
):
|
|
||||||
hass.config.path.return_value = CONFIG_FILE
|
|
||||||
service = html5.get_service(hass, {})
|
|
||||||
|
|
||||||
assert service is not None
|
resp = yield from client.post(REGISTER_URL, data=json.dumps({
|
||||||
|
'browser': 'invalid browser',
|
||||||
|
'subscription': 'sub info',
|
||||||
|
}))
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
# assert hass.called
|
resp = yield from client.post(REGISTER_URL, data=json.dumps({
|
||||||
assert len(hass.mock_calls) == 3
|
'browser': 'chrome',
|
||||||
|
}))
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
view = hass.mock_calls[1][1][0]
|
with patch('homeassistant.components.notify.html5.save_json',
|
||||||
|
return_value=False):
|
||||||
|
resp = yield from client.post(REGISTER_URL, data=json.dumps({
|
||||||
|
'browser': 'chrome',
|
||||||
|
'subscription': 'sub info',
|
||||||
|
}))
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
hass.loop = loop
|
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
view.register(app.router)
|
|
||||||
client = yield from test_client(app)
|
|
||||||
hass.http.is_banned_ip.return_value = False
|
|
||||||
|
|
||||||
resp = yield from client.post(REGISTER_URL, data=json.dumps({
|
@asyncio.coroutine
|
||||||
'browser': 'invalid browser',
|
def test_unregistering_device_view(hass, test_client):
|
||||||
'subscription': 'sub info',
|
"""Test that the HTML unregister view works."""
|
||||||
}))
|
registrations = {
|
||||||
assert resp.status == 400
|
'some device': SUBSCRIPTION_1,
|
||||||
|
'other device': SUBSCRIPTION_2,
|
||||||
|
}
|
||||||
|
client = yield from mock_client(hass, test_client, registrations)
|
||||||
|
|
||||||
resp = yield from client.post(REGISTER_URL, data=json.dumps({
|
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
|
||||||
'browser': 'chrome',
|
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
||||||
}))
|
'subscription': SUBSCRIPTION_1['subscription'],
|
||||||
assert resp.status == 400
|
}))
|
||||||
|
|
||||||
with patch('homeassistant.components.notify.html5.save_json',
|
assert resp.status == 200
|
||||||
return_value=False):
|
assert len(mock_save.mock_calls) == 1
|
||||||
# resp = view.post(Request(builder.get_environ()))
|
assert registrations == {
|
||||||
resp = yield from client.post(REGISTER_URL, data=json.dumps({
|
'other device': SUBSCRIPTION_2
|
||||||
'browser': 'chrome',
|
}
|
||||||
'subscription': 'sub info',
|
|
||||||
}))
|
|
||||||
|
|
||||||
assert resp.status == 400
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_unregistering_device_view(self, loop, test_client):
|
def test_unregister_device_view_handle_unknown_subscription(hass, test_client):
|
||||||
"""Test that the HTML unregister view works."""
|
"""Test that the HTML unregister view handles unknown subscriptions."""
|
||||||
hass = MagicMock()
|
registrations = {}
|
||||||
|
client = yield from mock_client(hass, test_client, registrations)
|
||||||
|
|
||||||
config = {
|
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
|
||||||
'some device': SUBSCRIPTION_1,
|
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
||||||
'other device': SUBSCRIPTION_2,
|
'subscription': SUBSCRIPTION_3['subscription']
|
||||||
}
|
}))
|
||||||
|
|
||||||
m = mock_open(read_data=json.dumps(config))
|
assert resp.status == 200, resp.response
|
||||||
with patch(
|
assert registrations == {}
|
||||||
'homeassistant.util.json.open',
|
assert len(mock_save.mock_calls) == 0
|
||||||
m, create=True
|
|
||||||
):
|
|
||||||
hass.config.path.return_value = CONFIG_FILE
|
|
||||||
service = html5.get_service(hass, {})
|
|
||||||
|
|
||||||
assert service is not None
|
|
||||||
|
|
||||||
# assert hass.called
|
@asyncio.coroutine
|
||||||
assert len(hass.mock_calls) == 3
|
def test_unregistering_device_view_handles_save_error(hass, test_client):
|
||||||
|
"""Test that the HTML unregister view handles save errors."""
|
||||||
|
registrations = {
|
||||||
|
'some device': SUBSCRIPTION_1,
|
||||||
|
'other device': SUBSCRIPTION_2,
|
||||||
|
}
|
||||||
|
client = yield from mock_client(hass, test_client, registrations)
|
||||||
|
|
||||||
view = hass.mock_calls[1][1][0]
|
with patch('homeassistant.components.notify.html5.save_json',
|
||||||
assert view.json_path == hass.config.path.return_value
|
side_effect=HomeAssistantError()):
|
||||||
assert view.registrations == config
|
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
||||||
|
'subscription': SUBSCRIPTION_1['subscription'],
|
||||||
|
}))
|
||||||
|
|
||||||
hass.loop = loop
|
assert resp.status == 500, resp.response
|
||||||
app = mock_http_component_app(hass)
|
assert registrations == {
|
||||||
view.register(app.router)
|
'some device': SUBSCRIPTION_1,
|
||||||
client = yield from test_client(app)
|
'other device': SUBSCRIPTION_2,
|
||||||
hass.http.is_banned_ip.return_value = False
|
}
|
||||||
|
|
||||||
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
|
||||||
'subscription': SUBSCRIPTION_1['subscription'],
|
|
||||||
}))
|
|
||||||
|
|
||||||
config.pop('some device')
|
@asyncio.coroutine
|
||||||
|
def test_callback_view_no_jwt(hass, test_client):
|
||||||
|
"""Test that the notification callback view works without JWT."""
|
||||||
|
client = yield from mock_client(hass, test_client)
|
||||||
|
resp = yield from client.post(PUBLISH_URL, data=json.dumps({
|
||||||
|
'type': 'push',
|
||||||
|
'tag': '3bc28d69-0921-41f1-ac6a-7a627ba0aa72'
|
||||||
|
}))
|
||||||
|
|
||||||
assert resp.status == 200, resp.response
|
assert resp.status == 401, resp.response
|
||||||
assert view.registrations == config
|
|
||||||
|
|
||||||
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE,
|
|
||||||
config)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_unregister_device_view_handle_unknown_subscription(
|
def test_callback_view_with_jwt(hass, test_client):
|
||||||
self, loop, test_client):
|
"""Test that the notification callback view works with JWT."""
|
||||||
"""Test that the HTML unregister view handles unknown subscriptions."""
|
registrations = {
|
||||||
hass = MagicMock()
|
'device': SUBSCRIPTION_1
|
||||||
|
}
|
||||||
|
client = yield from mock_client(hass, test_client, registrations)
|
||||||
|
|
||||||
config = {
|
with patch('pywebpush.WebPusher') as mock_wp:
|
||||||
'some device': SUBSCRIPTION_1,
|
yield from hass.services.async_call('notify', 'notify', {
|
||||||
'other device': SUBSCRIPTION_2,
|
'message': 'Hello',
|
||||||
}
|
'target': ['device'],
|
||||||
|
'data': {'icon': 'beer.png'}
|
||||||
|
}, blocking=True)
|
||||||
|
|
||||||
m = mock_open(read_data=json.dumps(config))
|
assert len(mock_wp.mock_calls) == 3
|
||||||
with patch(
|
|
||||||
'homeassistant.util.json.open',
|
|
||||||
m, create=True
|
|
||||||
):
|
|
||||||
hass.config.path.return_value = CONFIG_FILE
|
|
||||||
service = html5.get_service(hass, {})
|
|
||||||
|
|
||||||
assert service is not None
|
# WebPusher constructor
|
||||||
|
assert mock_wp.mock_calls[0][1][0] == \
|
||||||
|
SUBSCRIPTION_1['subscription']
|
||||||
|
# Third mock_call checks the status_code of the response.
|
||||||
|
assert mock_wp.mock_calls[2][0] == '().send().status_code.__eq__'
|
||||||
|
|
||||||
# assert hass.called
|
# Call to send
|
||||||
assert len(hass.mock_calls) == 3
|
push_payload = json.loads(mock_wp.mock_calls[1][1][0])
|
||||||
|
|
||||||
view = hass.mock_calls[1][1][0]
|
assert push_payload['body'] == 'Hello'
|
||||||
assert view.json_path == hass.config.path.return_value
|
assert push_payload['icon'] == 'beer.png'
|
||||||
assert view.registrations == config
|
|
||||||
|
|
||||||
hass.loop = loop
|
bearer_token = "Bearer {}".format(push_payload['data']['jwt'])
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
view.register(app.router)
|
|
||||||
client = yield from test_client(app)
|
|
||||||
hass.http.is_banned_ip.return_value = False
|
|
||||||
|
|
||||||
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
resp = yield from client.post(PUBLISH_URL, json={
|
||||||
'subscription': SUBSCRIPTION_3['subscription']
|
'type': 'push',
|
||||||
}))
|
}, headers={AUTHORIZATION: bearer_token})
|
||||||
|
|
||||||
assert resp.status == 200, resp.response
|
assert resp.status == 200
|
||||||
assert view.registrations == config
|
body = yield from resp.json()
|
||||||
|
assert body == {"event": "push", "status": "ok"}
|
||||||
hass.async_add_job.assert_not_called()
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def test_unregistering_device_view_handles_save_error(
|
|
||||||
self, loop, test_client):
|
|
||||||
"""Test that the HTML unregister view handles save errors."""
|
|
||||||
hass = MagicMock()
|
|
||||||
|
|
||||||
config = {
|
|
||||||
'some device': SUBSCRIPTION_1,
|
|
||||||
'other device': SUBSCRIPTION_2,
|
|
||||||
}
|
|
||||||
|
|
||||||
m = mock_open(read_data=json.dumps(config))
|
|
||||||
with patch(
|
|
||||||
'homeassistant.util.json.open',
|
|
||||||
m, create=True
|
|
||||||
):
|
|
||||||
hass.config.path.return_value = CONFIG_FILE
|
|
||||||
service = html5.get_service(hass, {})
|
|
||||||
|
|
||||||
assert service is not None
|
|
||||||
|
|
||||||
# assert hass.called
|
|
||||||
assert len(hass.mock_calls) == 3
|
|
||||||
|
|
||||||
view = hass.mock_calls[1][1][0]
|
|
||||||
assert view.json_path == hass.config.path.return_value
|
|
||||||
assert view.registrations == config
|
|
||||||
|
|
||||||
hass.loop = loop
|
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
view.register(app.router)
|
|
||||||
client = yield from test_client(app)
|
|
||||||
hass.http.is_banned_ip.return_value = False
|
|
||||||
|
|
||||||
hass.async_add_job.side_effect = HomeAssistantError()
|
|
||||||
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
|
||||||
'subscription': SUBSCRIPTION_1['subscription'],
|
|
||||||
}))
|
|
||||||
|
|
||||||
assert resp.status == 500, resp.response
|
|
||||||
assert view.registrations == config
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def test_callback_view_no_jwt(self, loop, test_client):
|
|
||||||
"""Test that the notification callback view works without JWT."""
|
|
||||||
hass = MagicMock()
|
|
||||||
|
|
||||||
m = mock_open()
|
|
||||||
with patch(
|
|
||||||
'homeassistant.util.json.open',
|
|
||||||
m, create=True
|
|
||||||
):
|
|
||||||
hass.config.path.return_value = CONFIG_FILE
|
|
||||||
service = html5.get_service(hass, {})
|
|
||||||
|
|
||||||
assert service is not None
|
|
||||||
|
|
||||||
# assert hass.called
|
|
||||||
assert len(hass.mock_calls) == 3
|
|
||||||
|
|
||||||
view = hass.mock_calls[2][1][0]
|
|
||||||
|
|
||||||
hass.loop = loop
|
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
view.register(app.router)
|
|
||||||
client = yield from test_client(app)
|
|
||||||
hass.http.is_banned_ip.return_value = False
|
|
||||||
|
|
||||||
resp = yield from client.post(PUBLISH_URL, data=json.dumps({
|
|
||||||
'type': 'push',
|
|
||||||
'tag': '3bc28d69-0921-41f1-ac6a-7a627ba0aa72'
|
|
||||||
}))
|
|
||||||
|
|
||||||
assert resp.status == 401, resp.response
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def test_callback_view_with_jwt(self, loop, test_client):
|
|
||||||
"""Test that the notification callback view works with JWT."""
|
|
||||||
hass = MagicMock()
|
|
||||||
|
|
||||||
data = {
|
|
||||||
'device': SUBSCRIPTION_1
|
|
||||||
}
|
|
||||||
|
|
||||||
m = mock_open(read_data=json.dumps(data))
|
|
||||||
with patch(
|
|
||||||
'homeassistant.util.json.open',
|
|
||||||
m, create=True
|
|
||||||
):
|
|
||||||
hass.config.path.return_value = CONFIG_FILE
|
|
||||||
service = html5.get_service(hass, {'gcm_sender_id': '100'})
|
|
||||||
|
|
||||||
assert service is not None
|
|
||||||
|
|
||||||
# assert hass.called
|
|
||||||
assert len(hass.mock_calls) == 3
|
|
||||||
|
|
||||||
with patch('pywebpush.WebPusher') as mock_wp:
|
|
||||||
service.send_message(
|
|
||||||
'Hello', target=['device'], data={'icon': 'beer.png'})
|
|
||||||
|
|
||||||
assert len(mock_wp.mock_calls) == 3
|
|
||||||
|
|
||||||
# WebPusher constructor
|
|
||||||
assert mock_wp.mock_calls[0][1][0] == \
|
|
||||||
SUBSCRIPTION_1['subscription']
|
|
||||||
# Third mock_call checks the status_code of the response.
|
|
||||||
assert mock_wp.mock_calls[2][0] == '().send().status_code.__eq__'
|
|
||||||
|
|
||||||
# Call to send
|
|
||||||
push_payload = json.loads(mock_wp.mock_calls[1][1][0])
|
|
||||||
|
|
||||||
assert push_payload['body'] == 'Hello'
|
|
||||||
assert push_payload['icon'] == 'beer.png'
|
|
||||||
|
|
||||||
view = hass.mock_calls[2][1][0]
|
|
||||||
view.registrations = data
|
|
||||||
|
|
||||||
bearer_token = "Bearer {}".format(push_payload['data']['jwt'])
|
|
||||||
|
|
||||||
hass.loop = loop
|
|
||||||
app = mock_http_component_app(hass)
|
|
||||||
view.register(app.router)
|
|
||||||
client = yield from test_client(app)
|
|
||||||
hass.http.is_banned_ip.return_value = False
|
|
||||||
|
|
||||||
resp = yield from client.post(PUBLISH_URL, data=json.dumps({
|
|
||||||
'type': 'push',
|
|
||||||
}), headers={AUTHORIZATION: bearer_token})
|
|
||||||
|
|
||||||
assert resp.status == 200
|
|
||||||
body = yield from resp.json()
|
|
||||||
assert body == {"event": "push", "status": "ok"}
|
|
||||||
|
|
|
@ -10,8 +10,7 @@ import homeassistant.util.dt as dt_util
|
||||||
from homeassistant.components import history, recorder
|
from homeassistant.components import history, recorder
|
||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
init_recorder_component, mock_http_component, mock_state_change_event,
|
init_recorder_component, mock_state_change_event, get_test_home_assistant)
|
||||||
get_test_home_assistant)
|
|
||||||
|
|
||||||
|
|
||||||
class TestComponentHistory(unittest.TestCase):
|
class TestComponentHistory(unittest.TestCase):
|
||||||
|
@ -38,7 +37,6 @@ class TestComponentHistory(unittest.TestCase):
|
||||||
|
|
||||||
def test_setup(self):
|
def test_setup(self):
|
||||||
"""Test setup method of history."""
|
"""Test setup method of history."""
|
||||||
mock_http_component(self.hass)
|
|
||||||
config = history.CONFIG_SCHEMA({
|
config = history.CONFIG_SCHEMA({
|
||||||
# ha.DOMAIN: {},
|
# ha.DOMAIN: {},
|
||||||
history.DOMAIN: {
|
history.DOMAIN: {
|
||||||
|
|
|
@ -14,7 +14,7 @@ from homeassistant.components import logbook
|
||||||
from homeassistant.setup import setup_component
|
from homeassistant.setup import setup_component
|
||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
mock_http_component, init_recorder_component, get_test_home_assistant)
|
init_recorder_component, get_test_home_assistant)
|
||||||
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
@ -29,10 +29,7 @@ class TestComponentLogbook(unittest.TestCase):
|
||||||
"""Setup things to be run when tests are started."""
|
"""Setup things to be run when tests are started."""
|
||||||
self.hass = get_test_home_assistant()
|
self.hass = get_test_home_assistant()
|
||||||
init_recorder_component(self.hass) # Force an in memory DB
|
init_recorder_component(self.hass) # Force an in memory DB
|
||||||
mock_http_component(self.hass)
|
assert setup_component(self.hass, logbook.DOMAIN, self.EMPTY_CONFIG)
|
||||||
self.hass.config.components |= set(['frontend', 'recorder', 'api'])
|
|
||||||
assert setup_component(self.hass, logbook.DOMAIN,
|
|
||||||
self.EMPTY_CONFIG)
|
|
||||||
self.hass.start()
|
self.hass.start()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
|
|
@ -150,7 +150,6 @@ def test_api_update_fails(hass, test_client):
|
||||||
assert resp.status == 404
|
assert resp.status == 404
|
||||||
|
|
||||||
beer_id = hass.data['shopping_list'].items[0]['id']
|
beer_id = hass.data['shopping_list'].items[0]['id']
|
||||||
client = yield from test_client(hass.http.app)
|
|
||||||
resp = yield from client.post(
|
resp = yield from client.post(
|
||||||
'/api/shopping_list/item/{}'.format(beer_id), json={
|
'/api/shopping_list/item/{}'.format(beer_id), json={
|
||||||
'name': 123,
|
'name': 123,
|
||||||
|
|
|
@ -8,8 +8,9 @@ import pytest
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.components import websocket_api as wapi, frontend
|
from homeassistant.components import websocket_api as wapi, frontend
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import mock_http_component_app, mock_coro
|
from tests.common import mock_coro
|
||||||
|
|
||||||
API_PASSWORD = 'test1234'
|
API_PASSWORD = 'test1234'
|
||||||
|
|
||||||
|
@ -17,10 +18,10 @@ API_PASSWORD = 'test1234'
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def websocket_client(loop, hass, test_client):
|
def websocket_client(loop, hass, test_client):
|
||||||
"""Websocket client fixture connected to websocket server."""
|
"""Websocket client fixture connected to websocket server."""
|
||||||
websocket_app = mock_http_component_app(hass)
|
assert loop.run_until_complete(
|
||||||
wapi.WebsocketAPIView().register(websocket_app.router)
|
async_setup_component(hass, 'websocket_api'))
|
||||||
|
|
||||||
client = loop.run_until_complete(test_client(websocket_app))
|
client = loop.run_until_complete(test_client(hass.http.app))
|
||||||
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
|
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
|
||||||
|
|
||||||
auth_ok = loop.run_until_complete(ws.receive_json())
|
auth_ok = loop.run_until_complete(ws.receive_json())
|
||||||
|
@ -35,10 +36,14 @@ def websocket_client(loop, hass, test_client):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def no_auth_websocket_client(hass, loop, test_client):
|
def no_auth_websocket_client(hass, loop, test_client):
|
||||||
"""Websocket connection that requires authentication."""
|
"""Websocket connection that requires authentication."""
|
||||||
websocket_app = mock_http_component_app(hass, API_PASSWORD)
|
assert loop.run_until_complete(
|
||||||
wapi.WebsocketAPIView().register(websocket_app.router)
|
async_setup_component(hass, 'websocket_api', {
|
||||||
|
'http': {
|
||||||
|
'api_password': API_PASSWORD
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
client = loop.run_until_complete(test_client(websocket_app))
|
client = loop.run_until_complete(test_client(hass.http.app))
|
||||||
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
|
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
|
||||||
|
|
||||||
auth_ok = loop.run_until_complete(ws.receive_json())
|
auth_ok = loop.run_until_complete(ws.receive_json())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue