Extract data validator to own file and add tests (#12401)
This commit is contained in:
parent
416f64fc70
commit
78c44180f4
7 changed files with 139 additions and 46 deletions
|
@ -5,7 +5,6 @@ For more details about this component, please refer to the documentation at
|
|||
https://home-assistant.io/components/http/
|
||||
"""
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from ipaddress import ip_network
|
||||
import json
|
||||
import logging
|
||||
|
@ -415,14 +414,13 @@ def request_handler_factory(view, handler):
|
|||
if not request.app['hass'].is_running:
|
||||
return web.Response(status=503)
|
||||
|
||||
remote_addr = get_real_ip(request)
|
||||
authenticated = request.get(KEY_AUTHENTICATED, False)
|
||||
|
||||
if view.requires_auth and not authenticated:
|
||||
raise HTTPUnauthorized()
|
||||
|
||||
_LOGGER.info('Serving %s to %s (auth: %s)',
|
||||
request.path, remote_addr, authenticated)
|
||||
request.path, get_real_ip(request), authenticated)
|
||||
|
||||
result = handler(request, **request.match_info)
|
||||
|
||||
|
@ -449,41 +447,3 @@ def request_handler_factory(view, handler):
|
|||
return web.Response(body=result, status=status_code)
|
||||
|
||||
return handle
|
||||
|
||||
|
||||
class RequestDataValidator:
|
||||
"""Decorator that will validate the incoming data.
|
||||
|
||||
Takes in a voluptuous schema and adds 'post_data' as
|
||||
keyword argument to the function call.
|
||||
|
||||
Will return a 400 if no JSON provided or doesn't match schema.
|
||||
"""
|
||||
|
||||
def __init__(self, schema):
|
||||
"""Initialize the decorator."""
|
||||
self._schema = schema
|
||||
|
||||
def __call__(self, method):
|
||||
"""Decorate a function."""
|
||||
@asyncio.coroutine
|
||||
@wraps(method)
|
||||
def wrapper(view, request, *args, **kwargs):
|
||||
"""Wrap a request handler with data validation."""
|
||||
try:
|
||||
data = yield from request.json()
|
||||
except ValueError:
|
||||
_LOGGER.error('Invalid JSON received.')
|
||||
return view.json_message('Invalid JSON.', 400)
|
||||
|
||||
try:
|
||||
kwargs['data'] = self._schema(data)
|
||||
except vol.Invalid as err:
|
||||
_LOGGER.error('Data does not match schema: %s', err)
|
||||
return view.json_message(
|
||||
'Message format incorrect: {}'.format(err), 400)
|
||||
|
||||
result = yield from method(view, request, *args, **kwargs)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
|
51
homeassistant/components/http/data_validator.py
Normal file
51
homeassistant/components/http/data_validator.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
"""Decorator for view methods to help with data validation."""
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RequestDataValidator:
|
||||
"""Decorator that will validate the incoming data.
|
||||
|
||||
Takes in a voluptuous schema and adds 'post_data' as
|
||||
keyword argument to the function call.
|
||||
|
||||
Will return a 400 if no JSON provided or doesn't match schema.
|
||||
"""
|
||||
|
||||
def __init__(self, schema, allow_empty=False):
|
||||
"""Initialize the decorator."""
|
||||
self._schema = schema
|
||||
self._allow_empty = allow_empty
|
||||
|
||||
def __call__(self, method):
|
||||
"""Decorate a function."""
|
||||
@asyncio.coroutine
|
||||
@wraps(method)
|
||||
def wrapper(view, request, *args, **kwargs):
|
||||
"""Wrap a request handler with data validation."""
|
||||
data = None
|
||||
try:
|
||||
data = yield from request.json()
|
||||
except ValueError:
|
||||
if not self._allow_empty or \
|
||||
(yield from request.content.read()) != b'':
|
||||
_LOGGER.error('Invalid JSON received.')
|
||||
return view.json_message('Invalid JSON.', 400)
|
||||
data = {}
|
||||
|
||||
try:
|
||||
kwargs['data'] = self._schema(data)
|
||||
except vol.Invalid as err:
|
||||
_LOGGER.error('Data does not match schema: %s', err)
|
||||
return view.json_message(
|
||||
'Message format incorrect: {}'.format(err), 400)
|
||||
|
||||
result = yield from method(view, request, *args, **kwargs)
|
||||
return result
|
||||
|
||||
return wrapper
|
|
@ -10,7 +10,7 @@ def get_real_ip(request):
|
|||
if KEY_REAL_IP in request:
|
||||
return request[KEY_REAL_IP]
|
||||
|
||||
if (request.app[KEY_USE_X_FORWARDED_FOR] and
|
||||
if (request.app.get(KEY_USE_X_FORWARDED_FOR) and
|
||||
HTTP_HEADER_X_FORWARDED_FOR in request.headers):
|
||||
request[KEY_REAL_IP] = ip_address(
|
||||
request.headers.get(HTTP_HEADER_X_FORWARDED_FOR).split(',')[0])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue