Extract data validator to own file and add tests (#12401)

This commit is contained in:
Paulus Schoutsen 2018-02-14 12:06:03 -08:00 committed by Pascal Vizeli
parent 416f64fc70
commit 78c44180f4
7 changed files with 139 additions and 46 deletions

View file

@ -6,8 +6,9 @@ import logging
import async_timeout
import voluptuous as vol
from homeassistant.components.http import (
HomeAssistantView, RequestDataValidator)
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.data_validator import (
RequestDataValidator)
from . import auth_api
from .const import DOMAIN, REQUEST_TIMEOUT

View file

@ -12,6 +12,8 @@ import voluptuous as vol
from homeassistant import core
from homeassistant.components import http
from homeassistant.components.http.data_validator import (
RequestDataValidator)
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers import intent
@ -148,7 +150,7 @@ class ConversationProcessView(http.HomeAssistantView):
url = '/api/conversation/process'
name = "api:conversation:process"
@http.RequestDataValidator(vol.Schema({
@RequestDataValidator(vol.Schema({
vol.Required('text'): str,
}))
@asyncio.coroutine

View file

@ -5,7 +5,6 @@ For more details about this component, please refer to the documentation at
https://home-assistant.io/components/http/
"""
import asyncio
from functools import wraps
from ipaddress import ip_network
import json
import logging
@ -415,14 +414,13 @@ def request_handler_factory(view, handler):
if not request.app['hass'].is_running:
return web.Response(status=503)
remote_addr = get_real_ip(request)
authenticated = request.get(KEY_AUTHENTICATED, False)
if view.requires_auth and not authenticated:
raise HTTPUnauthorized()
_LOGGER.info('Serving %s to %s (auth: %s)',
request.path, remote_addr, authenticated)
request.path, get_real_ip(request), authenticated)
result = handler(request, **request.match_info)
@ -449,41 +447,3 @@ def request_handler_factory(view, handler):
return web.Response(body=result, status=status_code)
return handle
class RequestDataValidator:
"""Decorator that will validate the incoming data.
Takes in a voluptuous schema and adds 'post_data' as
keyword argument to the function call.
Will return a 400 if no JSON provided or doesn't match schema.
"""
def __init__(self, schema):
"""Initialize the decorator."""
self._schema = schema
def __call__(self, method):
"""Decorate a function."""
@asyncio.coroutine
@wraps(method)
def wrapper(view, request, *args, **kwargs):
"""Wrap a request handler with data validation."""
try:
data = yield from request.json()
except ValueError:
_LOGGER.error('Invalid JSON received.')
return view.json_message('Invalid JSON.', 400)
try:
kwargs['data'] = self._schema(data)
except vol.Invalid as err:
_LOGGER.error('Data does not match schema: %s', err)
return view.json_message(
'Message format incorrect: {}'.format(err), 400)
result = yield from method(view, request, *args, **kwargs)
return result
return wrapper

View file

@ -0,0 +1,51 @@
"""Decorator for view methods to help with data validation."""
import asyncio
from functools import wraps
import logging
import voluptuous as vol
_LOGGER = logging.getLogger(__name__)
class RequestDataValidator:
"""Decorator that will validate the incoming data.
Takes in a voluptuous schema and adds 'post_data' as
keyword argument to the function call.
Will return a 400 if no JSON provided or doesn't match schema.
"""
def __init__(self, schema, allow_empty=False):
"""Initialize the decorator."""
self._schema = schema
self._allow_empty = allow_empty
def __call__(self, method):
"""Decorate a function."""
@asyncio.coroutine
@wraps(method)
def wrapper(view, request, *args, **kwargs):
"""Wrap a request handler with data validation."""
data = None
try:
data = yield from request.json()
except ValueError:
if not self._allow_empty or \
(yield from request.content.read()) != b'':
_LOGGER.error('Invalid JSON received.')
return view.json_message('Invalid JSON.', 400)
data = {}
try:
kwargs['data'] = self._schema(data)
except vol.Invalid as err:
_LOGGER.error('Data does not match schema: %s', err)
return view.json_message(
'Message format incorrect: {}'.format(err), 400)
result = yield from method(view, request, *args, **kwargs)
return result
return wrapper

View file

@ -10,7 +10,7 @@ def get_real_ip(request):
if KEY_REAL_IP in request:
return request[KEY_REAL_IP]
if (request.app[KEY_USE_X_FORWARDED_FOR] and
if (request.app.get(KEY_USE_X_FORWARDED_FOR) and
HTTP_HEADER_X_FORWARDED_FOR in request.headers):
request[KEY_REAL_IP] = ip_address(
request.headers.get(HTTP_HEADER_X_FORWARDED_FOR).split(',')[0])

View file

@ -10,6 +10,8 @@ import voluptuous as vol
from homeassistant.const import HTTP_NOT_FOUND, HTTP_BAD_REQUEST
from homeassistant.core import callback
from homeassistant.components import http
from homeassistant.components.http.data_validator import (
RequestDataValidator)
from homeassistant.helpers import intent
import homeassistant.helpers.config_validation as cv
@ -199,7 +201,7 @@ class CreateShoppingListItemView(http.HomeAssistantView):
url = '/api/shopping_list/item'
name = "api:shopping_list:item"
@http.RequestDataValidator(vol.Schema({
@RequestDataValidator(vol.Schema({
vol.Required('name'): str,
}))
@asyncio.coroutine

View file

@ -0,0 +1,77 @@
"""Test data validator decorator."""
import asyncio
from unittest.mock import Mock
from aiohttp import web
import voluptuous as vol
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.data_validator import RequestDataValidator
@asyncio.coroutine
def get_client(test_client, validator):
"""Generate a client that hits a view decorated with validator."""
app = web.Application()
app['hass'] = Mock(is_running=True)
class TestView(HomeAssistantView):
url = '/'
name = 'test'
requires_auth = False
@asyncio.coroutine
@validator
def post(self, request, data):
"""Test method."""
return b''
TestView().register(app.router)
client = yield from test_client(app)
return client
@asyncio.coroutine
def test_validator(test_client):
"""Test the validator."""
client = yield from get_client(
test_client, RequestDataValidator(vol.Schema({
vol.Required('test'): str
})))
resp = yield from client.post('/', json={
'test': 'bla'
})
assert resp.status == 200
resp = yield from client.post('/', json={
'test': 100
})
assert resp.status == 400
resp = yield from client.post('/')
assert resp.status == 400
@asyncio.coroutine
def test_validator_allow_empty(test_client):
"""Test the validator with empty data."""
client = yield from get_client(
test_client, RequestDataValidator(vol.Schema({
# Although we allow empty, our schema should still be able
# to validate an empty dict.
vol.Optional('test'): str
}), allow_empty=True))
resp = yield from client.post('/', json={
'test': 'bla'
})
assert resp.status == 200
resp = yield from client.post('/', json={
'test': 100
})
assert resp.status == 400
resp = yield from client.post('/')
assert resp.status == 200