Reorganize HTTP component (#4575)
* Move HTTP to own folder * Break HTTP into middlewares * Lint * Split tests per middleware * Clean up HTTP tests * Make HomeAssistantViews more stateless * Lint * Make HTTP setup async
This commit is contained in:
parent
58b85b2e0e
commit
32ffd006fa
35 changed files with 1318 additions and 1084 deletions
|
@ -118,7 +118,7 @@ class AlexaIntentsView(HomeAssistantView):
|
||||||
|
|
||||||
def __init__(self, hass, intents):
|
def __init__(self, hass, intents):
|
||||||
"""Initialize Alexa view."""
|
"""Initialize Alexa view."""
|
||||||
super().__init__(hass)
|
super().__init__()
|
||||||
|
|
||||||
intents = copy.deepcopy(intents)
|
intents = copy.deepcopy(intents)
|
||||||
template.attach(hass, intents)
|
template.attach(hass, intents)
|
||||||
|
@ -150,7 +150,7 @@ class AlexaIntentsView(HomeAssistantView):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
intent = req.get('intent')
|
intent = req.get('intent')
|
||||||
response = AlexaResponse(self.hass, intent)
|
response = AlexaResponse(request.app['hass'], intent)
|
||||||
|
|
||||||
if req_type == 'LaunchRequest':
|
if req_type == 'LaunchRequest':
|
||||||
response.add_speech(
|
response.add_speech(
|
||||||
|
@ -282,7 +282,7 @@ class AlexaFlashBriefingView(HomeAssistantView):
|
||||||
|
|
||||||
def __init__(self, hass, flash_briefings):
|
def __init__(self, hass, flash_briefings):
|
||||||
"""Initialize Alexa view."""
|
"""Initialize Alexa view."""
|
||||||
super().__init__(hass)
|
super().__init__()
|
||||||
self.flash_briefings = copy.deepcopy(flash_briefings)
|
self.flash_briefings = copy.deepcopy(flash_briefings)
|
||||||
template.attach(hass, self.flash_briefings)
|
template.attach(hass, self.flash_briefings)
|
||||||
|
|
||||||
|
|
|
@ -77,8 +77,10 @@ class APIEventStream(HomeAssistantView):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
"""Provide a streaming interface for the event bus."""
|
"""Provide a streaming interface for the event bus."""
|
||||||
|
# pylint: disable=no-self-use
|
||||||
|
hass = request.app['hass']
|
||||||
stop_obj = object()
|
stop_obj = object()
|
||||||
to_write = asyncio.Queue(loop=self.hass.loop)
|
to_write = asyncio.Queue(loop=hass.loop)
|
||||||
|
|
||||||
restrict = request.GET.get('restrict')
|
restrict = request.GET.get('restrict')
|
||||||
if restrict:
|
if restrict:
|
||||||
|
@ -106,7 +108,7 @@ class APIEventStream(HomeAssistantView):
|
||||||
response.content_type = 'text/event-stream'
|
response.content_type = 'text/event-stream'
|
||||||
yield from response.prepare(request)
|
yield from response.prepare(request)
|
||||||
|
|
||||||
unsub_stream = self.hass.bus.async_listen(MATCH_ALL, forward_events)
|
unsub_stream = hass.bus.async_listen(MATCH_ALL, forward_events)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_LOGGER.debug('STREAM %s ATTACHED', id(stop_obj))
|
_LOGGER.debug('STREAM %s ATTACHED', id(stop_obj))
|
||||||
|
@ -117,7 +119,7 @@ class APIEventStream(HomeAssistantView):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
with async_timeout.timeout(STREAM_PING_INTERVAL,
|
with async_timeout.timeout(STREAM_PING_INTERVAL,
|
||||||
loop=self.hass.loop):
|
loop=hass.loop):
|
||||||
payload = yield from to_write.get()
|
payload = yield from to_write.get()
|
||||||
|
|
||||||
if payload is stop_obj:
|
if payload is stop_obj:
|
||||||
|
@ -145,7 +147,7 @@ class APIConfigView(HomeAssistantView):
|
||||||
@ha.callback
|
@ha.callback
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
"""Get current configuration."""
|
"""Get current configuration."""
|
||||||
return self.json(self.hass.config.as_dict())
|
return self.json(request.app['hass'].config.as_dict())
|
||||||
|
|
||||||
|
|
||||||
class APIDiscoveryView(HomeAssistantView):
|
class APIDiscoveryView(HomeAssistantView):
|
||||||
|
@ -158,10 +160,11 @@ class APIDiscoveryView(HomeAssistantView):
|
||||||
@ha.callback
|
@ha.callback
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
"""Get discovery info."""
|
"""Get discovery info."""
|
||||||
needs_auth = self.hass.config.api.api_password is not None
|
hass = request.app['hass']
|
||||||
|
needs_auth = hass.config.api.api_password is not None
|
||||||
return self.json({
|
return self.json({
|
||||||
'base_url': self.hass.config.api.base_url,
|
'base_url': hass.config.api.base_url,
|
||||||
'location_name': self.hass.config.location_name,
|
'location_name': hass.config.location_name,
|
||||||
'requires_api_password': needs_auth,
|
'requires_api_password': needs_auth,
|
||||||
'version': __version__
|
'version': __version__
|
||||||
})
|
})
|
||||||
|
@ -176,7 +179,7 @@ class APIStatesView(HomeAssistantView):
|
||||||
@ha.callback
|
@ha.callback
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
"""Get current states."""
|
"""Get current states."""
|
||||||
return self.json(self.hass.states.async_all())
|
return self.json(request.app['hass'].states.async_all())
|
||||||
|
|
||||||
|
|
||||||
class APIEntityStateView(HomeAssistantView):
|
class APIEntityStateView(HomeAssistantView):
|
||||||
|
@ -188,7 +191,7 @@ class APIEntityStateView(HomeAssistantView):
|
||||||
@ha.callback
|
@ha.callback
|
||||||
def get(self, request, entity_id):
|
def get(self, request, entity_id):
|
||||||
"""Retrieve state of entity."""
|
"""Retrieve state of entity."""
|
||||||
state = self.hass.states.get(entity_id)
|
state = request.app['hass'].states.get(entity_id)
|
||||||
if state:
|
if state:
|
||||||
return self.json(state)
|
return self.json(state)
|
||||||
else:
|
else:
|
||||||
|
@ -197,6 +200,7 @@ class APIEntityStateView(HomeAssistantView):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def post(self, request, entity_id):
|
def post(self, request, entity_id):
|
||||||
"""Update state of entity."""
|
"""Update state of entity."""
|
||||||
|
hass = request.app['hass']
|
||||||
try:
|
try:
|
||||||
data = yield from request.json()
|
data = yield from request.json()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -211,15 +215,14 @@ class APIEntityStateView(HomeAssistantView):
|
||||||
attributes = data.get('attributes')
|
attributes = data.get('attributes')
|
||||||
force_update = data.get('force_update', False)
|
force_update = data.get('force_update', False)
|
||||||
|
|
||||||
is_new_state = self.hass.states.get(entity_id) is None
|
is_new_state = hass.states.get(entity_id) is None
|
||||||
|
|
||||||
# Write state
|
# Write state
|
||||||
self.hass.states.async_set(entity_id, new_state, attributes,
|
hass.states.async_set(entity_id, new_state, attributes, force_update)
|
||||||
force_update)
|
|
||||||
|
|
||||||
# Read the state back for our response
|
# Read the state back for our response
|
||||||
status_code = HTTP_CREATED if is_new_state else 200
|
status_code = HTTP_CREATED if is_new_state else 200
|
||||||
resp = self.json(self.hass.states.get(entity_id), status_code)
|
resp = self.json(hass.states.get(entity_id), status_code)
|
||||||
|
|
||||||
resp.headers.add('Location', URL_API_STATES_ENTITY.format(entity_id))
|
resp.headers.add('Location', URL_API_STATES_ENTITY.format(entity_id))
|
||||||
|
|
||||||
|
@ -228,7 +231,7 @@ class APIEntityStateView(HomeAssistantView):
|
||||||
@ha.callback
|
@ha.callback
|
||||||
def delete(self, request, entity_id):
|
def delete(self, request, entity_id):
|
||||||
"""Remove entity."""
|
"""Remove entity."""
|
||||||
if self.hass.states.async_remove(entity_id):
|
if request.app['hass'].states.async_remove(entity_id):
|
||||||
return self.json_message('Entity removed')
|
return self.json_message('Entity removed')
|
||||||
else:
|
else:
|
||||||
return self.json_message('Entity not found', HTTP_NOT_FOUND)
|
return self.json_message('Entity not found', HTTP_NOT_FOUND)
|
||||||
|
@ -243,7 +246,7 @@ class APIEventListenersView(HomeAssistantView):
|
||||||
@ha.callback
|
@ha.callback
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
"""Get event listeners."""
|
"""Get event listeners."""
|
||||||
return self.json(async_events_json(self.hass))
|
return self.json(async_events_json(request.app['hass']))
|
||||||
|
|
||||||
|
|
||||||
class APIEventView(HomeAssistantView):
|
class APIEventView(HomeAssistantView):
|
||||||
|
@ -271,7 +274,8 @@ class APIEventView(HomeAssistantView):
|
||||||
if state:
|
if state:
|
||||||
event_data[key] = state
|
event_data[key] = state
|
||||||
|
|
||||||
self.hass.bus.async_fire(event_type, event_data, ha.EventOrigin.remote)
|
request.app['hass'].bus.async_fire(event_type, event_data,
|
||||||
|
ha.EventOrigin.remote)
|
||||||
|
|
||||||
return self.json_message("Event {} fired.".format(event_type))
|
return self.json_message("Event {} fired.".format(event_type))
|
||||||
|
|
||||||
|
@ -285,7 +289,7 @@ class APIServicesView(HomeAssistantView):
|
||||||
@ha.callback
|
@ha.callback
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
"""Get registered services."""
|
"""Get registered services."""
|
||||||
return self.json(async_services_json(self.hass))
|
return self.json(async_services_json(request.app['hass']))
|
||||||
|
|
||||||
|
|
||||||
class APIDomainServicesView(HomeAssistantView):
|
class APIDomainServicesView(HomeAssistantView):
|
||||||
|
@ -300,12 +304,12 @@ class APIDomainServicesView(HomeAssistantView):
|
||||||
|
|
||||||
Returns a list of changed states.
|
Returns a list of changed states.
|
||||||
"""
|
"""
|
||||||
|
hass = request.app['hass']
|
||||||
body = yield from request.text()
|
body = yield from request.text()
|
||||||
data = json.loads(body) if body else None
|
data = json.loads(body) if body else None
|
||||||
|
|
||||||
with AsyncTrackStates(self.hass) as changed_states:
|
with AsyncTrackStates(hass) as changed_states:
|
||||||
yield from self.hass.services.async_call(domain, service, data,
|
yield from hass.services.async_call(domain, service, data, True)
|
||||||
True)
|
|
||||||
|
|
||||||
return self.json(changed_states)
|
return self.json(changed_states)
|
||||||
|
|
||||||
|
@ -320,6 +324,7 @@ class APIEventForwardingView(HomeAssistantView):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
"""Setup an event forwarder."""
|
"""Setup an event forwarder."""
|
||||||
|
hass = request.app['hass']
|
||||||
try:
|
try:
|
||||||
data = yield from request.json()
|
data = yield from request.json()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -340,14 +345,14 @@ class APIEventForwardingView(HomeAssistantView):
|
||||||
|
|
||||||
api = rem.API(host, api_password, port)
|
api = rem.API(host, api_password, port)
|
||||||
|
|
||||||
valid = yield from self.hass.loop.run_in_executor(
|
valid = yield from hass.loop.run_in_executor(
|
||||||
None, api.validate_api)
|
None, api.validate_api)
|
||||||
if not valid:
|
if not valid:
|
||||||
return self.json_message("Unable to validate API.",
|
return self.json_message("Unable to validate API.",
|
||||||
HTTP_UNPROCESSABLE_ENTITY)
|
HTTP_UNPROCESSABLE_ENTITY)
|
||||||
|
|
||||||
if self.event_forwarder is None:
|
if self.event_forwarder is None:
|
||||||
self.event_forwarder = rem.EventForwarder(self.hass)
|
self.event_forwarder = rem.EventForwarder(hass)
|
||||||
|
|
||||||
self.event_forwarder.async_connect(api)
|
self.event_forwarder.async_connect(api)
|
||||||
|
|
||||||
|
@ -389,7 +394,7 @@ class APIComponentsView(HomeAssistantView):
|
||||||
@ha.callback
|
@ha.callback
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
"""Get current loaded components."""
|
"""Get current loaded components."""
|
||||||
return self.json(self.hass.config.components)
|
return self.json(request.app['hass'].config.components)
|
||||||
|
|
||||||
|
|
||||||
class APIErrorLogView(HomeAssistantView):
|
class APIErrorLogView(HomeAssistantView):
|
||||||
|
@ -402,7 +407,7 @@ class APIErrorLogView(HomeAssistantView):
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
"""Serve error log."""
|
"""Serve error log."""
|
||||||
resp = yield from self.file(
|
resp = yield from self.file(
|
||||||
request, self.hass.config.path(ERROR_LOG_FILENAME))
|
request, request.app['hass'].config.path(ERROR_LOG_FILENAME))
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@ -417,7 +422,7 @@ class APITemplateView(HomeAssistantView):
|
||||||
"""Render a template."""
|
"""Render a template."""
|
||||||
try:
|
try:
|
||||||
data = yield from request.json()
|
data = yield from request.json()
|
||||||
tpl = template.Template(data['template'], self.hass)
|
tpl = template.Template(data['template'], request.app['hass'])
|
||||||
return tpl.async_render(data.get('variables'))
|
return tpl.async_render(data.get('variables'))
|
||||||
except (ValueError, TemplateError) as ex:
|
except (ValueError, TemplateError) as ex:
|
||||||
return self.json_message('Error rendering template: {}'.format(ex),
|
return self.json_message('Error rendering template: {}'.format(ex),
|
||||||
|
|
|
@ -13,7 +13,7 @@ from aiohttp import web
|
||||||
from homeassistant.helpers.entity import Entity
|
from homeassistant.helpers.entity import Entity
|
||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa
|
from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView, KEY_AUTHENTICATED
|
||||||
|
|
||||||
DOMAIN = 'camera'
|
DOMAIN = 'camera'
|
||||||
DEPENDENCIES = ['http']
|
DEPENDENCIES = ['http']
|
||||||
|
@ -33,8 +33,8 @@ def async_setup(hass, config):
|
||||||
component = EntityComponent(
|
component = EntityComponent(
|
||||||
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL)
|
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL)
|
||||||
|
|
||||||
hass.http.register_view(CameraImageView(hass, component.entities))
|
hass.http.register_view(CameraImageView(component.entities))
|
||||||
hass.http.register_view(CameraMjpegStream(hass, component.entities))
|
hass.http.register_view(CameraMjpegStream(component.entities))
|
||||||
|
|
||||||
yield from component.async_setup(config)
|
yield from component.async_setup(config)
|
||||||
return True
|
return True
|
||||||
|
@ -165,9 +165,8 @@ class CameraView(HomeAssistantView):
|
||||||
|
|
||||||
requires_auth = False
|
requires_auth = False
|
||||||
|
|
||||||
def __init__(self, hass, entities):
|
def __init__(self, entities):
|
||||||
"""Initialize a basic camera view."""
|
"""Initialize a basic camera view."""
|
||||||
super().__init__(hass)
|
|
||||||
self.entities = entities
|
self.entities = entities
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
@ -178,7 +177,7 @@ class CameraView(HomeAssistantView):
|
||||||
if camera is None:
|
if camera is None:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
authenticated = (request.authenticated or
|
authenticated = (request[KEY_AUTHENTICATED] or
|
||||||
request.GET.get('token') == camera.access_token)
|
request.GET.get('token') == camera.access_token)
|
||||||
|
|
||||||
if not authenticated:
|
if not authenticated:
|
||||||
|
|
|
@ -21,7 +21,7 @@ DEPENDENCIES = ['http']
|
||||||
|
|
||||||
def setup_scanner(hass, config, see):
|
def setup_scanner(hass, config, see):
|
||||||
"""Setup an endpoint for the GPSLogger application."""
|
"""Setup an endpoint for the GPSLogger application."""
|
||||||
hass.http.register_view(GPSLoggerView(hass, see))
|
hass.http.register_view(GPSLoggerView(see))
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -32,20 +32,18 @@ class GPSLoggerView(HomeAssistantView):
|
||||||
url = '/api/gpslogger'
|
url = '/api/gpslogger'
|
||||||
name = 'api:gpslogger'
|
name = 'api:gpslogger'
|
||||||
|
|
||||||
def __init__(self, hass, see):
|
def __init__(self, see):
|
||||||
"""Initialize GPSLogger url endpoints."""
|
"""Initialize GPSLogger url endpoints."""
|
||||||
super().__init__(hass)
|
|
||||||
self.see = see
|
self.see = see
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
"""A GPSLogger message received as GET."""
|
"""A GPSLogger message received as GET."""
|
||||||
res = yield from self._handle(request.GET)
|
res = yield from self._handle(request.app['hass'], request.GET)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
# pylint: disable=too-many-return-statements
|
def _handle(self, hass, data):
|
||||||
def _handle(self, data):
|
|
||||||
"""Handle gpslogger request."""
|
"""Handle gpslogger request."""
|
||||||
if 'latitude' not in data or 'longitude' not in data:
|
if 'latitude' not in data or 'longitude' not in data:
|
||||||
return ('Latitude and longitude not specified.',
|
return ('Latitude and longitude not specified.',
|
||||||
|
@ -66,7 +64,7 @@ class GPSLoggerView(HomeAssistantView):
|
||||||
if 'battery' in data:
|
if 'battery' in data:
|
||||||
battery = float(data['battery'])
|
battery = float(data['battery'])
|
||||||
|
|
||||||
yield from self.hass.loop.run_in_executor(
|
yield from hass.loop.run_in_executor(
|
||||||
None, partial(self.see, dev_id=device,
|
None, partial(self.see, dev_id=device,
|
||||||
gps=gps_location, battery=battery,
|
gps=gps_location, battery=battery,
|
||||||
gps_accuracy=accuracy))
|
gps_accuracy=accuracy))
|
||||||
|
|
|
@ -23,7 +23,7 @@ DEPENDENCIES = ['http']
|
||||||
|
|
||||||
def setup_scanner(hass, config, see):
|
def setup_scanner(hass, config, see):
|
||||||
"""Setup an endpoint for the Locative application."""
|
"""Setup an endpoint for the Locative application."""
|
||||||
hass.http.register_view(LocativeView(hass, see))
|
hass.http.register_view(LocativeView(see))
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -34,27 +34,26 @@ class LocativeView(HomeAssistantView):
|
||||||
url = '/api/locative'
|
url = '/api/locative'
|
||||||
name = 'api:locative'
|
name = 'api:locative'
|
||||||
|
|
||||||
def __init__(self, hass, see):
|
def __init__(self, see):
|
||||||
"""Initialize Locative url endpoints."""
|
"""Initialize Locative url endpoints."""
|
||||||
super().__init__(hass)
|
|
||||||
self.see = see
|
self.see = see
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
"""Locative message received as GET."""
|
"""Locative message received as GET."""
|
||||||
res = yield from self._handle(request.GET)
|
res = yield from self._handle(request.app['hass'], request.GET)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
"""Locative message received."""
|
"""Locative message received."""
|
||||||
data = yield from request.post()
|
data = yield from request.post()
|
||||||
res = yield from self._handle(data)
|
res = yield from self._handle(request.app['hass'], data)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
# pylint: disable=too-many-return-statements
|
# pylint: disable=too-many-return-statements
|
||||||
def _handle(self, data):
|
def _handle(self, hass, data):
|
||||||
"""Handle locative request."""
|
"""Handle locative request."""
|
||||||
if 'latitude' not in data or 'longitude' not in data:
|
if 'latitude' not in data or 'longitude' not in data:
|
||||||
return ('Latitude and longitude not specified.',
|
return ('Latitude and longitude not specified.',
|
||||||
|
@ -81,19 +80,19 @@ class LocativeView(HomeAssistantView):
|
||||||
gps_location = (data[ATTR_LATITUDE], data[ATTR_LONGITUDE])
|
gps_location = (data[ATTR_LATITUDE], data[ATTR_LONGITUDE])
|
||||||
|
|
||||||
if direction == 'enter':
|
if direction == 'enter':
|
||||||
yield from self.hass.loop.run_in_executor(
|
yield from hass.loop.run_in_executor(
|
||||||
None, partial(self.see, dev_id=device,
|
None, partial(self.see, dev_id=device,
|
||||||
location_name=location_name,
|
location_name=location_name,
|
||||||
gps=gps_location))
|
gps=gps_location))
|
||||||
return 'Setting location to {}'.format(location_name)
|
return 'Setting location to {}'.format(location_name)
|
||||||
|
|
||||||
elif direction == 'exit':
|
elif direction == 'exit':
|
||||||
current_state = self.hass.states.get(
|
current_state = hass.states.get(
|
||||||
'{}.{}'.format(DOMAIN, device))
|
'{}.{}'.format(DOMAIN, device))
|
||||||
|
|
||||||
if current_state is None or current_state.state == location_name:
|
if current_state is None or current_state.state == location_name:
|
||||||
location_name = STATE_NOT_HOME
|
location_name = STATE_NOT_HOME
|
||||||
yield from self.hass.loop.run_in_executor(
|
yield from hass.loop.run_in_executor(
|
||||||
None, partial(self.see, dev_id=device,
|
None, partial(self.see, dev_id=device,
|
||||||
location_name=location_name,
|
location_name=location_name,
|
||||||
gps=gps_location))
|
gps=gps_location))
|
||||||
|
|
|
@ -78,14 +78,13 @@ def setup(hass, yaml_config):
|
||||||
cors_origins=None,
|
cors_origins=None,
|
||||||
use_x_forwarded_for=False,
|
use_x_forwarded_for=False,
|
||||||
trusted_networks=None,
|
trusted_networks=None,
|
||||||
ip_bans=None,
|
|
||||||
login_threshold=0,
|
login_threshold=0,
|
||||||
is_ban_enabled=False
|
is_ban_enabled=False
|
||||||
)
|
)
|
||||||
|
|
||||||
server.register_view(DescriptionXmlView(hass, config))
|
server.register_view(DescriptionXmlView(config))
|
||||||
server.register_view(HueUsernameView(hass))
|
server.register_view(HueUsernameView)
|
||||||
server.register_view(HueLightsView(hass, config))
|
server.register_view(HueLightsView(config))
|
||||||
|
|
||||||
upnp_listener = UPNPResponderThread(
|
upnp_listener = UPNPResponderThread(
|
||||||
config.host_ip_addr, config.listen_port)
|
config.host_ip_addr, config.listen_port)
|
||||||
|
@ -157,9 +156,8 @@ class DescriptionXmlView(HomeAssistantView):
|
||||||
name = 'description:xml'
|
name = 'description:xml'
|
||||||
requires_auth = False
|
requires_auth = False
|
||||||
|
|
||||||
def __init__(self, hass, config):
|
def __init__(self, config):
|
||||||
"""Initialize the instance of the view."""
|
"""Initialize the instance of the view."""
|
||||||
super().__init__(hass)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
|
@ -201,10 +199,6 @@ class HueUsernameView(HomeAssistantView):
|
||||||
extra_urls = ['/api/']
|
extra_urls = ['/api/']
|
||||||
requires_auth = False
|
requires_auth = False
|
||||||
|
|
||||||
def __init__(self, hass):
|
|
||||||
"""Initialize the instance of the view."""
|
|
||||||
super().__init__(hass)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
"""Handle a POST request."""
|
"""Handle a POST request."""
|
||||||
|
@ -229,30 +223,33 @@ class HueLightsView(HomeAssistantView):
|
||||||
'/api/{username}/lights/{entity_id}/state']
|
'/api/{username}/lights/{entity_id}/state']
|
||||||
requires_auth = False
|
requires_auth = False
|
||||||
|
|
||||||
def __init__(self, hass, config):
|
def __init__(self, config):
|
||||||
"""Initialize the instance of the view."""
|
"""Initialize the instance of the view."""
|
||||||
super().__init__(hass)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.cached_states = {}
|
self.cached_states = {}
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
def get(self, request, username, entity_id=None):
|
def get(self, request, username, entity_id=None):
|
||||||
"""Handle a GET request."""
|
"""Handle a GET request."""
|
||||||
|
hass = request.app['hass']
|
||||||
|
|
||||||
if entity_id is None:
|
if entity_id is None:
|
||||||
return self.async_get_lights_list()
|
return self.async_get_lights_list(hass)
|
||||||
|
|
||||||
if not request.path.endswith('state'):
|
if not request.path.endswith('state'):
|
||||||
return self.async_get_light_state(entity_id)
|
return self.async_get_light_state(hass, entity_id)
|
||||||
|
|
||||||
return web.Response(text="Method not allowed", status=405)
|
return web.Response(text="Method not allowed", status=405)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def put(self, request, username, entity_id=None):
|
def put(self, request, username, entity_id=None):
|
||||||
"""Handle a PUT request."""
|
"""Handle a PUT request."""
|
||||||
|
hass = request.app['hass']
|
||||||
|
|
||||||
if not request.path.endswith('state'):
|
if not request.path.endswith('state'):
|
||||||
return web.Response(text="Method not allowed", status=405)
|
return web.Response(text="Method not allowed", status=405)
|
||||||
|
|
||||||
if entity_id and self.hass.states.get(entity_id) is None:
|
if entity_id and hass.states.get(entity_id) is None:
|
||||||
return self.json_message('Entity not found', HTTP_NOT_FOUND)
|
return self.json_message('Entity not found', HTTP_NOT_FOUND)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -260,24 +257,25 @@ class HueLightsView(HomeAssistantView):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return self.json_message('Invalid JSON', HTTP_BAD_REQUEST)
|
return self.json_message('Invalid JSON', HTTP_BAD_REQUEST)
|
||||||
|
|
||||||
result = yield from self.async_put_light_state(json_data, entity_id)
|
result = yield from self.async_put_light_state(hass, json_data,
|
||||||
|
entity_id)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
def async_get_lights_list(self):
|
def async_get_lights_list(self, hass):
|
||||||
"""Process a request to get the list of available lights."""
|
"""Process a request to get the list of available lights."""
|
||||||
json_response = {}
|
json_response = {}
|
||||||
|
|
||||||
for entity in self.hass.states.async_all():
|
for entity in hass.states.async_all():
|
||||||
if self.is_entity_exposed(entity):
|
if self.is_entity_exposed(entity):
|
||||||
json_response[entity.entity_id] = entity_to_json(entity)
|
json_response[entity.entity_id] = entity_to_json(entity)
|
||||||
|
|
||||||
return self.json(json_response)
|
return self.json(json_response)
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
def async_get_light_state(self, entity_id):
|
def async_get_light_state(self, hass, entity_id):
|
||||||
"""Process a request to get the state of an individual light."""
|
"""Process a request to get the state of an individual light."""
|
||||||
entity = self.hass.states.get(entity_id)
|
entity = hass.states.get(entity_id)
|
||||||
if entity is None or not self.is_entity_exposed(entity):
|
if entity is None or not self.is_entity_exposed(entity):
|
||||||
return web.Response(text="Entity not found", status=404)
|
return web.Response(text="Entity not found", status=404)
|
||||||
|
|
||||||
|
@ -295,12 +293,12 @@ class HueLightsView(HomeAssistantView):
|
||||||
return self.json(json_response)
|
return self.json(json_response)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_put_light_state(self, request_json, entity_id):
|
def async_put_light_state(self, hass, request_json, entity_id):
|
||||||
"""Process a request to set the state of an individual light."""
|
"""Process a request to set the state of an individual light."""
|
||||||
config = self.config
|
config = self.config
|
||||||
|
|
||||||
# Retrieve the entity from the state machine
|
# Retrieve the entity from the state machine
|
||||||
entity = self.hass.states.get(entity_id)
|
entity = hass.states.get(entity_id)
|
||||||
if entity is None:
|
if entity is None:
|
||||||
return web.Response(text="Entity not found", status=404)
|
return web.Response(text="Entity not found", status=404)
|
||||||
|
|
||||||
|
@ -345,7 +343,7 @@ class HueLightsView(HomeAssistantView):
|
||||||
self.cached_states[entity_id] = (result, brightness)
|
self.cached_states[entity_id] = (result, brightness)
|
||||||
|
|
||||||
# Perform the requested action
|
# Perform the requested action
|
||||||
yield from self.hass.services.async_call(core.DOMAIN, service, data,
|
yield from hass.services.async_call(core.DOMAIN, service, data,
|
||||||
blocking=True)
|
blocking=True)
|
||||||
|
|
||||||
json_response = \
|
json_response = \
|
||||||
|
|
|
@ -75,8 +75,7 @@ def setup(hass, config):
|
||||||
descriptions[DOMAIN][SERVICE_CHECKIN],
|
descriptions[DOMAIN][SERVICE_CHECKIN],
|
||||||
schema=CHECKIN_SERVICE_SCHEMA)
|
schema=CHECKIN_SERVICE_SCHEMA)
|
||||||
|
|
||||||
hass.http.register_view(FoursquarePushReceiver(
|
hass.http.register_view(FoursquarePushReceiver(config[CONF_PUSH_SECRET]))
|
||||||
hass, config[CONF_PUSH_SECRET]))
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -88,9 +87,8 @@ class FoursquarePushReceiver(HomeAssistantView):
|
||||||
url = "/api/foursquare"
|
url = "/api/foursquare"
|
||||||
name = "foursquare"
|
name = "foursquare"
|
||||||
|
|
||||||
def __init__(self, hass, push_secret):
|
def __init__(self, push_secret):
|
||||||
"""Initialize the OAuth callback view."""
|
"""Initialize the OAuth callback view."""
|
||||||
super().__init__(hass)
|
|
||||||
self.push_secret = push_secret
|
self.push_secret = push_secret
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
@ -110,4 +108,4 @@ class FoursquarePushReceiver(HomeAssistantView):
|
||||||
"push secret: %s", secret)
|
"push secret: %s", secret)
|
||||||
return self.json_message('Incorrect secret', HTTP_BAD_REQUEST)
|
return self.json_message('Incorrect secret', HTTP_BAD_REQUEST)
|
||||||
|
|
||||||
self.hass.bus.async_fire(EVENT_PUSH, data)
|
request.app['hass'].bus.async_fire(EVENT_PUSH, data)
|
||||||
|
|
|
@ -11,6 +11,8 @@ from homeassistant.core import callback
|
||||||
from homeassistant.const import HTTP_NOT_FOUND
|
from homeassistant.const import HTTP_NOT_FOUND
|
||||||
from homeassistant.components import api, group
|
from homeassistant.components import api, group
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
|
from homeassistant.components.http.auth import is_trusted_ip
|
||||||
|
from homeassistant.components.http.const import KEY_DEVELOPMENT
|
||||||
from .version import FINGERPRINTS
|
from .version import FINGERPRINTS
|
||||||
|
|
||||||
DOMAIN = 'frontend'
|
DOMAIN = 'frontend'
|
||||||
|
@ -155,7 +157,7 @@ def setup(hass, config):
|
||||||
if os.path.isdir(local):
|
if os.path.isdir(local):
|
||||||
hass.http.register_static_path("/local", local)
|
hass.http.register_static_path("/local", local)
|
||||||
|
|
||||||
index_view = hass.data[DATA_INDEX_VIEW] = IndexView(hass)
|
index_view = hass.data[DATA_INDEX_VIEW] = IndexView()
|
||||||
hass.http.register_view(index_view)
|
hass.http.register_view(index_view)
|
||||||
|
|
||||||
# Components have registered panels before frontend got setup.
|
# Components have registered panels before frontend got setup.
|
||||||
|
@ -185,12 +187,14 @@ class BootstrapView(HomeAssistantView):
|
||||||
@callback
|
@callback
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
"""Return all data needed to bootstrap Home Assistant."""
|
"""Return all data needed to bootstrap Home Assistant."""
|
||||||
|
hass = request.app['hass']
|
||||||
|
|
||||||
return self.json({
|
return self.json({
|
||||||
'config': self.hass.config.as_dict(),
|
'config': hass.config.as_dict(),
|
||||||
'states': self.hass.states.async_all(),
|
'states': hass.states.async_all(),
|
||||||
'events': api.async_events_json(self.hass),
|
'events': api.async_events_json(hass),
|
||||||
'services': api.async_services_json(self.hass),
|
'services': api.async_services_json(hass),
|
||||||
'panels': self.hass.data[DATA_PANELS],
|
'panels': hass.data[DATA_PANELS],
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@ -202,10 +206,8 @@ class IndexView(HomeAssistantView):
|
||||||
requires_auth = False
|
requires_auth = False
|
||||||
extra_urls = ['/states', '/states/{entity_id}']
|
extra_urls = ['/states', '/states/{entity_id}']
|
||||||
|
|
||||||
def __init__(self, hass):
|
def __init__(self):
|
||||||
"""Initialize the frontend view."""
|
"""Initialize the frontend view."""
|
||||||
super().__init__(hass)
|
|
||||||
|
|
||||||
from jinja2 import FileSystemLoader, Environment
|
from jinja2 import FileSystemLoader, Environment
|
||||||
|
|
||||||
self.templates = Environment(
|
self.templates = Environment(
|
||||||
|
@ -217,14 +219,16 @@ class IndexView(HomeAssistantView):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def get(self, request, entity_id=None):
|
def get(self, request, entity_id=None):
|
||||||
"""Serve the index view."""
|
"""Serve the index view."""
|
||||||
|
hass = request.app['hass']
|
||||||
|
|
||||||
if entity_id is not None:
|
if entity_id is not None:
|
||||||
state = self.hass.states.get(entity_id)
|
state = hass.states.get(entity_id)
|
||||||
|
|
||||||
if (not state or state.domain != 'group' or
|
if (not state or state.domain != 'group' or
|
||||||
not state.attributes.get(group.ATTR_VIEW)):
|
not state.attributes.get(group.ATTR_VIEW)):
|
||||||
return self.json_message('Entity not found', HTTP_NOT_FOUND)
|
return self.json_message('Entity not found', HTTP_NOT_FOUND)
|
||||||
|
|
||||||
if self.hass.http.development:
|
if request.app[KEY_DEVELOPMENT]:
|
||||||
core_url = '/static/home-assistant-polymer/build/core.js'
|
core_url = '/static/home-assistant-polymer/build/core.js'
|
||||||
ui_url = '/static/home-assistant-polymer/src/home-assistant.html'
|
ui_url = '/static/home-assistant-polymer/src/home-assistant.html'
|
||||||
else:
|
else:
|
||||||
|
@ -241,19 +245,18 @@ class IndexView(HomeAssistantView):
|
||||||
if panel == 'states':
|
if panel == 'states':
|
||||||
panel_url = ''
|
panel_url = ''
|
||||||
else:
|
else:
|
||||||
panel_url = self.hass.data[DATA_PANELS][panel]['url']
|
panel_url = hass.data[DATA_PANELS][panel]['url']
|
||||||
|
|
||||||
no_auth = 'true'
|
no_auth = 'true'
|
||||||
if self.hass.config.api.api_password:
|
if hass.config.api.api_password:
|
||||||
# require password if set
|
# require password if set
|
||||||
no_auth = 'false'
|
no_auth = 'false'
|
||||||
if self.hass.http.is_trusted_ip(
|
if is_trusted_ip(request):
|
||||||
self.hass.http.get_real_ip(request)):
|
|
||||||
# bypass for trusted networks
|
# bypass for trusted networks
|
||||||
no_auth = 'true'
|
no_auth = 'true'
|
||||||
|
|
||||||
icons_url = '/static/mdi-{}.html'.format(FINGERPRINTS['mdi.html'])
|
icons_url = '/static/mdi-{}.html'.format(FINGERPRINTS['mdi.html'])
|
||||||
template = yield from self.hass.loop.run_in_executor(
|
template = yield from hass.loop.run_in_executor(
|
||||||
None, self.templates.get_template, 'index.html')
|
None, self.templates.get_template, 'index.html')
|
||||||
|
|
||||||
# pylint is wrong
|
# pylint is wrong
|
||||||
|
@ -262,7 +265,7 @@ class IndexView(HomeAssistantView):
|
||||||
resp = template.render(
|
resp = template.render(
|
||||||
core_url=core_url, ui_url=ui_url, no_auth=no_auth,
|
core_url=core_url, ui_url=ui_url, no_auth=no_auth,
|
||||||
icons_url=icons_url, icons=FINGERPRINTS['mdi.html'],
|
icons_url=icons_url, icons=FINGERPRINTS['mdi.html'],
|
||||||
panel_url=panel_url, panels=self.hass.data[DATA_PANELS])
|
panel_url=panel_url, panels=hass.data[DATA_PANELS])
|
||||||
|
|
||||||
return web.Response(text=resp, content_type='text/html')
|
return web.Response(text=resp, content_type='text/html')
|
||||||
|
|
||||||
|
|
|
@ -184,8 +184,8 @@ def setup(hass, config):
|
||||||
filters.included_entities = include[CONF_ENTITIES]
|
filters.included_entities = include[CONF_ENTITIES]
|
||||||
filters.included_domains = include[CONF_DOMAINS]
|
filters.included_domains = include[CONF_DOMAINS]
|
||||||
|
|
||||||
hass.http.register_view(Last5StatesView(hass))
|
hass.http.register_view(Last5StatesView)
|
||||||
hass.http.register_view(HistoryPeriodView(hass, filters))
|
hass.http.register_view(HistoryPeriodView(filters))
|
||||||
register_built_in_panel(hass, 'history', 'History', 'mdi:poll-box')
|
register_built_in_panel(hass, 'history', 'History', 'mdi:poll-box')
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -197,14 +197,10 @@ class Last5StatesView(HomeAssistantView):
|
||||||
url = '/api/history/entity/{entity_id}/recent_states'
|
url = '/api/history/entity/{entity_id}/recent_states'
|
||||||
name = 'api:history:entity-recent-states'
|
name = 'api:history:entity-recent-states'
|
||||||
|
|
||||||
def __init__(self, hass):
|
|
||||||
"""Initilalize the history last 5 states view."""
|
|
||||||
super().__init__(hass)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def get(self, request, entity_id):
|
def get(self, request, entity_id):
|
||||||
"""Retrieve last 5 states of entity."""
|
"""Retrieve last 5 states of entity."""
|
||||||
result = yield from self.hass.loop.run_in_executor(
|
result = yield from request.app['hass'].loop.run_in_executor(
|
||||||
None, last_5_states, entity_id)
|
None, last_5_states, entity_id)
|
||||||
return self.json(result)
|
return self.json(result)
|
||||||
|
|
||||||
|
@ -216,9 +212,8 @@ class HistoryPeriodView(HomeAssistantView):
|
||||||
name = 'api:history:view-period'
|
name = 'api:history:view-period'
|
||||||
extra_urls = ['/api/history/period/{datetime}']
|
extra_urls = ['/api/history/period/{datetime}']
|
||||||
|
|
||||||
def __init__(self, hass, filters):
|
def __init__(self, filters):
|
||||||
"""Initilalize the history period view."""
|
"""Initilalize the history period view."""
|
||||||
super().__init__(hass)
|
|
||||||
self.filters = filters
|
self.filters = filters
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
@ -240,7 +235,7 @@ class HistoryPeriodView(HomeAssistantView):
|
||||||
end_time = start_time + one_day
|
end_time = start_time + one_day
|
||||||
entity_id = request.GET.get('filter_entity_id')
|
entity_id = request.GET.get('filter_entity_id')
|
||||||
|
|
||||||
result = yield from self.hass.loop.run_in_executor(
|
result = yield from request.app['hass'].loop.run_in_executor(
|
||||||
None, get_significant_states, start_time, end_time, entity_id,
|
None, get_significant_states, start_time, end_time, entity_id,
|
||||||
self.filters)
|
self.filters)
|
||||||
|
|
||||||
|
|
|
@ -1,641 +0,0 @@
|
||||||
"""
|
|
||||||
This module provides WSGI application to serve the Home Assistant API.
|
|
||||||
|
|
||||||
For more details about this component, please refer to the documentation at
|
|
||||||
https://home-assistant.io/components/http/
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import mimetypes
|
|
||||||
import ssl
|
|
||||||
from datetime import datetime
|
|
||||||
from ipaddress import ip_address, ip_network
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import hmac
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import voluptuous as vol
|
|
||||||
from aiohttp import web, hdrs
|
|
||||||
from aiohttp.file_sender import FileSender
|
|
||||||
from aiohttp.web_exceptions import (
|
|
||||||
HTTPUnauthorized, HTTPMovedPermanently, HTTPNotModified, HTTPForbidden)
|
|
||||||
from aiohttp.web_urldispatcher import StaticResource
|
|
||||||
|
|
||||||
import homeassistant.helpers.config_validation as cv
|
|
||||||
import homeassistant.remote as rem
|
|
||||||
from homeassistant import util
|
|
||||||
from homeassistant.components import persistent_notification
|
|
||||||
from homeassistant.config import load_yaml_config_file
|
|
||||||
from homeassistant.const import (
|
|
||||||
SERVER_PORT, HTTP_HEADER_HA_AUTH, # HTTP_HEADER_CACHE_CONTROL,
|
|
||||||
CONTENT_TYPE_JSON, ALLOWED_CORS_HEADERS, EVENT_HOMEASSISTANT_STOP,
|
|
||||||
EVENT_HOMEASSISTANT_START, HTTP_HEADER_X_FORWARDED_FOR)
|
|
||||||
from homeassistant.core import is_callback
|
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
|
||||||
from homeassistant.util.yaml import dump
|
|
||||||
|
|
||||||
DOMAIN = 'http'
|
|
||||||
REQUIREMENTS = ('aiohttp_cors==0.5.0',)
|
|
||||||
|
|
||||||
CONF_API_PASSWORD = 'api_password'
|
|
||||||
CONF_SERVER_HOST = 'server_host'
|
|
||||||
CONF_SERVER_PORT = 'server_port'
|
|
||||||
CONF_DEVELOPMENT = 'development'
|
|
||||||
CONF_SSL_CERTIFICATE = 'ssl_certificate'
|
|
||||||
CONF_SSL_KEY = 'ssl_key'
|
|
||||||
CONF_CORS_ORIGINS = 'cors_allowed_origins'
|
|
||||||
CONF_USE_X_FORWARDED_FOR = 'use_x_forwarded_for'
|
|
||||||
CONF_TRUSTED_NETWORKS = 'trusted_networks'
|
|
||||||
CONF_LOGIN_ATTEMPTS_THRESHOLD = 'login_attempts_threshold'
|
|
||||||
CONF_IP_BAN_ENABLED = 'ip_ban_enabled'
|
|
||||||
|
|
||||||
DATA_API_PASSWORD = 'api_password'
|
|
||||||
NOTIFICATION_ID_LOGIN = 'http-login'
|
|
||||||
NOTIFICATION_ID_BAN = 'ip-ban'
|
|
||||||
|
|
||||||
IP_BANS = 'ip_bans.yaml'
|
|
||||||
ATTR_BANNED_AT = "banned_at"
|
|
||||||
|
|
||||||
|
|
||||||
# TLS configuation follows the best-practice guidelines specified here:
|
|
||||||
# https://wiki.mozilla.org/Security/Server_Side_TLS
|
|
||||||
# Intermediate guidelines are followed.
|
|
||||||
SSL_VERSION = ssl.PROTOCOL_SSLv23
|
|
||||||
SSL_OPTS = ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3
|
|
||||||
if hasattr(ssl, 'OP_NO_COMPRESSION'):
|
|
||||||
SSL_OPTS |= ssl.OP_NO_COMPRESSION
|
|
||||||
CIPHERS = "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:" \
|
|
||||||
"ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:" \
|
|
||||||
"ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:" \
|
|
||||||
"DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:" \
|
|
||||||
"ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:" \
|
|
||||||
"ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:" \
|
|
||||||
"ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:" \
|
|
||||||
"ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:" \
|
|
||||||
"DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA256:" \
|
|
||||||
"DHE-RSA-AES256-SHA:ECDHE-ECDSA-DES-CBC3-SHA:" \
|
|
||||||
"ECDHE-RSA-DES-CBC3-SHA:EDH-RSA-DES-CBC3-SHA:AES128-GCM-SHA256:" \
|
|
||||||
"AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:" \
|
|
||||||
"AES256-SHA:DES-CBC3-SHA:!DSS"
|
|
||||||
|
|
||||||
_FINGERPRINT = re.compile(r'^(.+)-[a-z0-9]{32}\.(\w+)$', re.IGNORECASE)
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
CONFIG_SCHEMA = vol.Schema({
|
|
||||||
DOMAIN: vol.Schema({
|
|
||||||
vol.Optional(CONF_API_PASSWORD): cv.string,
|
|
||||||
vol.Optional(CONF_SERVER_HOST): cv.string,
|
|
||||||
vol.Optional(CONF_SERVER_PORT, default=SERVER_PORT):
|
|
||||||
vol.All(vol.Coerce(int), vol.Range(min=1, max=65535)),
|
|
||||||
vol.Optional(CONF_DEVELOPMENT): cv.string,
|
|
||||||
vol.Optional(CONF_SSL_CERTIFICATE): cv.isfile,
|
|
||||||
vol.Optional(CONF_SSL_KEY): cv.isfile,
|
|
||||||
vol.Optional(CONF_CORS_ORIGINS): vol.All(cv.ensure_list, [cv.string]),
|
|
||||||
vol.Optional(CONF_USE_X_FORWARDED_FOR, default=False): cv.boolean,
|
|
||||||
vol.Optional(CONF_TRUSTED_NETWORKS):
|
|
||||||
vol.All(cv.ensure_list, [ip_network]),
|
|
||||||
vol.Optional(CONF_LOGIN_ATTEMPTS_THRESHOLD): cv.positive_int,
|
|
||||||
vol.Optional(CONF_IP_BAN_ENABLED): cv.boolean
|
|
||||||
}),
|
|
||||||
}, extra=vol.ALLOW_EXTRA)
|
|
||||||
|
|
||||||
|
|
||||||
# TEMP TO GET TESTS TO RUN
|
|
||||||
def request_class():
|
|
||||||
"""."""
|
|
||||||
raise Exception('not implemented')
|
|
||||||
|
|
||||||
|
|
||||||
class HideSensitiveFilter(logging.Filter):
|
|
||||||
"""Filter API password calls."""
|
|
||||||
|
|
||||||
def __init__(self, hass):
|
|
||||||
"""Initialize sensitive data filter."""
|
|
||||||
super().__init__()
|
|
||||||
self.hass = hass
|
|
||||||
|
|
||||||
def filter(self, record):
|
|
||||||
"""Hide sensitive data in messages."""
|
|
||||||
if self.hass.http.api_password is None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
record.msg = record.msg.replace(self.hass.http.api_password, '*******')
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def setup(hass, config):
|
|
||||||
"""Set up the HTTP API and debug interface."""
|
|
||||||
logging.getLogger('aiohttp.access').addFilter(HideSensitiveFilter(hass))
|
|
||||||
|
|
||||||
conf = config.get(DOMAIN, {})
|
|
||||||
|
|
||||||
api_password = util.convert(conf.get(CONF_API_PASSWORD), str)
|
|
||||||
server_host = conf.get(CONF_SERVER_HOST, '0.0.0.0')
|
|
||||||
server_port = conf.get(CONF_SERVER_PORT, SERVER_PORT)
|
|
||||||
development = str(conf.get(CONF_DEVELOPMENT, '')) == '1'
|
|
||||||
ssl_certificate = conf.get(CONF_SSL_CERTIFICATE)
|
|
||||||
ssl_key = conf.get(CONF_SSL_KEY)
|
|
||||||
cors_origins = conf.get(CONF_CORS_ORIGINS, [])
|
|
||||||
use_x_forwarded_for = conf.get(CONF_USE_X_FORWARDED_FOR, False)
|
|
||||||
trusted_networks = [
|
|
||||||
ip_network(trusted_network)
|
|
||||||
for trusted_network in conf.get(CONF_TRUSTED_NETWORKS, [])]
|
|
||||||
is_ban_enabled = bool(conf.get(CONF_IP_BAN_ENABLED, False))
|
|
||||||
login_threshold = int(conf.get(CONF_LOGIN_ATTEMPTS_THRESHOLD, -1))
|
|
||||||
ip_bans = load_ip_bans_config(hass.config.path(IP_BANS))
|
|
||||||
|
|
||||||
server = HomeAssistantWSGI(
|
|
||||||
hass,
|
|
||||||
development=development,
|
|
||||||
server_host=server_host,
|
|
||||||
server_port=server_port,
|
|
||||||
api_password=api_password,
|
|
||||||
ssl_certificate=ssl_certificate,
|
|
||||||
ssl_key=ssl_key,
|
|
||||||
cors_origins=cors_origins,
|
|
||||||
use_x_forwarded_for=use_x_forwarded_for,
|
|
||||||
trusted_networks=trusted_networks,
|
|
||||||
ip_bans=ip_bans,
|
|
||||||
login_threshold=login_threshold,
|
|
||||||
is_ban_enabled=is_ban_enabled
|
|
||||||
)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def stop_server(event):
|
|
||||||
"""Callback to stop the server."""
|
|
||||||
yield from server.stop()
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def start_server(event):
|
|
||||||
"""Callback to start the server."""
|
|
||||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
|
|
||||||
yield from server.start()
|
|
||||||
|
|
||||||
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_server)
|
|
||||||
|
|
||||||
hass.http = server
|
|
||||||
hass.config.api = rem.API(server_host if server_host != '0.0.0.0'
|
|
||||||
else util.get_local_ip(),
|
|
||||||
api_password, server_port,
|
|
||||||
ssl_certificate is not None)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class GzipFileSender(FileSender):
|
|
||||||
"""FileSender class capable of sending gzip version if available."""
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
|
|
||||||
development = False
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def send(self, request, filepath):
|
|
||||||
"""Send filepath to client using request."""
|
|
||||||
gzip = False
|
|
||||||
if 'gzip' in request.headers[hdrs.ACCEPT_ENCODING]:
|
|
||||||
gzip_path = filepath.with_name(filepath.name + '.gz')
|
|
||||||
|
|
||||||
if gzip_path.is_file():
|
|
||||||
filepath = gzip_path
|
|
||||||
gzip = True
|
|
||||||
|
|
||||||
st = filepath.stat()
|
|
||||||
|
|
||||||
modsince = request.if_modified_since
|
|
||||||
if modsince is not None and st.st_mtime <= modsince.timestamp():
|
|
||||||
raise HTTPNotModified()
|
|
||||||
|
|
||||||
ct, encoding = mimetypes.guess_type(str(filepath))
|
|
||||||
if not ct:
|
|
||||||
ct = 'application/octet-stream'
|
|
||||||
|
|
||||||
resp = self._response_factory()
|
|
||||||
resp.content_type = ct
|
|
||||||
if encoding:
|
|
||||||
resp.headers[hdrs.CONTENT_ENCODING] = encoding
|
|
||||||
if gzip:
|
|
||||||
resp.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
|
|
||||||
resp.last_modified = st.st_mtime
|
|
||||||
|
|
||||||
# CACHE HACK
|
|
||||||
if not self.development:
|
|
||||||
cache_time = 31 * 86400 # = 1 month
|
|
||||||
resp.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format(
|
|
||||||
cache_time)
|
|
||||||
|
|
||||||
file_size = st.st_size
|
|
||||||
|
|
||||||
resp.content_length = file_size
|
|
||||||
with filepath.open('rb') as f:
|
|
||||||
yield from self._sendfile(request, resp, f, file_size)
|
|
||||||
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
_GZIP_FILE_SENDER = GzipFileSender()
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def staticresource_enhancer(app, handler):
|
|
||||||
"""Enhance StaticResourceHandler.
|
|
||||||
|
|
||||||
Adds gzip encoding and fingerprinting matching.
|
|
||||||
"""
|
|
||||||
inst = getattr(handler, '__self__', None)
|
|
||||||
if not isinstance(inst, StaticResource):
|
|
||||||
return handler
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
inst._file_sender = _GZIP_FILE_SENDER
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def middleware_handler(request):
|
|
||||||
"""Strip out fingerprints from resource names."""
|
|
||||||
fingerprinted = _FINGERPRINT.match(request.match_info['filename'])
|
|
||||||
|
|
||||||
if fingerprinted:
|
|
||||||
request.match_info['filename'] = \
|
|
||||||
'{}.{}'.format(*fingerprinted.groups())
|
|
||||||
|
|
||||||
resp = yield from handler(request)
|
|
||||||
return resp
|
|
||||||
|
|
||||||
return middleware_handler
|
|
||||||
|
|
||||||
|
|
||||||
class HomeAssistantWSGI(object):
|
|
||||||
"""WSGI server for Home Assistant."""
|
|
||||||
|
|
||||||
def __init__(self, hass, development, api_password, ssl_certificate,
|
|
||||||
ssl_key, server_host, server_port, cors_origins,
|
|
||||||
use_x_forwarded_for, trusted_networks,
|
|
||||||
ip_bans, login_threshold, is_ban_enabled):
|
|
||||||
"""Initialize the WSGI Home Assistant server."""
|
|
||||||
import aiohttp_cors
|
|
||||||
|
|
||||||
self.app = web.Application(middlewares=[staticresource_enhancer],
|
|
||||||
loop=hass.loop)
|
|
||||||
self.hass = hass
|
|
||||||
self.development = development
|
|
||||||
self.api_password = api_password
|
|
||||||
self.ssl_certificate = ssl_certificate
|
|
||||||
self.ssl_key = ssl_key
|
|
||||||
self.server_host = server_host
|
|
||||||
self.server_port = server_port
|
|
||||||
self.use_x_forwarded_for = use_x_forwarded_for
|
|
||||||
self.trusted_networks = trusted_networks \
|
|
||||||
if trusted_networks is not None else []
|
|
||||||
self.event_forwarder = None
|
|
||||||
self._handler = None
|
|
||||||
self.server = None
|
|
||||||
self.login_threshold = login_threshold
|
|
||||||
self.ip_bans = ip_bans if ip_bans is not None else []
|
|
||||||
self.failed_login_attempts = {}
|
|
||||||
self.is_ban_enabled = is_ban_enabled
|
|
||||||
|
|
||||||
if cors_origins:
|
|
||||||
self.cors = aiohttp_cors.setup(self.app, defaults={
|
|
||||||
host: aiohttp_cors.ResourceOptions(
|
|
||||||
allow_headers=ALLOWED_CORS_HEADERS,
|
|
||||||
allow_methods='*',
|
|
||||||
) for host in cors_origins
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
self.cors = None
|
|
||||||
|
|
||||||
# CACHE HACK
|
|
||||||
_GZIP_FILE_SENDER.development = development
|
|
||||||
|
|
||||||
def register_view(self, view):
|
|
||||||
"""Register a view with the WSGI server.
|
|
||||||
|
|
||||||
The view argument must be a class that inherits from HomeAssistantView.
|
|
||||||
It is optional to instantiate it before registering; this method will
|
|
||||||
handle it either way.
|
|
||||||
"""
|
|
||||||
if isinstance(view, type):
|
|
||||||
# Instantiate the view, if needed
|
|
||||||
view = view(self.hass)
|
|
||||||
|
|
||||||
view.register(self.app.router)
|
|
||||||
|
|
||||||
def register_redirect(self, url, redirect_to):
|
|
||||||
"""Register a redirect with the server.
|
|
||||||
|
|
||||||
If given this must be either a string or callable. In case of a
|
|
||||||
callable it's called with the url adapter that triggered the match and
|
|
||||||
the values of the URL as keyword arguments and has to return the target
|
|
||||||
for the redirect, otherwise it has to be a string with placeholders in
|
|
||||||
rule syntax.
|
|
||||||
"""
|
|
||||||
def redirect(request):
|
|
||||||
"""Redirect to location."""
|
|
||||||
raise HTTPMovedPermanently(redirect_to)
|
|
||||||
|
|
||||||
self.app.router.add_route('GET', url, redirect)
|
|
||||||
|
|
||||||
def register_static_path(self, url_root, path, cache_length=31):
|
|
||||||
"""Register a folder to serve as a static path.
|
|
||||||
|
|
||||||
Specify optional cache length of asset in days.
|
|
||||||
"""
|
|
||||||
if os.path.isdir(path):
|
|
||||||
self.app.router.add_static(url_root, path)
|
|
||||||
return
|
|
||||||
|
|
||||||
filepath = Path(path)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def serve_file(request):
|
|
||||||
"""Redirect to location."""
|
|
||||||
res = yield from _GZIP_FILE_SENDER.send(request, filepath)
|
|
||||||
return res
|
|
||||||
|
|
||||||
# aiohttp supports regex matching for variables. Using that as temp
|
|
||||||
# to work around cache busting MD5.
|
|
||||||
# Turns something like /static/dev-panel.html into
|
|
||||||
# /static/{filename:dev-panel(-[a-z0-9]{32}|)\.html}
|
|
||||||
base, ext = url_root.rsplit('.', 1)
|
|
||||||
base, file = base.rsplit('/', 1)
|
|
||||||
regex = r"{}(-[a-z0-9]{{32}}|)\.{}".format(file, ext)
|
|
||||||
url_pattern = "{}/{{filename:{}}}".format(base, regex)
|
|
||||||
|
|
||||||
self.app.router.add_route('GET', url_pattern, serve_file)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def start(self):
|
|
||||||
"""Start the wsgi server."""
|
|
||||||
if self.cors is not None:
|
|
||||||
for route in list(self.app.router.routes()):
|
|
||||||
self.cors.add(route)
|
|
||||||
|
|
||||||
if self.ssl_certificate:
|
|
||||||
context = ssl.SSLContext(SSL_VERSION)
|
|
||||||
context.options |= SSL_OPTS
|
|
||||||
context.set_ciphers(CIPHERS)
|
|
||||||
context.load_cert_chain(self.ssl_certificate, self.ssl_key)
|
|
||||||
else:
|
|
||||||
context = None
|
|
||||||
|
|
||||||
self._handler = self.app.make_handler()
|
|
||||||
self.server = yield from self.hass.loop.create_server(
|
|
||||||
self._handler, self.server_host, self.server_port, ssl=context)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def stop(self):
|
|
||||||
"""Stop the wsgi server."""
|
|
||||||
self.server.close()
|
|
||||||
yield from self.server.wait_closed()
|
|
||||||
yield from self.app.shutdown()
|
|
||||||
yield from self._handler.finish_connections(60.0)
|
|
||||||
yield from self.app.cleanup()
|
|
||||||
|
|
||||||
def get_real_ip(self, request):
|
|
||||||
"""Return the clients correct ip address, even in proxied setups."""
|
|
||||||
if self.use_x_forwarded_for \
|
|
||||||
and HTTP_HEADER_X_FORWARDED_FOR in request.headers:
|
|
||||||
return request.headers.get(
|
|
||||||
HTTP_HEADER_X_FORWARDED_FOR).split(',')[0]
|
|
||||||
else:
|
|
||||||
peername = request.transport.get_extra_info('peername')
|
|
||||||
return peername[0] if peername is not None else None
|
|
||||||
|
|
||||||
def is_trusted_ip(self, remote_addr):
|
|
||||||
"""Match an ip address against trusted CIDR networks."""
|
|
||||||
return any(ip_address(remote_addr) in trusted_network
|
|
||||||
for trusted_network in self.hass.http.trusted_networks)
|
|
||||||
|
|
||||||
def wrong_login_attempt(self, remote_addr):
|
|
||||||
"""Registering wrong login attempt."""
|
|
||||||
if not self.is_ban_enabled or self.login_threshold < 1:
|
|
||||||
return
|
|
||||||
|
|
||||||
if remote_addr in self.failed_login_attempts:
|
|
||||||
self.failed_login_attempts[remote_addr] += 1
|
|
||||||
else:
|
|
||||||
self.failed_login_attempts[remote_addr] = 1
|
|
||||||
|
|
||||||
if self.failed_login_attempts[remote_addr] > self.login_threshold:
|
|
||||||
new_ban = IpBan(remote_addr)
|
|
||||||
self.ip_bans.append(new_ban)
|
|
||||||
update_ip_bans_config(self.hass.config.path(IP_BANS), new_ban)
|
|
||||||
_LOGGER.warning('Banned IP %s for too many login attempts',
|
|
||||||
remote_addr)
|
|
||||||
persistent_notification.async_create(
|
|
||||||
self.hass,
|
|
||||||
'Too many login attempts from {}'.format(remote_addr),
|
|
||||||
'Banning IP address', NOTIFICATION_ID_BAN)
|
|
||||||
|
|
||||||
def is_banned_ip(self, remote_addr):
|
|
||||||
"""Check if IP address is in a ban list."""
|
|
||||||
if not self.is_ban_enabled:
|
|
||||||
return False
|
|
||||||
|
|
||||||
ip_address_ = ip_address(remote_addr)
|
|
||||||
for ip_ban in self.ip_bans:
|
|
||||||
if ip_ban.ip_address == ip_address_:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class HomeAssistantView(object):
|
|
||||||
"""Base view for all views."""
|
|
||||||
|
|
||||||
url = None
|
|
||||||
extra_urls = []
|
|
||||||
requires_auth = True # Views inheriting from this class can override this
|
|
||||||
|
|
||||||
def __init__(self, hass):
|
|
||||||
"""Initilalize the base view."""
|
|
||||||
if not hasattr(self, 'url'):
|
|
||||||
class_name = self.__class__.__name__
|
|
||||||
raise AttributeError(
|
|
||||||
'{0} missing required attribute "url"'.format(class_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not hasattr(self, 'name'):
|
|
||||||
class_name = self.__class__.__name__
|
|
||||||
raise AttributeError(
|
|
||||||
'{0} missing required attribute "name"'.format(class_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.hass = hass
|
|
||||||
|
|
||||||
# pylint: disable=no-self-use
|
|
||||||
def json(self, result, status_code=200):
|
|
||||||
"""Return a JSON response."""
|
|
||||||
msg = json.dumps(
|
|
||||||
result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8')
|
|
||||||
return web.Response(
|
|
||||||
body=msg, content_type=CONTENT_TYPE_JSON, status=status_code)
|
|
||||||
|
|
||||||
def json_message(self, error, status_code=200):
|
|
||||||
"""Return a JSON message response."""
|
|
||||||
return self.json({'message': error}, status_code)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
# pylint: disable=no-self-use
|
|
||||||
def file(self, request, fil):
|
|
||||||
"""Return a file."""
|
|
||||||
assert isinstance(fil, str), 'only string paths allowed'
|
|
||||||
response = yield from _GZIP_FILE_SENDER.send(request, Path(fil))
|
|
||||||
return response
|
|
||||||
|
|
||||||
def register(self, router):
|
|
||||||
"""Register the view with a router."""
|
|
||||||
assert self.url is not None, 'No url set for view'
|
|
||||||
urls = [self.url] + self.extra_urls
|
|
||||||
|
|
||||||
for method in ('get', 'post', 'delete', 'put'):
|
|
||||||
handler = getattr(self, method, None)
|
|
||||||
|
|
||||||
if not handler:
|
|
||||||
continue
|
|
||||||
|
|
||||||
handler = request_handler_factory(self, handler)
|
|
||||||
|
|
||||||
for url in urls:
|
|
||||||
router.add_route(method, url, handler)
|
|
||||||
|
|
||||||
# aiohttp_cors does not work with class based views
|
|
||||||
# self.app.router.add_route('*', self.url, self, name=self.name)
|
|
||||||
|
|
||||||
# for url in self.extra_urls:
|
|
||||||
# self.app.router.add_route('*', url, self)
|
|
||||||
|
|
||||||
|
|
||||||
def request_handler_factory(view, handler):
|
|
||||||
"""Factory to wrap our handler classes.
|
|
||||||
|
|
||||||
Eventually authentication should be managed by middleware.
|
|
||||||
"""
|
|
||||||
@asyncio.coroutine
|
|
||||||
def handle(request):
|
|
||||||
"""Handle incoming request."""
|
|
||||||
if not view.hass.is_running:
|
|
||||||
return web.Response(status=503)
|
|
||||||
|
|
||||||
remote_addr = view.hass.http.get_real_ip(request)
|
|
||||||
|
|
||||||
if view.hass.http.is_banned_ip(remote_addr):
|
|
||||||
raise HTTPForbidden()
|
|
||||||
|
|
||||||
# Auth code verbose on purpose
|
|
||||||
authenticated = False
|
|
||||||
|
|
||||||
if view.hass.http.api_password is None:
|
|
||||||
authenticated = True
|
|
||||||
|
|
||||||
elif view.hass.http.is_trusted_ip(remote_addr):
|
|
||||||
authenticated = True
|
|
||||||
|
|
||||||
elif hmac.compare_digest(request.headers.get(HTTP_HEADER_HA_AUTH, ''),
|
|
||||||
view.hass.http.api_password):
|
|
||||||
# A valid auth header has been set
|
|
||||||
authenticated = True
|
|
||||||
|
|
||||||
elif hmac.compare_digest(request.GET.get(DATA_API_PASSWORD, ''),
|
|
||||||
view.hass.http.api_password):
|
|
||||||
authenticated = True
|
|
||||||
|
|
||||||
if view.requires_auth and not authenticated:
|
|
||||||
view.hass.http.wrong_login_attempt(remote_addr)
|
|
||||||
_LOGGER.warning('Login attempt or request with an invalid '
|
|
||||||
'password from %s', remote_addr)
|
|
||||||
persistent_notification.async_create(
|
|
||||||
view.hass,
|
|
||||||
'Invalid password used from {}'.format(remote_addr),
|
|
||||||
'Login attempt failed', NOTIFICATION_ID_LOGIN)
|
|
||||||
raise HTTPUnauthorized()
|
|
||||||
|
|
||||||
request.authenticated = authenticated
|
|
||||||
|
|
||||||
_LOGGER.info('Serving %s to %s (auth: %s)',
|
|
||||||
request.path, remote_addr, authenticated)
|
|
||||||
|
|
||||||
assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \
|
|
||||||
"Handler should be a coroutine or a callback."
|
|
||||||
|
|
||||||
result = handler(request, **request.match_info)
|
|
||||||
|
|
||||||
if asyncio.iscoroutine(result):
|
|
||||||
result = yield from result
|
|
||||||
|
|
||||||
if isinstance(result, web.StreamResponse):
|
|
||||||
# The method handler returned a ready-made Response, how nice of it
|
|
||||||
return result
|
|
||||||
|
|
||||||
status_code = 200
|
|
||||||
|
|
||||||
if isinstance(result, tuple):
|
|
||||||
result, status_code = result
|
|
||||||
|
|
||||||
if isinstance(result, str):
|
|
||||||
result = result.encode('utf-8')
|
|
||||||
elif result is None:
|
|
||||||
result = b''
|
|
||||||
elif not isinstance(result, bytes):
|
|
||||||
assert False, ('Result should be None, string, bytes or Response. '
|
|
||||||
'Got: {}').format(result)
|
|
||||||
|
|
||||||
return web.Response(body=result, status=status_code)
|
|
||||||
|
|
||||||
return handle
|
|
||||||
|
|
||||||
|
|
||||||
class IpBan(object):
|
|
||||||
"""Represents banned IP address."""
|
|
||||||
|
|
||||||
def __init__(self, ip_ban: str, banned_at: datetime=None) -> None:
|
|
||||||
"""Initializing Ip Ban object."""
|
|
||||||
self.ip_address = ip_address(ip_ban)
|
|
||||||
self.banned_at = banned_at
|
|
||||||
if self.banned_at is None:
|
|
||||||
self.banned_at = datetime.utcnow()
|
|
||||||
|
|
||||||
|
|
||||||
def load_ip_bans_config(path: str):
|
|
||||||
"""Loading list of banned IPs from config file."""
|
|
||||||
ip_list = []
|
|
||||||
ip_schema = vol.Schema({
|
|
||||||
vol.Optional('banned_at'): vol.Any(None, cv.datetime)
|
|
||||||
})
|
|
||||||
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
list_ = load_yaml_config_file(path)
|
|
||||||
except HomeAssistantError as err:
|
|
||||||
_LOGGER.error('Unable to load %s: %s', path, str(err))
|
|
||||||
return []
|
|
||||||
|
|
||||||
for ip_ban, ip_info in list_.items():
|
|
||||||
try:
|
|
||||||
ip_info = ip_schema(ip_info)
|
|
||||||
ip_info['ip_ban'] = ip_address(ip_ban)
|
|
||||||
ip_list.append(IpBan(**ip_info))
|
|
||||||
except vol.Invalid:
|
|
||||||
_LOGGER.exception('Failed to load IP ban')
|
|
||||||
continue
|
|
||||||
|
|
||||||
except(HomeAssistantError, FileNotFoundError):
|
|
||||||
# No need to report error, file absence means
|
|
||||||
# that no bans were applied.
|
|
||||||
return []
|
|
||||||
|
|
||||||
return ip_list
|
|
||||||
|
|
||||||
|
|
||||||
def update_ip_bans_config(path: str, ip_ban: IpBan):
|
|
||||||
"""Update config file with new banned IP address."""
|
|
||||||
with open(path, 'a') as out:
|
|
||||||
ip_ = {str(ip_ban.ip_address): {
|
|
||||||
ATTR_BANNED_AT: ip_ban.banned_at.strftime("%Y-%m-%dT%H:%M:%S")
|
|
||||||
}}
|
|
||||||
out.write('\n')
|
|
||||||
out.write(dump(ip_))
|
|
407
homeassistant/components/http/__init__.py
Normal file
407
homeassistant/components/http/__init__.py
Normal file
|
@ -0,0 +1,407 @@
|
||||||
|
"""
|
||||||
|
This module provides WSGI application to serve the Home Assistant API.
|
||||||
|
|
||||||
|
For more details about this component, please refer to the documentation at
|
||||||
|
https://home-assistant.io/components/http/
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import ssl
|
||||||
|
from ipaddress import ip_network
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import os
|
||||||
|
import voluptuous as vol
|
||||||
|
from aiohttp import web
|
||||||
|
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently
|
||||||
|
|
||||||
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
import homeassistant.remote as rem
|
||||||
|
from homeassistant.util import get_local_ip
|
||||||
|
from homeassistant.components import persistent_notification
|
||||||
|
from homeassistant.const import (
|
||||||
|
SERVER_PORT, CONTENT_TYPE_JSON, ALLOWED_CORS_HEADERS,
|
||||||
|
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START)
|
||||||
|
from homeassistant.core import is_callback
|
||||||
|
from homeassistant.util.logging import HideSensitiveDataFilter
|
||||||
|
|
||||||
|
from .auth import auth_middleware
|
||||||
|
from .ban import ban_middleware, process_wrong_login
|
||||||
|
from .const import (
|
||||||
|
KEY_USE_X_FORWARDED_FOR, KEY_TRUSTED_NETWORKS,
|
||||||
|
KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD,
|
||||||
|
KEY_DEVELOPMENT, KEY_AUTHENTICATED)
|
||||||
|
from .static import GZIP_FILE_SENDER, staticresource_middleware
|
||||||
|
from .util import get_real_ip
|
||||||
|
|
||||||
|
DOMAIN = 'http'
|
||||||
|
REQUIREMENTS = ('aiohttp_cors==0.5.0',)
|
||||||
|
|
||||||
|
CONF_API_PASSWORD = 'api_password'
|
||||||
|
CONF_SERVER_HOST = 'server_host'
|
||||||
|
CONF_SERVER_PORT = 'server_port'
|
||||||
|
CONF_DEVELOPMENT = 'development'
|
||||||
|
CONF_SSL_CERTIFICATE = 'ssl_certificate'
|
||||||
|
CONF_SSL_KEY = 'ssl_key'
|
||||||
|
CONF_CORS_ORIGINS = 'cors_allowed_origins'
|
||||||
|
CONF_USE_X_FORWARDED_FOR = 'use_x_forwarded_for'
|
||||||
|
CONF_TRUSTED_NETWORKS = 'trusted_networks'
|
||||||
|
CONF_LOGIN_ATTEMPTS_THRESHOLD = 'login_attempts_threshold'
|
||||||
|
CONF_IP_BAN_ENABLED = 'ip_ban_enabled'
|
||||||
|
|
||||||
|
NOTIFICATION_ID_LOGIN = 'http-login'
|
||||||
|
|
||||||
|
# TLS configuation follows the best-practice guidelines specified here:
|
||||||
|
# https://wiki.mozilla.org/Security/Server_Side_TLS
|
||||||
|
# Intermediate guidelines are followed.
|
||||||
|
SSL_VERSION = ssl.PROTOCOL_SSLv23
|
||||||
|
SSL_OPTS = ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3
|
||||||
|
if hasattr(ssl, 'OP_NO_COMPRESSION'):
|
||||||
|
SSL_OPTS |= ssl.OP_NO_COMPRESSION
|
||||||
|
CIPHERS = "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:" \
|
||||||
|
"ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:" \
|
||||||
|
"ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:" \
|
||||||
|
"DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:" \
|
||||||
|
"ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:" \
|
||||||
|
"ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:" \
|
||||||
|
"ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:" \
|
||||||
|
"ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:" \
|
||||||
|
"DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA256:" \
|
||||||
|
"DHE-RSA-AES256-SHA:ECDHE-ECDSA-DES-CBC3-SHA:" \
|
||||||
|
"ECDHE-RSA-DES-CBC3-SHA:EDH-RSA-DES-CBC3-SHA:AES128-GCM-SHA256:" \
|
||||||
|
"AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:" \
|
||||||
|
"AES256-SHA:DES-CBC3-SHA:!DSS"
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_SERVER_HOST = '0.0.0.0'
|
||||||
|
DEFAULT_DEVELOPMENT = '0'
|
||||||
|
DEFAULT_LOGIN_ATTEMPT_THRESHOLD = -1
|
||||||
|
|
||||||
|
HTTP_SCHEMA = vol.Schema({
|
||||||
|
vol.Optional(CONF_API_PASSWORD, default=None): cv.string,
|
||||||
|
vol.Optional(CONF_SERVER_HOST, default=DEFAULT_SERVER_HOST): cv.string,
|
||||||
|
vol.Optional(CONF_SERVER_PORT, default=SERVER_PORT):
|
||||||
|
vol.All(vol.Coerce(int), vol.Range(min=1, max=65535)),
|
||||||
|
vol.Optional(CONF_DEVELOPMENT, default=DEFAULT_DEVELOPMENT): cv.string,
|
||||||
|
vol.Optional(CONF_SSL_CERTIFICATE, default=None): cv.isfile,
|
||||||
|
vol.Optional(CONF_SSL_KEY, default=None): cv.isfile,
|
||||||
|
vol.Optional(CONF_CORS_ORIGINS, default=[]): vol.All(cv.ensure_list,
|
||||||
|
[cv.string]),
|
||||||
|
vol.Optional(CONF_USE_X_FORWARDED_FOR, default=False): cv.boolean,
|
||||||
|
vol.Optional(CONF_TRUSTED_NETWORKS, default=[]):
|
||||||
|
vol.All(cv.ensure_list, [ip_network]),
|
||||||
|
vol.Optional(CONF_LOGIN_ATTEMPTS_THRESHOLD,
|
||||||
|
default=DEFAULT_LOGIN_ATTEMPT_THRESHOLD): cv.positive_int,
|
||||||
|
vol.Optional(CONF_IP_BAN_ENABLED, default=True): cv.boolean
|
||||||
|
})
|
||||||
|
|
||||||
|
CONFIG_SCHEMA = vol.Schema({
|
||||||
|
DOMAIN: HTTP_SCHEMA,
|
||||||
|
}, extra=vol.ALLOW_EXTRA)
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def async_setup(hass, config):
|
||||||
|
"""Set up the HTTP API and debug interface."""
|
||||||
|
conf = config.get(DOMAIN)
|
||||||
|
|
||||||
|
if conf is None:
|
||||||
|
conf = HTTP_SCHEMA({})
|
||||||
|
|
||||||
|
api_password = conf[CONF_API_PASSWORD]
|
||||||
|
server_host = conf[CONF_SERVER_HOST]
|
||||||
|
server_port = conf[CONF_SERVER_PORT]
|
||||||
|
development = conf[CONF_DEVELOPMENT] == '1'
|
||||||
|
ssl_certificate = conf[CONF_SSL_CERTIFICATE]
|
||||||
|
ssl_key = conf[CONF_SSL_KEY]
|
||||||
|
cors_origins = conf[CONF_CORS_ORIGINS]
|
||||||
|
use_x_forwarded_for = conf[CONF_USE_X_FORWARDED_FOR]
|
||||||
|
trusted_networks = conf[CONF_TRUSTED_NETWORKS]
|
||||||
|
is_ban_enabled = conf[CONF_IP_BAN_ENABLED]
|
||||||
|
login_threshold = conf[CONF_LOGIN_ATTEMPTS_THRESHOLD]
|
||||||
|
|
||||||
|
if api_password is not None:
|
||||||
|
logging.getLogger('aiohttp.access').addFilter(
|
||||||
|
HideSensitiveDataFilter(api_password))
|
||||||
|
|
||||||
|
server = HomeAssistantWSGI(
|
||||||
|
hass,
|
||||||
|
development=development,
|
||||||
|
server_host=server_host,
|
||||||
|
server_port=server_port,
|
||||||
|
api_password=api_password,
|
||||||
|
ssl_certificate=ssl_certificate,
|
||||||
|
ssl_key=ssl_key,
|
||||||
|
cors_origins=cors_origins,
|
||||||
|
use_x_forwarded_for=use_x_forwarded_for,
|
||||||
|
trusted_networks=trusted_networks,
|
||||||
|
login_threshold=login_threshold,
|
||||||
|
is_ban_enabled=is_ban_enabled
|
||||||
|
)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def stop_server(event):
|
||||||
|
"""Callback to stop the server."""
|
||||||
|
yield from server.stop()
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def start_server(event):
|
||||||
|
"""Callback to start the server."""
|
||||||
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
|
||||||
|
yield from server.start()
|
||||||
|
|
||||||
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_server)
|
||||||
|
|
||||||
|
hass.http = server
|
||||||
|
hass.config.api = rem.API(server_host if server_host != '0.0.0.0'
|
||||||
|
else get_local_ip(),
|
||||||
|
api_password, server_port,
|
||||||
|
ssl_certificate is not None)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class HomeAssistantWSGI(object):
|
||||||
|
"""WSGI server for Home Assistant."""
|
||||||
|
|
||||||
|
def __init__(self, hass, development, api_password, ssl_certificate,
|
||||||
|
ssl_key, server_host, server_port, cors_origins,
|
||||||
|
use_x_forwarded_for, trusted_networks,
|
||||||
|
login_threshold, is_ban_enabled):
|
||||||
|
"""Initialize the WSGI Home Assistant server."""
|
||||||
|
import aiohttp_cors
|
||||||
|
|
||||||
|
middlewares = [auth_middleware, staticresource_middleware]
|
||||||
|
|
||||||
|
if is_ban_enabled:
|
||||||
|
middlewares.insert(0, ban_middleware)
|
||||||
|
|
||||||
|
self.app = web.Application(middlewares=middlewares, loop=hass.loop)
|
||||||
|
self.app['hass'] = hass
|
||||||
|
self.app[KEY_USE_X_FORWARDED_FOR] = use_x_forwarded_for
|
||||||
|
self.app[KEY_TRUSTED_NETWORKS] = trusted_networks
|
||||||
|
self.app[KEY_BANS_ENABLED] = is_ban_enabled
|
||||||
|
self.app[KEY_LOGIN_THRESHOLD] = login_threshold
|
||||||
|
self.app[KEY_DEVELOPMENT] = development
|
||||||
|
|
||||||
|
self.hass = hass
|
||||||
|
self.development = development
|
||||||
|
self.api_password = api_password
|
||||||
|
self.ssl_certificate = ssl_certificate
|
||||||
|
self.ssl_key = ssl_key
|
||||||
|
self.server_host = server_host
|
||||||
|
self.server_port = server_port
|
||||||
|
self._handler = None
|
||||||
|
self.server = None
|
||||||
|
|
||||||
|
if cors_origins:
|
||||||
|
self.cors = aiohttp_cors.setup(self.app, defaults={
|
||||||
|
host: aiohttp_cors.ResourceOptions(
|
||||||
|
allow_headers=ALLOWED_CORS_HEADERS,
|
||||||
|
allow_methods='*',
|
||||||
|
) for host in cors_origins
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
self.cors = None
|
||||||
|
|
||||||
|
def register_view(self, view):
|
||||||
|
"""Register a view with the WSGI server.
|
||||||
|
|
||||||
|
The view argument must be a class that inherits from HomeAssistantView.
|
||||||
|
It is optional to instantiate it before registering; this method will
|
||||||
|
handle it either way.
|
||||||
|
"""
|
||||||
|
if isinstance(view, type):
|
||||||
|
# Instantiate the view, if needed
|
||||||
|
view = view()
|
||||||
|
|
||||||
|
if not hasattr(view, 'url'):
|
||||||
|
class_name = view.__class__.__name__
|
||||||
|
raise AttributeError(
|
||||||
|
'{0} missing required attribute "url"'.format(class_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not hasattr(view, 'name'):
|
||||||
|
class_name = view.__class__.__name__
|
||||||
|
raise AttributeError(
|
||||||
|
'{0} missing required attribute "name"'.format(class_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
view.register(self.app.router)
|
||||||
|
|
||||||
|
def register_redirect(self, url, redirect_to):
|
||||||
|
"""Register a redirect with the server.
|
||||||
|
|
||||||
|
If given this must be either a string or callable. In case of a
|
||||||
|
callable it's called with the url adapter that triggered the match and
|
||||||
|
the values of the URL as keyword arguments and has to return the target
|
||||||
|
for the redirect, otherwise it has to be a string with placeholders in
|
||||||
|
rule syntax.
|
||||||
|
"""
|
||||||
|
def redirect(request):
|
||||||
|
"""Redirect to location."""
|
||||||
|
raise HTTPMovedPermanently(redirect_to)
|
||||||
|
|
||||||
|
self.app.router.add_route('GET', url, redirect)
|
||||||
|
|
||||||
|
def register_static_path(self, url_root, path, cache_length=31):
|
||||||
|
"""Register a folder to serve as a static path.
|
||||||
|
|
||||||
|
Specify optional cache length of asset in days.
|
||||||
|
"""
|
||||||
|
if os.path.isdir(path):
|
||||||
|
self.app.router.add_static(url_root, path)
|
||||||
|
return
|
||||||
|
|
||||||
|
filepath = Path(path)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def serve_file(request):
|
||||||
|
"""Serve file from disk."""
|
||||||
|
res = yield from GZIP_FILE_SENDER.send(request, filepath)
|
||||||
|
return res
|
||||||
|
|
||||||
|
# aiohttp supports regex matching for variables. Using that as temp
|
||||||
|
# to work around cache busting MD5.
|
||||||
|
# Turns something like /static/dev-panel.html into
|
||||||
|
# /static/{filename:dev-panel(-[a-z0-9]{32}|)\.html}
|
||||||
|
base, ext = url_root.rsplit('.', 1)
|
||||||
|
base, file = base.rsplit('/', 1)
|
||||||
|
regex = r"{}(-[a-z0-9]{{32}}|)\.{}".format(file, ext)
|
||||||
|
url_pattern = "{}/{{filename:{}}}".format(base, regex)
|
||||||
|
|
||||||
|
self.app.router.add_route('GET', url_pattern, serve_file)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def start(self):
|
||||||
|
"""Start the wsgi server."""
|
||||||
|
if self.cors is not None:
|
||||||
|
for route in list(self.app.router.routes()):
|
||||||
|
self.cors.add(route)
|
||||||
|
|
||||||
|
if self.ssl_certificate:
|
||||||
|
context = ssl.SSLContext(SSL_VERSION)
|
||||||
|
context.options |= SSL_OPTS
|
||||||
|
context.set_ciphers(CIPHERS)
|
||||||
|
context.load_cert_chain(self.ssl_certificate, self.ssl_key)
|
||||||
|
else:
|
||||||
|
context = None
|
||||||
|
|
||||||
|
self._handler = self.app.make_handler()
|
||||||
|
self.server = yield from self.hass.loop.create_server(
|
||||||
|
self._handler, self.server_host, self.server_port, ssl=context)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def stop(self):
|
||||||
|
"""Stop the wsgi server."""
|
||||||
|
self.server.close()
|
||||||
|
yield from self.server.wait_closed()
|
||||||
|
yield from self.app.shutdown()
|
||||||
|
yield from self._handler.finish_connections(60.0)
|
||||||
|
yield from self.app.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class HomeAssistantView(object):
|
||||||
|
"""Base view for all views."""
|
||||||
|
|
||||||
|
url = None
|
||||||
|
extra_urls = []
|
||||||
|
requires_auth = True # Views inheriting from this class can override this
|
||||||
|
|
||||||
|
# pylint: disable=no-self-use
|
||||||
|
def json(self, result, status_code=200):
|
||||||
|
"""Return a JSON response."""
|
||||||
|
msg = json.dumps(
|
||||||
|
result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8')
|
||||||
|
return web.Response(
|
||||||
|
body=msg, content_type=CONTENT_TYPE_JSON, status=status_code)
|
||||||
|
|
||||||
|
def json_message(self, error, status_code=200):
|
||||||
|
"""Return a JSON message response."""
|
||||||
|
return self.json({'message': error}, status_code)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
# pylint: disable=no-self-use
|
||||||
|
def file(self, request, fil):
|
||||||
|
"""Return a file."""
|
||||||
|
assert isinstance(fil, str), 'only string paths allowed'
|
||||||
|
response = yield from GZIP_FILE_SENDER.send(request, Path(fil))
|
||||||
|
return response
|
||||||
|
|
||||||
|
def register(self, router):
|
||||||
|
"""Register the view with a router."""
|
||||||
|
assert self.url is not None, 'No url set for view'
|
||||||
|
urls = [self.url] + self.extra_urls
|
||||||
|
|
||||||
|
for method in ('get', 'post', 'delete', 'put'):
|
||||||
|
handler = getattr(self, method, None)
|
||||||
|
|
||||||
|
if not handler:
|
||||||
|
continue
|
||||||
|
|
||||||
|
handler = request_handler_factory(self, handler)
|
||||||
|
|
||||||
|
for url in urls:
|
||||||
|
router.add_route(method, url, handler)
|
||||||
|
|
||||||
|
# aiohttp_cors does not work with class based views
|
||||||
|
# self.app.router.add_route('*', self.url, self, name=self.name)
|
||||||
|
|
||||||
|
# for url in self.extra_urls:
|
||||||
|
# self.app.router.add_route('*', url, self)
|
||||||
|
|
||||||
|
|
||||||
|
def request_handler_factory(view, handler):
|
||||||
|
"""Factory to wrap our handler classes."""
|
||||||
|
assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \
|
||||||
|
"Handler should be a coroutine or a callback."
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def handle(request):
|
||||||
|
"""Handle incoming request."""
|
||||||
|
if not request.app['hass'].is_running:
|
||||||
|
return web.Response(status=503)
|
||||||
|
|
||||||
|
remote_addr = get_real_ip(request)
|
||||||
|
authenticated = request.get(KEY_AUTHENTICATED, False)
|
||||||
|
|
||||||
|
if view.requires_auth and not authenticated:
|
||||||
|
yield from process_wrong_login(request)
|
||||||
|
_LOGGER.warning('Login attempt or request with an invalid '
|
||||||
|
'password from %s', remote_addr)
|
||||||
|
persistent_notification.async_create(
|
||||||
|
request.app['hass'],
|
||||||
|
'Invalid password used from {}'.format(remote_addr),
|
||||||
|
'Login attempt failed', NOTIFICATION_ID_LOGIN)
|
||||||
|
raise HTTPUnauthorized()
|
||||||
|
|
||||||
|
_LOGGER.info('Serving %s to %s (auth: %s)',
|
||||||
|
request.path, remote_addr, authenticated)
|
||||||
|
|
||||||
|
result = handler(request, **request.match_info)
|
||||||
|
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
result = yield from result
|
||||||
|
|
||||||
|
if isinstance(result, web.StreamResponse):
|
||||||
|
# The method handler returned a ready-made Response, how nice of it
|
||||||
|
return result
|
||||||
|
|
||||||
|
status_code = 200
|
||||||
|
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
result, status_code = result
|
||||||
|
|
||||||
|
if isinstance(result, str):
|
||||||
|
result = result.encode('utf-8')
|
||||||
|
elif result is None:
|
||||||
|
result = b''
|
||||||
|
elif not isinstance(result, bytes):
|
||||||
|
assert False, ('Result should be None, string, bytes or Response. '
|
||||||
|
'Got: {}').format(result)
|
||||||
|
|
||||||
|
return web.Response(body=result, status=status_code)
|
||||||
|
|
||||||
|
return handle
|
61
homeassistant/components/http/auth.py
Normal file
61
homeassistant/components/http/auth.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
"""Authentication for HTTP component."""
|
||||||
|
import asyncio
|
||||||
|
import hmac
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from homeassistant.const import HTTP_HEADER_HA_AUTH
|
||||||
|
from .util import get_real_ip
|
||||||
|
from .const import KEY_TRUSTED_NETWORKS, KEY_AUTHENTICATED
|
||||||
|
|
||||||
|
DATA_API_PASSWORD = 'api_password'
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def auth_middleware(app, handler):
|
||||||
|
"""Authentication middleware."""
|
||||||
|
# If no password set, just always set authenticated=True
|
||||||
|
if app['hass'].http.api_password is None:
|
||||||
|
@asyncio.coroutine
|
||||||
|
def no_auth_middleware_handler(request):
|
||||||
|
"""Auth middleware to approve all requests."""
|
||||||
|
request[KEY_AUTHENTICATED] = True
|
||||||
|
return handler(request)
|
||||||
|
|
||||||
|
return no_auth_middleware_handler
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def auth_middleware_handler(request):
|
||||||
|
"""Auth middleware to check authentication."""
|
||||||
|
hass = app['hass']
|
||||||
|
|
||||||
|
# Auth code verbose on purpose
|
||||||
|
authenticated = False
|
||||||
|
|
||||||
|
if hmac.compare_digest(request.headers.get(HTTP_HEADER_HA_AUTH, ''),
|
||||||
|
hass.http.api_password):
|
||||||
|
# A valid auth header has been set
|
||||||
|
authenticated = True
|
||||||
|
|
||||||
|
elif hmac.compare_digest(request.GET.get(DATA_API_PASSWORD, ''),
|
||||||
|
hass.http.api_password):
|
||||||
|
authenticated = True
|
||||||
|
|
||||||
|
elif is_trusted_ip(request):
|
||||||
|
authenticated = True
|
||||||
|
|
||||||
|
request[KEY_AUTHENTICATED] = authenticated
|
||||||
|
|
||||||
|
return handler(request)
|
||||||
|
|
||||||
|
return auth_middleware_handler
|
||||||
|
|
||||||
|
|
||||||
|
def is_trusted_ip(request):
|
||||||
|
"""Test if request is from a trusted ip."""
|
||||||
|
ip_addr = get_real_ip(request)
|
||||||
|
|
||||||
|
return ip_addr and any(
|
||||||
|
ip_addr in trusted_network for trusted_network
|
||||||
|
in request.app[KEY_TRUSTED_NETWORKS])
|
132
homeassistant/components/http/ban.py
Normal file
132
homeassistant/components/http/ban.py
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
"""Ban logic for HTTP component."""
|
||||||
|
import asyncio
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime
|
||||||
|
from ipaddress import ip_address
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from aiohttp.web_exceptions import HTTPForbidden
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components import persistent_notification
|
||||||
|
from homeassistant.config import load_yaml_config_file
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
from homeassistant.util.yaml import dump
|
||||||
|
from .const import (
|
||||||
|
KEY_BANS_ENABLED, KEY_BANNED_IPS, KEY_LOGIN_THRESHOLD,
|
||||||
|
KEY_FAILED_LOGIN_ATTEMPTS)
|
||||||
|
from .util import get_real_ip
|
||||||
|
|
||||||
|
NOTIFICATION_ID_BAN = 'ip-ban'
|
||||||
|
|
||||||
|
IP_BANS_FILE = 'ip_bans.yaml'
|
||||||
|
ATTR_BANNED_AT = "banned_at"
|
||||||
|
|
||||||
|
SCHEMA_IP_BAN_ENTRY = vol.Schema({
|
||||||
|
vol.Optional('banned_at'): vol.Any(None, cv.datetime)
|
||||||
|
})
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def ban_middleware(app, handler):
|
||||||
|
"""IP Ban middleware."""
|
||||||
|
if not app[KEY_BANS_ENABLED]:
|
||||||
|
return handler
|
||||||
|
|
||||||
|
if KEY_BANNED_IPS not in app:
|
||||||
|
hass = app['hass']
|
||||||
|
app[KEY_BANNED_IPS] = yield from hass.loop.run_in_executor(
|
||||||
|
None, load_ip_bans_config, hass.config.path(IP_BANS_FILE))
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def ban_middleware_handler(request):
|
||||||
|
"""Verify if IP is not banned."""
|
||||||
|
ip_address_ = get_real_ip(request)
|
||||||
|
|
||||||
|
is_banned = any(ip_ban.ip_address == ip_address_
|
||||||
|
for ip_ban in request.app[KEY_BANNED_IPS])
|
||||||
|
|
||||||
|
if is_banned:
|
||||||
|
raise HTTPForbidden()
|
||||||
|
|
||||||
|
return handler(request)
|
||||||
|
|
||||||
|
return ban_middleware_handler
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def process_wrong_login(request):
|
||||||
|
"""Process a wrong login attempt."""
|
||||||
|
if (not request.app[KEY_BANS_ENABLED] or
|
||||||
|
request.app[KEY_LOGIN_THRESHOLD] < 1):
|
||||||
|
return
|
||||||
|
|
||||||
|
if KEY_FAILED_LOGIN_ATTEMPTS not in request.app:
|
||||||
|
request.app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
|
||||||
|
|
||||||
|
remote_addr = get_real_ip(request)
|
||||||
|
|
||||||
|
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1
|
||||||
|
|
||||||
|
if (request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] >
|
||||||
|
request.app[KEY_LOGIN_THRESHOLD]):
|
||||||
|
new_ban = IpBan(remote_addr)
|
||||||
|
request.app[KEY_BANNED_IPS].append(new_ban)
|
||||||
|
|
||||||
|
hass = request.app['hass']
|
||||||
|
yield from hass.loop.run_in_executor(
|
||||||
|
None, update_ip_bans_config, hass.config.path(IP_BANS_FILE),
|
||||||
|
new_ban)
|
||||||
|
|
||||||
|
_LOGGER.warning('Banned IP %s for too many login attempts',
|
||||||
|
remote_addr)
|
||||||
|
|
||||||
|
persistent_notification.async_create(
|
||||||
|
hass,
|
||||||
|
'Too many login attempts from {}'.format(remote_addr),
|
||||||
|
'Banning IP address', NOTIFICATION_ID_BAN)
|
||||||
|
|
||||||
|
|
||||||
|
class IpBan(object):
|
||||||
|
"""Represents banned IP address."""
|
||||||
|
|
||||||
|
def __init__(self, ip_ban: str, banned_at: datetime=None) -> None:
|
||||||
|
"""Initializing Ip Ban object."""
|
||||||
|
self.ip_address = ip_address(ip_ban)
|
||||||
|
self.banned_at = banned_at or datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
|
def load_ip_bans_config(path: str):
|
||||||
|
"""Loading list of banned IPs from config file."""
|
||||||
|
ip_list = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
list_ = load_yaml_config_file(path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return []
|
||||||
|
except HomeAssistantError as err:
|
||||||
|
_LOGGER.error('Unable to load %s: %s', path, str(err))
|
||||||
|
return []
|
||||||
|
|
||||||
|
for ip_ban, ip_info in list_.items():
|
||||||
|
try:
|
||||||
|
ip_info = SCHEMA_IP_BAN_ENTRY(ip_info)
|
||||||
|
ip_list.append(IpBan(ip_ban, ip_info['banned_at']))
|
||||||
|
except vol.Invalid as err:
|
||||||
|
_LOGGER.error('Failed to load IP ban %s: %s', ip_info, err)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return ip_list
|
||||||
|
|
||||||
|
|
||||||
|
def update_ip_bans_config(path: str, ip_ban: IpBan):
|
||||||
|
"""Update config file with new banned IP address."""
|
||||||
|
with open(path, 'a') as out:
|
||||||
|
ip_ = {str(ip_ban.ip_address): {
|
||||||
|
ATTR_BANNED_AT: ip_ban.banned_at.strftime("%Y-%m-%dT%H:%M:%S")
|
||||||
|
}}
|
||||||
|
out.write('\n')
|
||||||
|
out.write(dump(ip_))
|
12
homeassistant/components/http/const.py
Normal file
12
homeassistant/components/http/const.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
"""HTTP specific constants."""
|
||||||
|
KEY_AUTHENTICATED = 'ha_authenticated'
|
||||||
|
KEY_USE_X_FORWARDED_FOR = 'ha_use_x_forwarded_for'
|
||||||
|
KEY_TRUSTED_NETWORKS = 'ha_trusted_networks'
|
||||||
|
KEY_REAL_IP = 'ha_real_ip'
|
||||||
|
KEY_BANS_ENABLED = 'ha_bans_enabled'
|
||||||
|
KEY_BANNED_IPS = 'ha_banned_ips'
|
||||||
|
KEY_FAILED_LOGIN_ATTEMPTS = 'ha_failed_login_attempts'
|
||||||
|
KEY_LOGIN_THRESHOLD = 'ha_login_treshold'
|
||||||
|
KEY_DEVELOPMENT = 'ha_development'
|
||||||
|
|
||||||
|
HTTP_HEADER_X_FORWARDED_FOR = 'X-Forwarded-For'
|
93
homeassistant/components/http/static.py
Normal file
93
homeassistant/components/http/static.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
"""Static file handling for HTTP component."""
|
||||||
|
import asyncio
|
||||||
|
import mimetypes
|
||||||
|
import re
|
||||||
|
|
||||||
|
from aiohttp import hdrs
|
||||||
|
from aiohttp.file_sender import FileSender
|
||||||
|
from aiohttp.web_urldispatcher import StaticResource
|
||||||
|
from aiohttp.web_exceptions import HTTPNotModified
|
||||||
|
|
||||||
|
from .const import KEY_DEVELOPMENT
|
||||||
|
|
||||||
|
_FINGERPRINT = re.compile(r'^(.+)-[a-z0-9]{32}\.(\w+)$', re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
class GzipFileSender(FileSender):
|
||||||
|
"""FileSender class capable of sending gzip version if available."""
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def send(self, request, filepath):
|
||||||
|
"""Send filepath to client using request."""
|
||||||
|
gzip = False
|
||||||
|
if 'gzip' in request.headers[hdrs.ACCEPT_ENCODING]:
|
||||||
|
gzip_path = filepath.with_name(filepath.name + '.gz')
|
||||||
|
|
||||||
|
if gzip_path.is_file():
|
||||||
|
filepath = gzip_path
|
||||||
|
gzip = True
|
||||||
|
|
||||||
|
st = filepath.stat()
|
||||||
|
|
||||||
|
modsince = request.if_modified_since
|
||||||
|
if modsince is not None and st.st_mtime <= modsince.timestamp():
|
||||||
|
raise HTTPNotModified()
|
||||||
|
|
||||||
|
ct, encoding = mimetypes.guess_type(str(filepath))
|
||||||
|
if not ct:
|
||||||
|
ct = 'application/octet-stream'
|
||||||
|
|
||||||
|
resp = self._response_factory()
|
||||||
|
resp.content_type = ct
|
||||||
|
if encoding:
|
||||||
|
resp.headers[hdrs.CONTENT_ENCODING] = encoding
|
||||||
|
if gzip:
|
||||||
|
resp.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
|
||||||
|
resp.last_modified = st.st_mtime
|
||||||
|
|
||||||
|
# CACHE HACK
|
||||||
|
if not request.app[KEY_DEVELOPMENT]:
|
||||||
|
cache_time = 31 * 86400 # = 1 month
|
||||||
|
resp.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format(
|
||||||
|
cache_time)
|
||||||
|
|
||||||
|
file_size = st.st_size
|
||||||
|
|
||||||
|
resp.content_length = file_size
|
||||||
|
with filepath.open('rb') as f:
|
||||||
|
yield from self._sendfile(request, resp, f, file_size)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
GZIP_FILE_SENDER = GzipFileSender()
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def staticresource_middleware(app, handler):
|
||||||
|
"""Enhance StaticResourceHandler middleware.
|
||||||
|
|
||||||
|
Adds gzip encoding and fingerprinting matching.
|
||||||
|
"""
|
||||||
|
inst = getattr(handler, '__self__', None)
|
||||||
|
if not isinstance(inst, StaticResource):
|
||||||
|
return handler
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
inst._file_sender = GZIP_FILE_SENDER
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def static_middleware_handler(request):
|
||||||
|
"""Strip out fingerprints from resource names."""
|
||||||
|
fingerprinted = _FINGERPRINT.match(request.match_info['filename'])
|
||||||
|
|
||||||
|
if fingerprinted:
|
||||||
|
request.match_info['filename'] = \
|
||||||
|
'{}.{}'.format(*fingerprinted.groups())
|
||||||
|
|
||||||
|
resp = yield from handler(request)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
return static_middleware_handler
|
25
homeassistant/components/http/util.py
Normal file
25
homeassistant/components/http/util.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
"""HTTP utilities."""
|
||||||
|
from ipaddress import ip_address
|
||||||
|
|
||||||
|
from .const import (
|
||||||
|
KEY_REAL_IP, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
|
||||||
|
|
||||||
|
|
||||||
|
def get_real_ip(request):
|
||||||
|
"""Get IP address of client."""
|
||||||
|
if KEY_REAL_IP in request:
|
||||||
|
return request[KEY_REAL_IP]
|
||||||
|
|
||||||
|
if (request.app[KEY_USE_X_FORWARDED_FOR] and
|
||||||
|
HTTP_HEADER_X_FORWARDED_FOR in request.headers):
|
||||||
|
request[KEY_REAL_IP] = ip_address(
|
||||||
|
request.headers.get(HTTP_HEADER_X_FORWARDED_FOR).split(',')[0])
|
||||||
|
else:
|
||||||
|
peername = request.transport.get_extra_info('peername')
|
||||||
|
|
||||||
|
if peername:
|
||||||
|
request[KEY_REAL_IP] = ip_address(peername[0])
|
||||||
|
else:
|
||||||
|
request[KEY_REAL_IP] = None
|
||||||
|
|
||||||
|
return request[KEY_REAL_IP]
|
|
@ -250,11 +250,10 @@ def setup(hass, config):
|
||||||
|
|
||||||
discovery.load_platform(hass, "sensor", DOMAIN, {}, config)
|
discovery.load_platform(hass, "sensor", DOMAIN, {}, config)
|
||||||
|
|
||||||
hass.http.register_view(iOSIdentifyDeviceView(hass))
|
hass.http.register_view(iOSIdentifyDeviceView)
|
||||||
|
|
||||||
app_config = config.get(DOMAIN, {})
|
app_config = config.get(DOMAIN, {})
|
||||||
hass.http.register_view(iOSPushConfigView(hass,
|
hass.http.register_view(iOSPushConfigView(app_config.get(CONF_PUSH, {})))
|
||||||
app_config.get(CONF_PUSH, {})))
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -266,9 +265,8 @@ class iOSPushConfigView(HomeAssistantView):
|
||||||
url = "/api/ios/push"
|
url = "/api/ios/push"
|
||||||
name = "api:ios:push"
|
name = "api:ios:push"
|
||||||
|
|
||||||
def __init__(self, hass, push_config):
|
def __init__(self, push_config):
|
||||||
"""Init the view."""
|
"""Init the view."""
|
||||||
super().__init__(hass)
|
|
||||||
self.push_config = push_config
|
self.push_config = push_config
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -283,10 +281,6 @@ class iOSIdentifyDeviceView(HomeAssistantView):
|
||||||
url = "/api/ios/identify"
|
url = "/api/ios/identify"
|
||||||
name = "api:ios:identify"
|
name = "api:ios:identify"
|
||||||
|
|
||||||
def __init__(self, hass):
|
|
||||||
"""Init the view."""
|
|
||||||
super().__init__(hass)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
"""Handle the POST request for device identification."""
|
"""Handle the POST request for device identification."""
|
||||||
|
|
|
@ -101,7 +101,7 @@ def setup(hass, config):
|
||||||
message = message.async_render()
|
message = message.async_render()
|
||||||
async_log_entry(hass, name, message, domain, entity_id)
|
async_log_entry(hass, name, message, domain, entity_id)
|
||||||
|
|
||||||
hass.http.register_view(LogbookView(hass, config))
|
hass.http.register_view(LogbookView(config))
|
||||||
|
|
||||||
register_built_in_panel(hass, 'logbook', 'Logbook',
|
register_built_in_panel(hass, 'logbook', 'Logbook',
|
||||||
'mdi:format-list-bulleted-type')
|
'mdi:format-list-bulleted-type')
|
||||||
|
@ -118,9 +118,8 @@ class LogbookView(HomeAssistantView):
|
||||||
name = 'api:logbook'
|
name = 'api:logbook'
|
||||||
extra_urls = ['/api/logbook/{datetime}']
|
extra_urls = ['/api/logbook/{datetime}']
|
||||||
|
|
||||||
def __init__(self, hass, config):
|
def __init__(self, config):
|
||||||
"""Initilalize the logbook view."""
|
"""Initilalize the logbook view."""
|
||||||
super().__init__(hass)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
@ -146,7 +145,8 @@ class LogbookView(HomeAssistantView):
|
||||||
events = recorder.execute(query)
|
events = recorder.execute(query)
|
||||||
return _exclude_events(events, self.config)
|
return _exclude_events(events, self.config)
|
||||||
|
|
||||||
events = yield from self.hass.loop.run_in_executor(None, get_results)
|
events = yield from request.app['hass'].loop.run_in_executor(
|
||||||
|
None, get_results)
|
||||||
|
|
||||||
return self.json(humanify(events))
|
return self.json(humanify(events))
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ from homeassistant.config import load_yaml_config_file
|
||||||
from homeassistant.helpers.entity import Entity
|
from homeassistant.helpers.entity import Entity
|
||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa
|
from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView, KEY_AUTHENTICATED
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.util.async import run_coroutine_threadsafe
|
from homeassistant.util.async import run_coroutine_threadsafe
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
|
@ -304,7 +304,7 @@ def setup(hass, config):
|
||||||
component = EntityComponent(
|
component = EntityComponent(
|
||||||
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL)
|
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL)
|
||||||
|
|
||||||
hass.http.register_view(MediaPlayerImageView(hass, component.entities))
|
hass.http.register_view(MediaPlayerImageView(component.entities))
|
||||||
|
|
||||||
component.setup(config)
|
component.setup(config)
|
||||||
|
|
||||||
|
@ -736,9 +736,8 @@ class MediaPlayerImageView(HomeAssistantView):
|
||||||
url = "/api/media_player_proxy/{entity_id}"
|
url = "/api/media_player_proxy/{entity_id}"
|
||||||
name = "api:media_player:image"
|
name = "api:media_player:image"
|
||||||
|
|
||||||
def __init__(self, hass, entities):
|
def __init__(self, entities):
|
||||||
"""Initialize a media player view."""
|
"""Initialize a media player view."""
|
||||||
super().__init__(hass)
|
|
||||||
self.entities = entities
|
self.entities = entities
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
@ -748,14 +747,14 @@ class MediaPlayerImageView(HomeAssistantView):
|
||||||
if player is None:
|
if player is None:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
authenticated = (request.authenticated or
|
authenticated = (request[KEY_AUTHENTICATED] or
|
||||||
request.GET.get('token') == player.access_token)
|
request.GET.get('token') == player.access_token)
|
||||||
|
|
||||||
if not authenticated:
|
if not authenticated:
|
||||||
return web.Response(status=401)
|
return web.Response(status=401)
|
||||||
|
|
||||||
data, content_type = yield from _async_fetch_image(
|
data, content_type = yield from _async_fetch_image(
|
||||||
self.hass, player.media_image_url)
|
request.app['hass'], player.media_image_url)
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
|
@ -107,8 +107,8 @@ def get_service(hass, config):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
hass.http.register_view(
|
hass.http.register_view(
|
||||||
HTML5PushRegistrationView(hass, registrations, json_path))
|
HTML5PushRegistrationView(registrations, json_path))
|
||||||
hass.http.register_view(HTML5PushCallbackView(hass, registrations))
|
hass.http.register_view(HTML5PushCallbackView(registrations))
|
||||||
|
|
||||||
gcm_api_key = config.get(ATTR_GCM_API_KEY)
|
gcm_api_key = config.get(ATTR_GCM_API_KEY)
|
||||||
gcm_sender_id = config.get(ATTR_GCM_SENDER_ID)
|
gcm_sender_id = config.get(ATTR_GCM_SENDER_ID)
|
||||||
|
@ -168,9 +168,8 @@ class HTML5PushRegistrationView(HomeAssistantView):
|
||||||
url = '/api/notify.html5'
|
url = '/api/notify.html5'
|
||||||
name = 'api:notify.html5'
|
name = 'api:notify.html5'
|
||||||
|
|
||||||
def __init__(self, hass, registrations, json_path):
|
def __init__(self, registrations, json_path):
|
||||||
"""Init HTML5PushRegistrationView."""
|
"""Init HTML5PushRegistrationView."""
|
||||||
super().__init__(hass)
|
|
||||||
self.registrations = registrations
|
self.registrations = registrations
|
||||||
self.json_path = json_path
|
self.json_path = json_path
|
||||||
|
|
||||||
|
@ -237,9 +236,8 @@ class HTML5PushCallbackView(HomeAssistantView):
|
||||||
url = '/api/notify.html5/callback'
|
url = '/api/notify.html5/callback'
|
||||||
name = 'api:notify.html5/callback'
|
name = 'api:notify.html5/callback'
|
||||||
|
|
||||||
def __init__(self, hass, registrations):
|
def __init__(self, registrations):
|
||||||
"""Init HTML5PushCallbackView."""
|
"""Init HTML5PushCallbackView."""
|
||||||
super().__init__(hass)
|
|
||||||
self.registrations = registrations
|
self.registrations = registrations
|
||||||
|
|
||||||
def decode_jwt(self, token):
|
def decode_jwt(self, token):
|
||||||
|
@ -324,7 +322,7 @@ class HTML5PushCallbackView(HomeAssistantView):
|
||||||
|
|
||||||
event_name = '{}.{}'.format(NOTIFY_CALLBACK_EVENT,
|
event_name = '{}.{}'.format(NOTIFY_CALLBACK_EVENT,
|
||||||
event_payload[ATTR_TYPE])
|
event_payload[ATTR_TYPE])
|
||||||
self.hass.bus.fire(event_name, event_payload)
|
request.app['hass'].bus.fire(event_name, event_payload)
|
||||||
return self.json({'status': 'ok',
|
return self.json({'status': 'ok',
|
||||||
'event': event_payload[ATTR_TYPE]})
|
'event': event_payload[ATTR_TYPE]})
|
||||||
|
|
||||||
|
|
|
@ -274,7 +274,7 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
|
||||||
|
|
||||||
hass.http.register_redirect(FITBIT_AUTH_START, fitbit_auth_start_url)
|
hass.http.register_redirect(FITBIT_AUTH_START, fitbit_auth_start_url)
|
||||||
hass.http.register_view(FitbitAuthCallbackView(
|
hass.http.register_view(FitbitAuthCallbackView(
|
||||||
hass, config, add_devices, oauth))
|
config, add_devices, oauth))
|
||||||
|
|
||||||
request_oauth_completion(hass)
|
request_oauth_completion(hass)
|
||||||
|
|
||||||
|
@ -286,9 +286,8 @@ class FitbitAuthCallbackView(HomeAssistantView):
|
||||||
url = '/auth/fitbit/callback'
|
url = '/auth/fitbit/callback'
|
||||||
name = 'auth:fitbit:callback'
|
name = 'auth:fitbit:callback'
|
||||||
|
|
||||||
def __init__(self, hass, config, add_devices, oauth):
|
def __init__(self, config, add_devices, oauth):
|
||||||
"""Initialize the OAuth callback view."""
|
"""Initialize the OAuth callback view."""
|
||||||
super().__init__(hass)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.add_devices = add_devices
|
self.add_devices = add_devices
|
||||||
self.oauth = oauth
|
self.oauth = oauth
|
||||||
|
@ -299,6 +298,7 @@ class FitbitAuthCallbackView(HomeAssistantView):
|
||||||
from oauthlib.oauth2.rfc6749.errors import MismatchingStateError
|
from oauthlib.oauth2.rfc6749.errors import MismatchingStateError
|
||||||
from oauthlib.oauth2.rfc6749.errors import MissingTokenError
|
from oauthlib.oauth2.rfc6749.errors import MissingTokenError
|
||||||
|
|
||||||
|
hass = request.app['hass']
|
||||||
data = request.GET
|
data = request.GET
|
||||||
|
|
||||||
response_message = """Fitbit has been successfully authorized!
|
response_message = """Fitbit has been successfully authorized!
|
||||||
|
@ -306,7 +306,7 @@ class FitbitAuthCallbackView(HomeAssistantView):
|
||||||
|
|
||||||
if data.get('code') is not None:
|
if data.get('code') is not None:
|
||||||
redirect_uri = '{}{}'.format(
|
redirect_uri = '{}{}'.format(
|
||||||
self.hass.config.api.base_url, FITBIT_AUTH_CALLBACK_PATH)
|
hass.config.api.base_url, FITBIT_AUTH_CALLBACK_PATH)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.oauth.fetch_access_token(data.get('code'), redirect_uri)
|
self.oauth.fetch_access_token(data.get('code'), redirect_uri)
|
||||||
|
@ -336,12 +336,11 @@ class FitbitAuthCallbackView(HomeAssistantView):
|
||||||
ATTR_CLIENT_ID: self.oauth.client_id,
|
ATTR_CLIENT_ID: self.oauth.client_id,
|
||||||
ATTR_CLIENT_SECRET: self.oauth.client_secret
|
ATTR_CLIENT_SECRET: self.oauth.client_secret
|
||||||
}
|
}
|
||||||
if not config_from_file(self.hass.config.path(FITBIT_CONFIG_FILE),
|
if not config_from_file(hass.config.path(FITBIT_CONFIG_FILE),
|
||||||
config_contents):
|
config_contents):
|
||||||
_LOGGER.error("Failed to save config file")
|
_LOGGER.error("Failed to save config file")
|
||||||
|
|
||||||
self.hass.async_add_job(setup_platform, self.hass, self.config,
|
hass.async_add_job(setup_platform, hass, self.config, self.add_devices)
|
||||||
self.add_devices)
|
|
||||||
|
|
||||||
return html_response
|
return html_response
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,7 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
|
||||||
sensors = {}
|
sensors = {}
|
||||||
|
|
||||||
hass.http.register_view(TorqueReceiveDataView(
|
hass.http.register_view(TorqueReceiveDataView(
|
||||||
hass, email, vehicle, sensors, add_devices))
|
email, vehicle, sensors, add_devices))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,9 +69,8 @@ class TorqueReceiveDataView(HomeAssistantView):
|
||||||
url = API_PATH
|
url = API_PATH
|
||||||
name = 'api:torque'
|
name = 'api:torque'
|
||||||
|
|
||||||
def __init__(self, hass, email, vehicle, sensors, add_devices):
|
def __init__(self, email, vehicle, sensors, add_devices):
|
||||||
"""Initialize a Torque view."""
|
"""Initialize a Torque view."""
|
||||||
super().__init__(hass)
|
|
||||||
self.email = email
|
self.email = email
|
||||||
self.vehicle = vehicle
|
self.vehicle = vehicle
|
||||||
self.sensors = sensors
|
self.sensors = sensors
|
||||||
|
@ -80,6 +79,7 @@ class TorqueReceiveDataView(HomeAssistantView):
|
||||||
@callback
|
@callback
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
"""Handle Torque data request."""
|
"""Handle Torque data request."""
|
||||||
|
hass = request.app['hass']
|
||||||
data = request.GET
|
data = request.GET
|
||||||
|
|
||||||
if self.email is not None and self.email != data[SENSOR_EMAIL_FIELD]:
|
if self.email is not None and self.email != data[SENSOR_EMAIL_FIELD]:
|
||||||
|
@ -108,7 +108,7 @@ class TorqueReceiveDataView(HomeAssistantView):
|
||||||
self.sensors[pid] = TorqueSensor(
|
self.sensors[pid] = TorqueSensor(
|
||||||
ENTITY_NAME_FORMAT.format(self.vehicle, names[pid]),
|
ENTITY_NAME_FORMAT.format(self.vehicle, names[pid]),
|
||||||
units.get(pid, None))
|
units.get(pid, None))
|
||||||
self.hass.async_add_job(self.add_devices, [self.sensors[pid]])
|
hass.async_add_job(self.add_devices, [self.sensors[pid]])
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -97,6 +97,7 @@ class NetioApiView(HomeAssistantView):
|
||||||
@callback
|
@callback
|
||||||
def get(self, request, host):
|
def get(self, request, host):
|
||||||
"""Request handler."""
|
"""Request handler."""
|
||||||
|
hass = request.app['hass']
|
||||||
data = request.GET
|
data = request.GET
|
||||||
states, consumptions, cumulated_consumptions, start_dates = \
|
states, consumptions, cumulated_consumptions, start_dates = \
|
||||||
[], [], [], []
|
[], [], [], []
|
||||||
|
@ -119,7 +120,7 @@ class NetioApiView(HomeAssistantView):
|
||||||
ndev.start_dates = start_dates
|
ndev.start_dates = start_dates
|
||||||
|
|
||||||
for dev in DEVICES[host].entities:
|
for dev in DEVICES[host].entities:
|
||||||
self.hass.async_add_job(dev.async_update_ha_state())
|
hass.async_add_job(dev.async_update_ha_state())
|
||||||
|
|
||||||
return self.json(True)
|
return self.json(True)
|
||||||
|
|
||||||
|
|
|
@ -360,7 +360,6 @@ HTTP_HEADER_CONTENT_LENGTH = 'Content-Length'
|
||||||
HTTP_HEADER_CACHE_CONTROL = 'Cache-Control'
|
HTTP_HEADER_CACHE_CONTROL = 'Cache-Control'
|
||||||
HTTP_HEADER_EXPIRES = 'Expires'
|
HTTP_HEADER_EXPIRES = 'Expires'
|
||||||
HTTP_HEADER_ORIGIN = 'Origin'
|
HTTP_HEADER_ORIGIN = 'Origin'
|
||||||
HTTP_HEADER_X_FORWARDED_FOR = 'X-Forwarded-For'
|
|
||||||
HTTP_HEADER_X_REQUESTED_WITH = 'X-Requested-With'
|
HTTP_HEADER_X_REQUESTED_WITH = 'X-Requested-With'
|
||||||
HTTP_HEADER_ACCEPT = 'Accept'
|
HTTP_HEADER_ACCEPT = 'Accept'
|
||||||
HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN = 'Access-Control-Allow-Origin'
|
HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN = 'Access-Control-Allow-Origin'
|
||||||
|
|
17
homeassistant/util/logging.py
Normal file
17
homeassistant/util/logging.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
"""Logging utilities."""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class HideSensitiveDataFilter(logging.Filter):
|
||||||
|
"""Filter API password calls."""
|
||||||
|
|
||||||
|
def __init__(self, text):
|
||||||
|
"""Initialize sensitive data filter."""
|
||||||
|
super().__init__()
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
"""Hide sensitive data in messages."""
|
||||||
|
record.msg = record.msg.replace(self.text, '*******')
|
||||||
|
|
||||||
|
return True
|
|
@ -10,6 +10,8 @@ import logging
|
||||||
import threading
|
import threading
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
from homeassistant import core as ha, loader
|
from homeassistant import core as ha, loader
|
||||||
from homeassistant.bootstrap import (
|
from homeassistant.bootstrap import (
|
||||||
setup_component, async_prepare_setup_component)
|
setup_component, async_prepare_setup_component)
|
||||||
|
@ -22,6 +24,9 @@ from homeassistant.const import (
|
||||||
EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE,
|
EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE,
|
||||||
ATTR_DISCOVERED, SERVER_PORT)
|
ATTR_DISCOVERED, SERVER_PORT)
|
||||||
from homeassistant.components import sun, mqtt
|
from homeassistant.components import sun, mqtt
|
||||||
|
from homeassistant.components.http.auth import auth_middleware
|
||||||
|
from homeassistant.components.http.const import (
|
||||||
|
KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED)
|
||||||
|
|
||||||
_TEST_INSTANCE_PORT = SERVER_PORT
|
_TEST_INSTANCE_PORT = SERVER_PORT
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
@ -210,13 +215,23 @@ def mock_http_component(hass):
|
||||||
"""Store registered view."""
|
"""Store registered view."""
|
||||||
if isinstance(view, type):
|
if isinstance(view, type):
|
||||||
# Instantiate the view, if needed
|
# Instantiate the view, if needed
|
||||||
view = view(hass)
|
view = view()
|
||||||
|
|
||||||
hass.http.views[view.name] = view
|
hass.http.views[view.name] = view
|
||||||
|
|
||||||
hass.http.register_view = mock_register_view
|
hass.http.register_view = mock_register_view
|
||||||
|
|
||||||
|
|
||||||
|
def mock_http_component_app(hass):
|
||||||
|
"""Create an aiohttp.web.Application instance for testing."""
|
||||||
|
hass.http.api_password = None
|
||||||
|
app = web.Application(middlewares=[auth_middleware], loop=hass.loop)
|
||||||
|
app['hass'] = hass
|
||||||
|
app[KEY_USE_X_FORWARDED_FOR] = False
|
||||||
|
app[KEY_BANS_ENABLED] = False
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
def mock_mqtt_component(hass):
|
def mock_mqtt_component(hass):
|
||||||
"""Mock the MQTT component."""
|
"""Mock the MQTT component."""
|
||||||
with mock.patch('homeassistant.components.mqtt.MQTT') as mock_mqtt:
|
with mock.patch('homeassistant.components.mqtt.MQTT') as mock_mqtt:
|
||||||
|
|
|
@ -27,8 +27,8 @@ def test_fetching_url(aioclient_mock, hass, test_client):
|
||||||
|
|
||||||
resp = yield from client.get('/api/camera_proxy/camera.config_test')
|
resp = yield from client.get('/api/camera_proxy/camera.config_test')
|
||||||
|
|
||||||
assert aioclient_mock.call_count == 1
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
assert aioclient_mock.call_count == 1
|
||||||
body = yield from resp.text()
|
body = yield from resp.text()
|
||||||
assert body == 'hello world'
|
assert body == 'hello world'
|
||||||
|
|
||||||
|
|
1
tests/components/http/__init__.py
Normal file
1
tests/components/http/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
"""Tests for the HTTP component."""
|
169
tests/components/http/test_auth.py
Normal file
169
tests/components/http/test_auth.py
Normal file
|
@ -0,0 +1,169 @@
|
||||||
|
"""The tests for the Home Assistant HTTP component."""
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
import logging
|
||||||
|
from ipaddress import ip_address, ip_network
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from homeassistant import bootstrap, const
|
||||||
|
import homeassistant.components.http as http
|
||||||
|
from homeassistant.components.http.const import (
|
||||||
|
KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
|
||||||
|
|
||||||
|
from tests.common import get_test_instance_port, get_test_home_assistant
|
||||||
|
|
||||||
|
API_PASSWORD = 'test1234'
|
||||||
|
SERVER_PORT = get_test_instance_port()
|
||||||
|
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
|
||||||
|
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
|
||||||
|
HA_HEADERS = {
|
||||||
|
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
|
||||||
|
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
|
||||||
|
}
|
||||||
|
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
|
||||||
|
TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1',
|
||||||
|
'FD01:DB8::1']
|
||||||
|
TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
|
||||||
|
'2001:DB8:ABCD::1']
|
||||||
|
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
|
||||||
|
|
||||||
|
hass = None
|
||||||
|
|
||||||
|
|
||||||
|
def _url(path=''):
|
||||||
|
"""Helper method to generate URLs."""
|
||||||
|
return HTTP_BASE_URL + path
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
def setUpModule():
|
||||||
|
"""Initialize a Home Assistant server."""
|
||||||
|
global hass
|
||||||
|
|
||||||
|
hass = get_test_home_assistant()
|
||||||
|
|
||||||
|
bootstrap.setup_component(
|
||||||
|
hass, http.DOMAIN, {
|
||||||
|
http.DOMAIN: {
|
||||||
|
http.CONF_API_PASSWORD: API_PASSWORD,
|
||||||
|
http.CONF_SERVER_PORT: SERVER_PORT,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
bootstrap.setup_component(hass, 'api')
|
||||||
|
|
||||||
|
hass.http.app[KEY_TRUSTED_NETWORKS] = [
|
||||||
|
ip_network(trusted_network)
|
||||||
|
for trusted_network in TRUSTED_NETWORKS]
|
||||||
|
|
||||||
|
hass.start()
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
def tearDownModule():
|
||||||
|
"""Stop the Home Assistant server."""
|
||||||
|
hass.stop()
|
||||||
|
|
||||||
|
|
||||||
|
class TestHttp:
|
||||||
|
"""Test HTTP component."""
|
||||||
|
|
||||||
|
def test_access_denied_without_password(self):
|
||||||
|
"""Test access without password."""
|
||||||
|
req = requests.get(_url(const.URL_API))
|
||||||
|
|
||||||
|
assert req.status_code == 401
|
||||||
|
|
||||||
|
def test_access_denied_with_wrong_password_in_header(self):
|
||||||
|
"""Test access with wrong password."""
|
||||||
|
req = requests.get(
|
||||||
|
_url(const.URL_API),
|
||||||
|
headers={const.HTTP_HEADER_HA_AUTH: 'wrongpassword'})
|
||||||
|
|
||||||
|
assert req.status_code == 401
|
||||||
|
|
||||||
|
def test_access_denied_with_x_forwarded_for(self, caplog):
|
||||||
|
"""Test access denied through the X-Forwarded-For http header."""
|
||||||
|
hass.http.use_x_forwarded_for = True
|
||||||
|
for remote_addr in UNTRUSTED_ADDRESSES:
|
||||||
|
req = requests.get(_url(const.URL_API), headers={
|
||||||
|
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
|
||||||
|
|
||||||
|
assert req.status_code == 401, \
|
||||||
|
"{} shouldn't be trusted".format(remote_addr)
|
||||||
|
|
||||||
|
def test_access_denied_with_untrusted_ip(self, caplog):
|
||||||
|
"""Test access with an untrusted ip address."""
|
||||||
|
for remote_addr in UNTRUSTED_ADDRESSES:
|
||||||
|
with patch('homeassistant.components.http.'
|
||||||
|
'util.get_real_ip',
|
||||||
|
return_value=ip_address(remote_addr)):
|
||||||
|
req = requests.get(
|
||||||
|
_url(const.URL_API), params={'api_password': ''})
|
||||||
|
|
||||||
|
assert req.status_code == 401, \
|
||||||
|
"{} shouldn't be trusted".format(remote_addr)
|
||||||
|
|
||||||
|
def test_access_with_password_in_header(self, caplog):
|
||||||
|
"""Test access with password in URL."""
|
||||||
|
# Hide logging from requests package that we use to test logging
|
||||||
|
caplog.set_level(
|
||||||
|
logging.WARNING, logger='requests.packages.urllib3.connectionpool')
|
||||||
|
|
||||||
|
req = requests.get(
|
||||||
|
_url(const.URL_API),
|
||||||
|
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||||
|
|
||||||
|
assert req.status_code == 200
|
||||||
|
|
||||||
|
logs = caplog.text
|
||||||
|
|
||||||
|
assert const.URL_API in logs
|
||||||
|
assert API_PASSWORD not in logs
|
||||||
|
|
||||||
|
def test_access_denied_with_wrong_password_in_url(self):
|
||||||
|
"""Test access with wrong password."""
|
||||||
|
req = requests.get(
|
||||||
|
_url(const.URL_API), params={'api_password': 'wrongpassword'})
|
||||||
|
|
||||||
|
assert req.status_code == 401
|
||||||
|
|
||||||
|
def test_access_with_password_in_url(self, caplog):
|
||||||
|
"""Test access with password in URL."""
|
||||||
|
# Hide logging from requests package that we use to test logging
|
||||||
|
caplog.set_level(
|
||||||
|
logging.WARNING, logger='requests.packages.urllib3.connectionpool')
|
||||||
|
|
||||||
|
req = requests.get(
|
||||||
|
_url(const.URL_API), params={'api_password': API_PASSWORD})
|
||||||
|
|
||||||
|
assert req.status_code == 200
|
||||||
|
|
||||||
|
logs = caplog.text
|
||||||
|
|
||||||
|
assert const.URL_API in logs
|
||||||
|
assert API_PASSWORD not in logs
|
||||||
|
|
||||||
|
def test_access_granted_with_x_forwarded_for(self, caplog):
|
||||||
|
"""Test access denied through the X-Forwarded-For http header."""
|
||||||
|
hass.http.app[KEY_USE_X_FORWARDED_FOR] = True
|
||||||
|
for remote_addr in TRUSTED_ADDRESSES:
|
||||||
|
req = requests.get(_url(const.URL_API), headers={
|
||||||
|
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
|
||||||
|
|
||||||
|
assert req.status_code == 200, \
|
||||||
|
"{} should be trusted".format(remote_addr)
|
||||||
|
|
||||||
|
def test_access_granted_with_trusted_ip(self, caplog):
|
||||||
|
"""Test access with trusted addresses."""
|
||||||
|
for remote_addr in TRUSTED_ADDRESSES:
|
||||||
|
with patch('homeassistant.components.http.'
|
||||||
|
'auth.get_real_ip',
|
||||||
|
return_value=ip_address(remote_addr)):
|
||||||
|
req = requests.get(
|
||||||
|
_url(const.URL_API), params={'api_password': ''})
|
||||||
|
|
||||||
|
assert req.status_code == 200, \
|
||||||
|
'{} should be trusted'.format(remote_addr)
|
118
tests/components/http/test_ban.py
Normal file
118
tests/components/http/test_ban.py
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
"""The tests for the Home Assistant HTTP component."""
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
from ipaddress import ip_address
|
||||||
|
from unittest.mock import patch, mock_open
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from homeassistant import bootstrap, const
|
||||||
|
import homeassistant.components.http as http
|
||||||
|
from homeassistant.components.http.const import (
|
||||||
|
KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD, KEY_BANNED_IPS)
|
||||||
|
from homeassistant.components.http.ban import IpBan, IP_BANS_FILE
|
||||||
|
|
||||||
|
from tests.common import get_test_instance_port, get_test_home_assistant
|
||||||
|
|
||||||
|
API_PASSWORD = 'test1234'
|
||||||
|
SERVER_PORT = get_test_instance_port()
|
||||||
|
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
|
||||||
|
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
|
||||||
|
HA_HEADERS = {
|
||||||
|
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
|
||||||
|
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
|
||||||
|
}
|
||||||
|
BANNED_IPS = ['200.201.202.203', '100.64.0.2']
|
||||||
|
|
||||||
|
hass = None
|
||||||
|
|
||||||
|
|
||||||
|
def _url(path=''):
|
||||||
|
"""Helper method to generate URLs."""
|
||||||
|
return HTTP_BASE_URL + path
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
def setUpModule():
|
||||||
|
"""Initialize a Home Assistant server."""
|
||||||
|
global hass
|
||||||
|
|
||||||
|
hass = get_test_home_assistant()
|
||||||
|
|
||||||
|
bootstrap.setup_component(
|
||||||
|
hass, http.DOMAIN, {
|
||||||
|
http.DOMAIN: {
|
||||||
|
http.CONF_API_PASSWORD: API_PASSWORD,
|
||||||
|
http.CONF_SERVER_PORT: SERVER_PORT,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
bootstrap.setup_component(hass, 'api')
|
||||||
|
|
||||||
|
hass.http.app[KEY_BANNED_IPS] = [IpBan(banned_ip) for banned_ip
|
||||||
|
in BANNED_IPS]
|
||||||
|
hass.start()
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
def tearDownModule():
|
||||||
|
"""Stop the Home Assistant server."""
|
||||||
|
hass.stop()
|
||||||
|
|
||||||
|
|
||||||
|
class TestHttp:
|
||||||
|
"""Test HTTP component."""
|
||||||
|
|
||||||
|
def test_access_from_banned_ip(self):
|
||||||
|
"""Test accessing to server from banned IP. Both trusted and not."""
|
||||||
|
hass.http.app[KEY_BANS_ENABLED] = True
|
||||||
|
for remote_addr in BANNED_IPS:
|
||||||
|
with patch('homeassistant.components.http.'
|
||||||
|
'ban.get_real_ip',
|
||||||
|
return_value=ip_address(remote_addr)):
|
||||||
|
req = requests.get(
|
||||||
|
_url(const.URL_API))
|
||||||
|
assert req.status_code == 403
|
||||||
|
|
||||||
|
def test_access_from_banned_ip_when_ban_is_off(self):
|
||||||
|
"""Test accessing to server from banned IP when feature is off"""
|
||||||
|
hass.http.app[KEY_BANS_ENABLED] = False
|
||||||
|
for remote_addr in BANNED_IPS:
|
||||||
|
with patch('homeassistant.components.http.'
|
||||||
|
'ban.get_real_ip',
|
||||||
|
return_value=ip_address(remote_addr)):
|
||||||
|
req = requests.get(
|
||||||
|
_url(const.URL_API),
|
||||||
|
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||||
|
assert req.status_code == 200
|
||||||
|
|
||||||
|
def test_ip_bans_file_creation(self):
|
||||||
|
"""Testing if banned IP file created"""
|
||||||
|
hass.http.app[KEY_BANS_ENABLED] = True
|
||||||
|
hass.http.app[KEY_LOGIN_THRESHOLD] = 1
|
||||||
|
|
||||||
|
m = mock_open()
|
||||||
|
|
||||||
|
def call_server():
|
||||||
|
with patch('homeassistant.components.http.'
|
||||||
|
'ban.get_real_ip',
|
||||||
|
return_value=ip_address("200.201.202.204")):
|
||||||
|
print("GETTING API")
|
||||||
|
return requests.get(
|
||||||
|
_url(const.URL_API),
|
||||||
|
headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'})
|
||||||
|
|
||||||
|
with patch('homeassistant.components.http.ban.open', m, create=True):
|
||||||
|
req = call_server()
|
||||||
|
assert req.status_code == 401
|
||||||
|
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS)
|
||||||
|
assert m.call_count == 0
|
||||||
|
|
||||||
|
req = call_server()
|
||||||
|
assert req.status_code == 401
|
||||||
|
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
|
||||||
|
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
|
||||||
|
|
||||||
|
req = call_server()
|
||||||
|
assert req.status_code == 403
|
||||||
|
assert m.call_count == 1
|
111
tests/components/http/test_init.py
Normal file
111
tests/components/http/test_init.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
"""The tests for the Home Assistant HTTP component."""
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from homeassistant import bootstrap, const
|
||||||
|
import homeassistant.components.http as http
|
||||||
|
|
||||||
|
from tests.common import get_test_instance_port, get_test_home_assistant
|
||||||
|
|
||||||
|
API_PASSWORD = 'test1234'
|
||||||
|
SERVER_PORT = get_test_instance_port()
|
||||||
|
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
|
||||||
|
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
|
||||||
|
HA_HEADERS = {
|
||||||
|
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
|
||||||
|
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
|
||||||
|
}
|
||||||
|
CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE]
|
||||||
|
|
||||||
|
hass = None
|
||||||
|
|
||||||
|
|
||||||
|
def _url(path=''):
|
||||||
|
"""Helper method to generate URLs."""
|
||||||
|
return HTTP_BASE_URL + path
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
def setUpModule():
|
||||||
|
"""Initialize a Home Assistant server."""
|
||||||
|
global hass
|
||||||
|
|
||||||
|
hass = get_test_home_assistant()
|
||||||
|
|
||||||
|
bootstrap.setup_component(
|
||||||
|
hass, http.DOMAIN, {
|
||||||
|
http.DOMAIN: {
|
||||||
|
http.CONF_API_PASSWORD: API_PASSWORD,
|
||||||
|
http.CONF_SERVER_PORT: SERVER_PORT,
|
||||||
|
http.CONF_CORS_ORIGINS: CORS_ORIGINS,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
bootstrap.setup_component(hass, 'api')
|
||||||
|
|
||||||
|
hass.start()
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
def tearDownModule():
|
||||||
|
"""Stop the Home Assistant server."""
|
||||||
|
hass.stop()
|
||||||
|
|
||||||
|
|
||||||
|
class TestHttp:
|
||||||
|
"""Test HTTP component."""
|
||||||
|
|
||||||
|
def test_cors_allowed_with_password_in_url(self):
|
||||||
|
"""Test cross origin resource sharing with password in url."""
|
||||||
|
req = requests.get(_url(const.URL_API),
|
||||||
|
params={'api_password': API_PASSWORD},
|
||||||
|
headers={const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL})
|
||||||
|
|
||||||
|
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
|
||||||
|
|
||||||
|
assert req.status_code == 200
|
||||||
|
assert req.headers.get(allow_origin) == HTTP_BASE_URL
|
||||||
|
|
||||||
|
def test_cors_allowed_with_password_in_header(self):
|
||||||
|
"""Test cross origin resource sharing with password in header."""
|
||||||
|
headers = {
|
||||||
|
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
|
||||||
|
const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL
|
||||||
|
}
|
||||||
|
req = requests.get(_url(const.URL_API), headers=headers)
|
||||||
|
|
||||||
|
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
|
||||||
|
|
||||||
|
assert req.status_code == 200
|
||||||
|
assert req.headers.get(allow_origin) == HTTP_BASE_URL
|
||||||
|
|
||||||
|
def test_cors_denied_without_origin_header(self):
|
||||||
|
"""Test cross origin resource sharing with password in header."""
|
||||||
|
headers = {
|
||||||
|
const.HTTP_HEADER_HA_AUTH: API_PASSWORD
|
||||||
|
}
|
||||||
|
req = requests.get(_url(const.URL_API), headers=headers)
|
||||||
|
|
||||||
|
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
|
||||||
|
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
|
||||||
|
|
||||||
|
assert req.status_code == 200
|
||||||
|
assert allow_origin not in req.headers
|
||||||
|
assert allow_headers not in req.headers
|
||||||
|
|
||||||
|
def test_cors_preflight_allowed(self):
|
||||||
|
"""Test cross origin resource sharing preflight (OPTIONS) request."""
|
||||||
|
headers = {
|
||||||
|
const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL,
|
||||||
|
'Access-Control-Request-Method': 'GET',
|
||||||
|
'Access-Control-Request-Headers': 'x-ha-access'
|
||||||
|
}
|
||||||
|
req = requests.options(_url(const.URL_API), headers=headers)
|
||||||
|
|
||||||
|
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
|
||||||
|
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
|
||||||
|
|
||||||
|
assert req.status_code == 200
|
||||||
|
assert req.headers.get(allow_origin) == HTTP_BASE_URL
|
||||||
|
assert req.headers.get(allow_headers) == \
|
||||||
|
const.HTTP_HEADER_HA_AUTH.upper()
|
|
@ -3,10 +3,10 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
from unittest.mock import patch, MagicMock, mock_open
|
from unittest.mock import patch, MagicMock, mock_open
|
||||||
|
|
||||||
from aiohttp import web
|
|
||||||
|
|
||||||
from homeassistant.components.notify import html5
|
from homeassistant.components.notify import html5
|
||||||
|
|
||||||
|
from tests.common import mock_http_component_app
|
||||||
|
|
||||||
SUBSCRIPTION_1 = {
|
SUBSCRIPTION_1 = {
|
||||||
'browser': 'chrome',
|
'browser': 'chrome',
|
||||||
'subscription': {
|
'subscription': {
|
||||||
|
@ -121,7 +121,8 @@ class TestHtml5Notify(object):
|
||||||
assert view.json_path == hass.config.path.return_value
|
assert view.json_path == hass.config.path.return_value
|
||||||
assert view.registrations == {}
|
assert view.registrations == {}
|
||||||
|
|
||||||
app = web.Application(loop=loop)
|
hass.loop = loop
|
||||||
|
app = mock_http_component_app(hass)
|
||||||
view.register(app.router)
|
view.register(app.router)
|
||||||
client = yield from test_client(app)
|
client = yield from test_client(app)
|
||||||
hass.http.is_banned_ip.return_value = False
|
hass.http.is_banned_ip.return_value = False
|
||||||
|
@ -153,7 +154,8 @@ class TestHtml5Notify(object):
|
||||||
|
|
||||||
view = hass.mock_calls[1][1][0]
|
view = hass.mock_calls[1][1][0]
|
||||||
|
|
||||||
app = web.Application(loop=loop)
|
hass.loop = loop
|
||||||
|
app = mock_http_component_app(hass)
|
||||||
view.register(app.router)
|
view.register(app.router)
|
||||||
client = yield from test_client(app)
|
client = yield from test_client(app)
|
||||||
hass.http.is_banned_ip.return_value = False
|
hass.http.is_banned_ip.return_value = False
|
||||||
|
@ -208,7 +210,8 @@ class TestHtml5Notify(object):
|
||||||
assert view.json_path == hass.config.path.return_value
|
assert view.json_path == hass.config.path.return_value
|
||||||
assert view.registrations == config
|
assert view.registrations == config
|
||||||
|
|
||||||
app = web.Application(loop=loop)
|
hass.loop = loop
|
||||||
|
app = mock_http_component_app(hass)
|
||||||
view.register(app.router)
|
view.register(app.router)
|
||||||
client = yield from test_client(app)
|
client = yield from test_client(app)
|
||||||
hass.http.is_banned_ip.return_value = False
|
hass.http.is_banned_ip.return_value = False
|
||||||
|
@ -253,7 +256,8 @@ class TestHtml5Notify(object):
|
||||||
assert view.json_path == hass.config.path.return_value
|
assert view.json_path == hass.config.path.return_value
|
||||||
assert view.registrations == config
|
assert view.registrations == config
|
||||||
|
|
||||||
app = web.Application(loop=loop)
|
hass.loop = loop
|
||||||
|
app = mock_http_component_app(hass)
|
||||||
view.register(app.router)
|
view.register(app.router)
|
||||||
client = yield from test_client(app)
|
client = yield from test_client(app)
|
||||||
hass.http.is_banned_ip.return_value = False
|
hass.http.is_banned_ip.return_value = False
|
||||||
|
@ -296,7 +300,8 @@ class TestHtml5Notify(object):
|
||||||
assert view.json_path == hass.config.path.return_value
|
assert view.json_path == hass.config.path.return_value
|
||||||
assert view.registrations == config
|
assert view.registrations == config
|
||||||
|
|
||||||
app = web.Application(loop=loop)
|
hass.loop = loop
|
||||||
|
app = mock_http_component_app(hass)
|
||||||
view.register(app.router)
|
view.register(app.router)
|
||||||
client = yield from test_client(app)
|
client = yield from test_client(app)
|
||||||
hass.http.is_banned_ip.return_value = False
|
hass.http.is_banned_ip.return_value = False
|
||||||
|
@ -331,7 +336,8 @@ class TestHtml5Notify(object):
|
||||||
|
|
||||||
view = hass.mock_calls[2][1][0]
|
view = hass.mock_calls[2][1][0]
|
||||||
|
|
||||||
app = web.Application(loop=loop)
|
hass.loop = loop
|
||||||
|
app = mock_http_component_app(hass)
|
||||||
view.register(app.router)
|
view.register(app.router)
|
||||||
client = yield from test_client(app)
|
client = yield from test_client(app)
|
||||||
hass.http.is_banned_ip.return_value = False
|
hass.http.is_banned_ip.return_value = False
|
||||||
|
@ -387,7 +393,8 @@ class TestHtml5Notify(object):
|
||||||
|
|
||||||
bearer_token = "Bearer {}".format(push_payload['data']['jwt'])
|
bearer_token = "Bearer {}".format(push_payload['data']['jwt'])
|
||||||
|
|
||||||
app = web.Application(loop=loop)
|
hass.loop = loop
|
||||||
|
app = mock_http_component_app(hass)
|
||||||
view.register(app.router)
|
view.register(app.router)
|
||||||
client = yield from test_client(app)
|
client = yield from test_client(app)
|
||||||
hass.http.is_banned_ip.return_value = False
|
hass.http.is_banned_ip.return_value = False
|
||||||
|
|
|
@ -6,7 +6,7 @@ import unittest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
import homeassistant.bootstrap as bootstrap
|
import homeassistant.bootstrap as bootstrap
|
||||||
from homeassistant.components import frontend, http
|
from homeassistant.components import http
|
||||||
from homeassistant.const import HTTP_HEADER_HA_AUTH
|
from homeassistant.const import HTTP_HEADER_HA_AUTH
|
||||||
|
|
||||||
from tests.common import get_test_instance_port, get_test_home_assistant
|
from tests.common import get_test_instance_port, get_test_home_assistant
|
||||||
|
@ -45,7 +45,6 @@ def setUpModule():
|
||||||
def tearDownModule():
|
def tearDownModule():
|
||||||
"""Stop everything that was started."""
|
"""Stop everything that was started."""
|
||||||
hass.stop()
|
hass.stop()
|
||||||
frontend.PANELS = {}
|
|
||||||
|
|
||||||
|
|
||||||
class TestFrontend(unittest.TestCase):
|
class TestFrontend(unittest.TestCase):
|
||||||
|
|
|
@ -1,285 +0,0 @@
|
||||||
"""The tests for the Home Assistant HTTP component."""
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
import logging
|
|
||||||
from ipaddress import ip_network
|
|
||||||
from unittest.mock import patch, mock_open
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from homeassistant import bootstrap, const
|
|
||||||
import homeassistant.components.http as http
|
|
||||||
|
|
||||||
from tests.common import get_test_instance_port, get_test_home_assistant
|
|
||||||
|
|
||||||
API_PASSWORD = 'test1234'
|
|
||||||
SERVER_PORT = get_test_instance_port()
|
|
||||||
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
|
|
||||||
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
|
|
||||||
HA_HEADERS = {
|
|
||||||
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
|
|
||||||
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
|
|
||||||
}
|
|
||||||
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
|
|
||||||
TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1',
|
|
||||||
'FD01:DB8::1']
|
|
||||||
TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
|
|
||||||
'2001:DB8:ABCD::1']
|
|
||||||
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
|
|
||||||
BANNED_IPS = ['200.201.202.203', '100.64.0.1']
|
|
||||||
|
|
||||||
CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE]
|
|
||||||
|
|
||||||
hass = None
|
|
||||||
|
|
||||||
|
|
||||||
def _url(path=''):
|
|
||||||
"""Helper method to generate URLs."""
|
|
||||||
return HTTP_BASE_URL + path
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def setUpModule():
|
|
||||||
"""Initialize a Home Assistant server."""
|
|
||||||
global hass
|
|
||||||
|
|
||||||
hass = get_test_home_assistant()
|
|
||||||
|
|
||||||
hass.bus.listen('test_event', lambda _: _)
|
|
||||||
hass.states.set('test.test', 'a_state')
|
|
||||||
|
|
||||||
bootstrap.setup_component(
|
|
||||||
hass, http.DOMAIN, {
|
|
||||||
http.DOMAIN: {
|
|
||||||
http.CONF_API_PASSWORD: API_PASSWORD,
|
|
||||||
http.CONF_SERVER_PORT: SERVER_PORT,
|
|
||||||
http.CONF_CORS_ORIGINS: CORS_ORIGINS,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
bootstrap.setup_component(hass, 'api')
|
|
||||||
|
|
||||||
hass.http.trusted_networks = [
|
|
||||||
ip_network(trusted_network)
|
|
||||||
for trusted_network in TRUSTED_NETWORKS]
|
|
||||||
|
|
||||||
hass.http.ip_bans = [http.IpBan(banned_ip)
|
|
||||||
for banned_ip in BANNED_IPS]
|
|
||||||
|
|
||||||
hass.start()
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def tearDownModule():
|
|
||||||
"""Stop the Home Assistant server."""
|
|
||||||
hass.stop()
|
|
||||||
|
|
||||||
|
|
||||||
class TestHttp:
|
|
||||||
"""Test HTTP component."""
|
|
||||||
|
|
||||||
def test_access_denied_without_password(self):
|
|
||||||
"""Test access without password."""
|
|
||||||
req = requests.get(_url(const.URL_API))
|
|
||||||
|
|
||||||
assert req.status_code == 401
|
|
||||||
|
|
||||||
def test_access_denied_with_wrong_password_in_header(self):
|
|
||||||
"""Test access with wrong password."""
|
|
||||||
req = requests.get(
|
|
||||||
_url(const.URL_API),
|
|
||||||
headers={const.HTTP_HEADER_HA_AUTH: 'wrongpassword'})
|
|
||||||
|
|
||||||
assert req.status_code == 401
|
|
||||||
|
|
||||||
def test_access_denied_with_x_forwarded_for(self, caplog):
|
|
||||||
"""Test access denied through the X-Forwarded-For http header."""
|
|
||||||
hass.http.use_x_forwarded_for = True
|
|
||||||
for remote_addr in UNTRUSTED_ADDRESSES:
|
|
||||||
req = requests.get(_url(const.URL_API), headers={
|
|
||||||
const.HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
|
|
||||||
|
|
||||||
assert req.status_code == 401, \
|
|
||||||
"{} shouldn't be trusted".format(remote_addr)
|
|
||||||
|
|
||||||
def test_access_denied_with_untrusted_ip(self, caplog):
|
|
||||||
"""Test access with an untrusted ip address."""
|
|
||||||
for remote_addr in UNTRUSTED_ADDRESSES:
|
|
||||||
with patch('homeassistant.components.http.'
|
|
||||||
'HomeAssistantWSGI.get_real_ip',
|
|
||||||
return_value=remote_addr):
|
|
||||||
req = requests.get(
|
|
||||||
_url(const.URL_API), params={'api_password': ''})
|
|
||||||
|
|
||||||
assert req.status_code == 401, \
|
|
||||||
"{} shouldn't be trusted".format(remote_addr)
|
|
||||||
|
|
||||||
def test_access_with_password_in_header(self, caplog):
|
|
||||||
"""Test access with password in URL."""
|
|
||||||
# Hide logging from requests package that we use to test logging
|
|
||||||
caplog.set_level(
|
|
||||||
logging.WARNING, logger='requests.packages.urllib3.connectionpool')
|
|
||||||
|
|
||||||
req = requests.get(
|
|
||||||
_url(const.URL_API),
|
|
||||||
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
|
||||||
|
|
||||||
assert req.status_code == 200
|
|
||||||
|
|
||||||
logs = caplog.text
|
|
||||||
|
|
||||||
# assert const.URL_API in logs
|
|
||||||
assert API_PASSWORD not in logs
|
|
||||||
|
|
||||||
def test_access_denied_with_wrong_password_in_url(self):
|
|
||||||
"""Test access with wrong password."""
|
|
||||||
req = requests.get(
|
|
||||||
_url(const.URL_API), params={'api_password': 'wrongpassword'})
|
|
||||||
|
|
||||||
assert req.status_code == 401
|
|
||||||
|
|
||||||
def test_access_with_password_in_url(self, caplog):
|
|
||||||
"""Test access with password in URL."""
|
|
||||||
# Hide logging from requests package that we use to test logging
|
|
||||||
caplog.set_level(
|
|
||||||
logging.WARNING, logger='requests.packages.urllib3.connectionpool')
|
|
||||||
|
|
||||||
req = requests.get(
|
|
||||||
_url(const.URL_API), params={'api_password': API_PASSWORD})
|
|
||||||
|
|
||||||
assert req.status_code == 200
|
|
||||||
|
|
||||||
logs = caplog.text
|
|
||||||
|
|
||||||
# assert const.URL_API in logs
|
|
||||||
assert API_PASSWORD not in logs
|
|
||||||
|
|
||||||
def test_access_granted_with_x_forwarded_for(self, caplog):
|
|
||||||
"""Test access denied through the X-Forwarded-For http header."""
|
|
||||||
hass.http.use_x_forwarded_for = True
|
|
||||||
for remote_addr in TRUSTED_ADDRESSES:
|
|
||||||
req = requests.get(_url(const.URL_API), headers={
|
|
||||||
const.HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
|
|
||||||
|
|
||||||
assert req.status_code == 200, \
|
|
||||||
"{} should be trusted".format(remote_addr)
|
|
||||||
|
|
||||||
def test_access_granted_with_trusted_ip(self, caplog):
|
|
||||||
"""Test access with trusted addresses."""
|
|
||||||
for remote_addr in TRUSTED_ADDRESSES:
|
|
||||||
with patch('homeassistant.components.http.'
|
|
||||||
'HomeAssistantWSGI.get_real_ip',
|
|
||||||
return_value=remote_addr):
|
|
||||||
req = requests.get(
|
|
||||||
_url(const.URL_API), params={'api_password': ''})
|
|
||||||
|
|
||||||
assert req.status_code == 200, \
|
|
||||||
'{} should be trusted'.format(remote_addr)
|
|
||||||
|
|
||||||
def test_cors_allowed_with_password_in_url(self):
|
|
||||||
"""Test cross origin resource sharing with password in url."""
|
|
||||||
req = requests.get(_url(const.URL_API),
|
|
||||||
params={'api_password': API_PASSWORD},
|
|
||||||
headers={const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL})
|
|
||||||
|
|
||||||
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
|
|
||||||
|
|
||||||
assert req.status_code == 200
|
|
||||||
assert req.headers.get(allow_origin) == HTTP_BASE_URL
|
|
||||||
|
|
||||||
def test_cors_allowed_with_password_in_header(self):
|
|
||||||
"""Test cross origin resource sharing with password in header."""
|
|
||||||
headers = {
|
|
||||||
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
|
|
||||||
const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL
|
|
||||||
}
|
|
||||||
req = requests.get(_url(const.URL_API), headers=headers)
|
|
||||||
|
|
||||||
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
|
|
||||||
|
|
||||||
assert req.status_code == 200
|
|
||||||
assert req.headers.get(allow_origin) == HTTP_BASE_URL
|
|
||||||
|
|
||||||
def test_cors_denied_without_origin_header(self):
|
|
||||||
"""Test cross origin resource sharing with password in header."""
|
|
||||||
headers = {
|
|
||||||
const.HTTP_HEADER_HA_AUTH: API_PASSWORD
|
|
||||||
}
|
|
||||||
req = requests.get(_url(const.URL_API), headers=headers)
|
|
||||||
|
|
||||||
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
|
|
||||||
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
|
|
||||||
|
|
||||||
assert req.status_code == 200
|
|
||||||
assert allow_origin not in req.headers
|
|
||||||
assert allow_headers not in req.headers
|
|
||||||
|
|
||||||
def test_cors_preflight_allowed(self):
|
|
||||||
"""Test cross origin resource sharing preflight (OPTIONS) request."""
|
|
||||||
headers = {
|
|
||||||
const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL,
|
|
||||||
'Access-Control-Request-Method': 'GET',
|
|
||||||
'Access-Control-Request-Headers': 'x-ha-access'
|
|
||||||
}
|
|
||||||
req = requests.options(_url(const.URL_API), headers=headers)
|
|
||||||
|
|
||||||
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
|
|
||||||
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
|
|
||||||
|
|
||||||
assert req.status_code == 200
|
|
||||||
assert req.headers.get(allow_origin) == HTTP_BASE_URL
|
|
||||||
assert req.headers.get(allow_headers) == \
|
|
||||||
const.HTTP_HEADER_HA_AUTH.upper()
|
|
||||||
|
|
||||||
def test_access_from_banned_ip(self):
|
|
||||||
"""Test accessing to server from banned IP. Both trusted and not."""
|
|
||||||
hass.http.is_ban_enabled = True
|
|
||||||
for remote_addr in BANNED_IPS:
|
|
||||||
with patch('homeassistant.components.http.'
|
|
||||||
'HomeAssistantWSGI.get_real_ip',
|
|
||||||
return_value=remote_addr):
|
|
||||||
req = requests.get(
|
|
||||||
_url(const.URL_API))
|
|
||||||
assert req.status_code == 403
|
|
||||||
|
|
||||||
def test_access_from_banned_ip_when_ban_is_off(self):
|
|
||||||
"""Test accessing to server from banned IP when feature is off"""
|
|
||||||
hass.http.is_ban_enabled = False
|
|
||||||
for remote_addr in BANNED_IPS:
|
|
||||||
with patch('homeassistant.components.http.'
|
|
||||||
'HomeAssistantWSGI.get_real_ip',
|
|
||||||
return_value=remote_addr):
|
|
||||||
req = requests.get(
|
|
||||||
_url(const.URL_API),
|
|
||||||
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
|
||||||
assert req.status_code == 200
|
|
||||||
|
|
||||||
def test_ip_bans_file_creation(self):
|
|
||||||
"""Testing if banned IP file created"""
|
|
||||||
hass.http.is_ban_enabled = True
|
|
||||||
hass.http.login_threshold = 1
|
|
||||||
|
|
||||||
m = mock_open()
|
|
||||||
|
|
||||||
def call_server():
|
|
||||||
with patch('homeassistant.components.http.'
|
|
||||||
'HomeAssistantWSGI.get_real_ip',
|
|
||||||
return_value="200.201.202.204"):
|
|
||||||
return requests.get(
|
|
||||||
_url(const.URL_API),
|
|
||||||
headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'})
|
|
||||||
|
|
||||||
with patch('homeassistant.components.http.open', m, create=True):
|
|
||||||
req = call_server()
|
|
||||||
assert req.status_code == 401
|
|
||||||
assert len(hass.http.ip_bans) == len(BANNED_IPS)
|
|
||||||
assert m.call_count == 0
|
|
||||||
|
|
||||||
req = call_server()
|
|
||||||
assert req.status_code == 401
|
|
||||||
assert len(hass.http.ip_bans) == len(BANNED_IPS) + 1
|
|
||||||
m.assert_called_once_with(hass.config.path(http.IP_BANS), 'a')
|
|
||||||
|
|
||||||
req = call_server()
|
|
||||||
assert req.status_code == 403
|
|
||||||
assert m.call_count == 1
|
|
|
@ -165,7 +165,15 @@ class TestCheckConfig(unittest.TestCase):
|
||||||
|
|
||||||
self.assertDictEqual({
|
self.assertDictEqual({
|
||||||
'components': {'http': {'api_password': 'abc123',
|
'components': {'http': {'api_password': 'abc123',
|
||||||
|
'cors_allowed_origins': [],
|
||||||
|
'development': '0',
|
||||||
|
'ip_ban_enabled': True,
|
||||||
|
'login_attempts_threshold': -1,
|
||||||
|
'server_host': '0.0.0.0',
|
||||||
'server_port': 8123,
|
'server_port': 8123,
|
||||||
|
'ssl_certificate': None,
|
||||||
|
'ssl_key': None,
|
||||||
|
'trusted_networks': [],
|
||||||
'use_x_forwarded_for': False}},
|
'use_x_forwarded_for': False}},
|
||||||
'except': {},
|
'except': {},
|
||||||
'secret_cache': {secrets_path: {'http_pw': 'abc123'}},
|
'secret_cache': {secrets_path: {'http_pw': 'abc123'}},
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue