Reorganize HTTP component (#4575)

* Move HTTP to own folder

* Break HTTP into middlewares

* Lint

* Split tests per middleware

* Clean up HTTP tests

* Make HomeAssistantViews more stateless

* Lint

* Make HTTP setup async
This commit is contained in:
Paulus Schoutsen 2016-11-25 13:04:06 -08:00 committed by GitHub
parent 58b85b2e0e
commit 32ffd006fa
35 changed files with 1318 additions and 1084 deletions

View file

@ -118,7 +118,7 @@ class AlexaIntentsView(HomeAssistantView):
def __init__(self, hass, intents): def __init__(self, hass, intents):
"""Initialize Alexa view.""" """Initialize Alexa view."""
super().__init__(hass) super().__init__()
intents = copy.deepcopy(intents) intents = copy.deepcopy(intents)
template.attach(hass, intents) template.attach(hass, intents)
@ -150,7 +150,7 @@ class AlexaIntentsView(HomeAssistantView):
return None return None
intent = req.get('intent') intent = req.get('intent')
response = AlexaResponse(self.hass, intent) response = AlexaResponse(request.app['hass'], intent)
if req_type == 'LaunchRequest': if req_type == 'LaunchRequest':
response.add_speech( response.add_speech(
@ -282,7 +282,7 @@ class AlexaFlashBriefingView(HomeAssistantView):
def __init__(self, hass, flash_briefings): def __init__(self, hass, flash_briefings):
"""Initialize Alexa view.""" """Initialize Alexa view."""
super().__init__(hass) super().__init__()
self.flash_briefings = copy.deepcopy(flash_briefings) self.flash_briefings = copy.deepcopy(flash_briefings)
template.attach(hass, self.flash_briefings) template.attach(hass, self.flash_briefings)

View file

@ -77,8 +77,10 @@ class APIEventStream(HomeAssistantView):
@asyncio.coroutine @asyncio.coroutine
def get(self, request): def get(self, request):
"""Provide a streaming interface for the event bus.""" """Provide a streaming interface for the event bus."""
# pylint: disable=no-self-use
hass = request.app['hass']
stop_obj = object() stop_obj = object()
to_write = asyncio.Queue(loop=self.hass.loop) to_write = asyncio.Queue(loop=hass.loop)
restrict = request.GET.get('restrict') restrict = request.GET.get('restrict')
if restrict: if restrict:
@ -106,7 +108,7 @@ class APIEventStream(HomeAssistantView):
response.content_type = 'text/event-stream' response.content_type = 'text/event-stream'
yield from response.prepare(request) yield from response.prepare(request)
unsub_stream = self.hass.bus.async_listen(MATCH_ALL, forward_events) unsub_stream = hass.bus.async_listen(MATCH_ALL, forward_events)
try: try:
_LOGGER.debug('STREAM %s ATTACHED', id(stop_obj)) _LOGGER.debug('STREAM %s ATTACHED', id(stop_obj))
@ -117,7 +119,7 @@ class APIEventStream(HomeAssistantView):
while True: while True:
try: try:
with async_timeout.timeout(STREAM_PING_INTERVAL, with async_timeout.timeout(STREAM_PING_INTERVAL,
loop=self.hass.loop): loop=hass.loop):
payload = yield from to_write.get() payload = yield from to_write.get()
if payload is stop_obj: if payload is stop_obj:
@ -145,7 +147,7 @@ class APIConfigView(HomeAssistantView):
@ha.callback @ha.callback
def get(self, request): def get(self, request):
"""Get current configuration.""" """Get current configuration."""
return self.json(self.hass.config.as_dict()) return self.json(request.app['hass'].config.as_dict())
class APIDiscoveryView(HomeAssistantView): class APIDiscoveryView(HomeAssistantView):
@ -158,10 +160,11 @@ class APIDiscoveryView(HomeAssistantView):
@ha.callback @ha.callback
def get(self, request): def get(self, request):
"""Get discovery info.""" """Get discovery info."""
needs_auth = self.hass.config.api.api_password is not None hass = request.app['hass']
needs_auth = hass.config.api.api_password is not None
return self.json({ return self.json({
'base_url': self.hass.config.api.base_url, 'base_url': hass.config.api.base_url,
'location_name': self.hass.config.location_name, 'location_name': hass.config.location_name,
'requires_api_password': needs_auth, 'requires_api_password': needs_auth,
'version': __version__ 'version': __version__
}) })
@ -176,7 +179,7 @@ class APIStatesView(HomeAssistantView):
@ha.callback @ha.callback
def get(self, request): def get(self, request):
"""Get current states.""" """Get current states."""
return self.json(self.hass.states.async_all()) return self.json(request.app['hass'].states.async_all())
class APIEntityStateView(HomeAssistantView): class APIEntityStateView(HomeAssistantView):
@ -188,7 +191,7 @@ class APIEntityStateView(HomeAssistantView):
@ha.callback @ha.callback
def get(self, request, entity_id): def get(self, request, entity_id):
"""Retrieve state of entity.""" """Retrieve state of entity."""
state = self.hass.states.get(entity_id) state = request.app['hass'].states.get(entity_id)
if state: if state:
return self.json(state) return self.json(state)
else: else:
@ -197,6 +200,7 @@ class APIEntityStateView(HomeAssistantView):
@asyncio.coroutine @asyncio.coroutine
def post(self, request, entity_id): def post(self, request, entity_id):
"""Update state of entity.""" """Update state of entity."""
hass = request.app['hass']
try: try:
data = yield from request.json() data = yield from request.json()
except ValueError: except ValueError:
@ -211,15 +215,14 @@ class APIEntityStateView(HomeAssistantView):
attributes = data.get('attributes') attributes = data.get('attributes')
force_update = data.get('force_update', False) force_update = data.get('force_update', False)
is_new_state = self.hass.states.get(entity_id) is None is_new_state = hass.states.get(entity_id) is None
# Write state # Write state
self.hass.states.async_set(entity_id, new_state, attributes, hass.states.async_set(entity_id, new_state, attributes, force_update)
force_update)
# Read the state back for our response # Read the state back for our response
status_code = HTTP_CREATED if is_new_state else 200 status_code = HTTP_CREATED if is_new_state else 200
resp = self.json(self.hass.states.get(entity_id), status_code) resp = self.json(hass.states.get(entity_id), status_code)
resp.headers.add('Location', URL_API_STATES_ENTITY.format(entity_id)) resp.headers.add('Location', URL_API_STATES_ENTITY.format(entity_id))
@ -228,7 +231,7 @@ class APIEntityStateView(HomeAssistantView):
@ha.callback @ha.callback
def delete(self, request, entity_id): def delete(self, request, entity_id):
"""Remove entity.""" """Remove entity."""
if self.hass.states.async_remove(entity_id): if request.app['hass'].states.async_remove(entity_id):
return self.json_message('Entity removed') return self.json_message('Entity removed')
else: else:
return self.json_message('Entity not found', HTTP_NOT_FOUND) return self.json_message('Entity not found', HTTP_NOT_FOUND)
@ -243,7 +246,7 @@ class APIEventListenersView(HomeAssistantView):
@ha.callback @ha.callback
def get(self, request): def get(self, request):
"""Get event listeners.""" """Get event listeners."""
return self.json(async_events_json(self.hass)) return self.json(async_events_json(request.app['hass']))
class APIEventView(HomeAssistantView): class APIEventView(HomeAssistantView):
@ -271,7 +274,8 @@ class APIEventView(HomeAssistantView):
if state: if state:
event_data[key] = state event_data[key] = state
self.hass.bus.async_fire(event_type, event_data, ha.EventOrigin.remote) request.app['hass'].bus.async_fire(event_type, event_data,
ha.EventOrigin.remote)
return self.json_message("Event {} fired.".format(event_type)) return self.json_message("Event {} fired.".format(event_type))
@ -285,7 +289,7 @@ class APIServicesView(HomeAssistantView):
@ha.callback @ha.callback
def get(self, request): def get(self, request):
"""Get registered services.""" """Get registered services."""
return self.json(async_services_json(self.hass)) return self.json(async_services_json(request.app['hass']))
class APIDomainServicesView(HomeAssistantView): class APIDomainServicesView(HomeAssistantView):
@ -300,12 +304,12 @@ class APIDomainServicesView(HomeAssistantView):
Returns a list of changed states. Returns a list of changed states.
""" """
hass = request.app['hass']
body = yield from request.text() body = yield from request.text()
data = json.loads(body) if body else None data = json.loads(body) if body else None
with AsyncTrackStates(self.hass) as changed_states: with AsyncTrackStates(hass) as changed_states:
yield from self.hass.services.async_call(domain, service, data, yield from hass.services.async_call(domain, service, data, True)
True)
return self.json(changed_states) return self.json(changed_states)
@ -320,6 +324,7 @@ class APIEventForwardingView(HomeAssistantView):
@asyncio.coroutine @asyncio.coroutine
def post(self, request): def post(self, request):
"""Setup an event forwarder.""" """Setup an event forwarder."""
hass = request.app['hass']
try: try:
data = yield from request.json() data = yield from request.json()
except ValueError: except ValueError:
@ -340,14 +345,14 @@ class APIEventForwardingView(HomeAssistantView):
api = rem.API(host, api_password, port) api = rem.API(host, api_password, port)
valid = yield from self.hass.loop.run_in_executor( valid = yield from hass.loop.run_in_executor(
None, api.validate_api) None, api.validate_api)
if not valid: if not valid:
return self.json_message("Unable to validate API.", return self.json_message("Unable to validate API.",
HTTP_UNPROCESSABLE_ENTITY) HTTP_UNPROCESSABLE_ENTITY)
if self.event_forwarder is None: if self.event_forwarder is None:
self.event_forwarder = rem.EventForwarder(self.hass) self.event_forwarder = rem.EventForwarder(hass)
self.event_forwarder.async_connect(api) self.event_forwarder.async_connect(api)
@ -389,7 +394,7 @@ class APIComponentsView(HomeAssistantView):
@ha.callback @ha.callback
def get(self, request): def get(self, request):
"""Get current loaded components.""" """Get current loaded components."""
return self.json(self.hass.config.components) return self.json(request.app['hass'].config.components)
class APIErrorLogView(HomeAssistantView): class APIErrorLogView(HomeAssistantView):
@ -402,7 +407,7 @@ class APIErrorLogView(HomeAssistantView):
def get(self, request): def get(self, request):
"""Serve error log.""" """Serve error log."""
resp = yield from self.file( resp = yield from self.file(
request, self.hass.config.path(ERROR_LOG_FILENAME)) request, request.app['hass'].config.path(ERROR_LOG_FILENAME))
return resp return resp
@ -417,7 +422,7 @@ class APITemplateView(HomeAssistantView):
"""Render a template.""" """Render a template."""
try: try:
data = yield from request.json() data = yield from request.json()
tpl = template.Template(data['template'], self.hass) tpl = template.Template(data['template'], request.app['hass'])
return tpl.async_render(data.get('variables')) return tpl.async_render(data.get('variables'))
except (ValueError, TemplateError) as ex: except (ValueError, TemplateError) as ex:
return self.json_message('Error rendering template: {}'.format(ex), return self.json_message('Error rendering template: {}'.format(ex),

View file

@ -13,7 +13,7 @@ from aiohttp import web
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView, KEY_AUTHENTICATED
DOMAIN = 'camera' DOMAIN = 'camera'
DEPENDENCIES = ['http'] DEPENDENCIES = ['http']
@ -33,8 +33,8 @@ def async_setup(hass, config):
component = EntityComponent( component = EntityComponent(
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL) logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL)
hass.http.register_view(CameraImageView(hass, component.entities)) hass.http.register_view(CameraImageView(component.entities))
hass.http.register_view(CameraMjpegStream(hass, component.entities)) hass.http.register_view(CameraMjpegStream(component.entities))
yield from component.async_setup(config) yield from component.async_setup(config)
return True return True
@ -165,9 +165,8 @@ class CameraView(HomeAssistantView):
requires_auth = False requires_auth = False
def __init__(self, hass, entities): def __init__(self, entities):
"""Initialize a basic camera view.""" """Initialize a basic camera view."""
super().__init__(hass)
self.entities = entities self.entities = entities
@asyncio.coroutine @asyncio.coroutine
@ -178,7 +177,7 @@ class CameraView(HomeAssistantView):
if camera is None: if camera is None:
return web.Response(status=404) return web.Response(status=404)
authenticated = (request.authenticated or authenticated = (request[KEY_AUTHENTICATED] or
request.GET.get('token') == camera.access_token) request.GET.get('token') == camera.access_token)
if not authenticated: if not authenticated:

View file

@ -21,7 +21,7 @@ DEPENDENCIES = ['http']
def setup_scanner(hass, config, see): def setup_scanner(hass, config, see):
"""Setup an endpoint for the GPSLogger application.""" """Setup an endpoint for the GPSLogger application."""
hass.http.register_view(GPSLoggerView(hass, see)) hass.http.register_view(GPSLoggerView(see))
return True return True
@ -32,20 +32,18 @@ class GPSLoggerView(HomeAssistantView):
url = '/api/gpslogger' url = '/api/gpslogger'
name = 'api:gpslogger' name = 'api:gpslogger'
def __init__(self, hass, see): def __init__(self, see):
"""Initialize GPSLogger url endpoints.""" """Initialize GPSLogger url endpoints."""
super().__init__(hass)
self.see = see self.see = see
@asyncio.coroutine @asyncio.coroutine
def get(self, request): def get(self, request):
"""A GPSLogger message received as GET.""" """A GPSLogger message received as GET."""
res = yield from self._handle(request.GET) res = yield from self._handle(request.app['hass'], request.GET)
return res return res
@asyncio.coroutine @asyncio.coroutine
# pylint: disable=too-many-return-statements def _handle(self, hass, data):
def _handle(self, data):
"""Handle gpslogger request.""" """Handle gpslogger request."""
if 'latitude' not in data or 'longitude' not in data: if 'latitude' not in data or 'longitude' not in data:
return ('Latitude and longitude not specified.', return ('Latitude and longitude not specified.',
@ -66,7 +64,7 @@ class GPSLoggerView(HomeAssistantView):
if 'battery' in data: if 'battery' in data:
battery = float(data['battery']) battery = float(data['battery'])
yield from self.hass.loop.run_in_executor( yield from hass.loop.run_in_executor(
None, partial(self.see, dev_id=device, None, partial(self.see, dev_id=device,
gps=gps_location, battery=battery, gps=gps_location, battery=battery,
gps_accuracy=accuracy)) gps_accuracy=accuracy))

View file

@ -23,7 +23,7 @@ DEPENDENCIES = ['http']
def setup_scanner(hass, config, see): def setup_scanner(hass, config, see):
"""Setup an endpoint for the Locative application.""" """Setup an endpoint for the Locative application."""
hass.http.register_view(LocativeView(hass, see)) hass.http.register_view(LocativeView(see))
return True return True
@ -34,27 +34,26 @@ class LocativeView(HomeAssistantView):
url = '/api/locative' url = '/api/locative'
name = 'api:locative' name = 'api:locative'
def __init__(self, hass, see): def __init__(self, see):
"""Initialize Locative url endpoints.""" """Initialize Locative url endpoints."""
super().__init__(hass)
self.see = see self.see = see
@asyncio.coroutine @asyncio.coroutine
def get(self, request): def get(self, request):
"""Locative message received as GET.""" """Locative message received as GET."""
res = yield from self._handle(request.GET) res = yield from self._handle(request.app['hass'], request.GET)
return res return res
@asyncio.coroutine @asyncio.coroutine
def post(self, request): def post(self, request):
"""Locative message received.""" """Locative message received."""
data = yield from request.post() data = yield from request.post()
res = yield from self._handle(data) res = yield from self._handle(request.app['hass'], data)
return res return res
@asyncio.coroutine @asyncio.coroutine
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
def _handle(self, data): def _handle(self, hass, data):
"""Handle locative request.""" """Handle locative request."""
if 'latitude' not in data or 'longitude' not in data: if 'latitude' not in data or 'longitude' not in data:
return ('Latitude and longitude not specified.', return ('Latitude and longitude not specified.',
@ -81,19 +80,19 @@ class LocativeView(HomeAssistantView):
gps_location = (data[ATTR_LATITUDE], data[ATTR_LONGITUDE]) gps_location = (data[ATTR_LATITUDE], data[ATTR_LONGITUDE])
if direction == 'enter': if direction == 'enter':
yield from self.hass.loop.run_in_executor( yield from hass.loop.run_in_executor(
None, partial(self.see, dev_id=device, None, partial(self.see, dev_id=device,
location_name=location_name, location_name=location_name,
gps=gps_location)) gps=gps_location))
return 'Setting location to {}'.format(location_name) return 'Setting location to {}'.format(location_name)
elif direction == 'exit': elif direction == 'exit':
current_state = self.hass.states.get( current_state = hass.states.get(
'{}.{}'.format(DOMAIN, device)) '{}.{}'.format(DOMAIN, device))
if current_state is None or current_state.state == location_name: if current_state is None or current_state.state == location_name:
location_name = STATE_NOT_HOME location_name = STATE_NOT_HOME
yield from self.hass.loop.run_in_executor( yield from hass.loop.run_in_executor(
None, partial(self.see, dev_id=device, None, partial(self.see, dev_id=device,
location_name=location_name, location_name=location_name,
gps=gps_location)) gps=gps_location))

View file

@ -78,14 +78,13 @@ def setup(hass, yaml_config):
cors_origins=None, cors_origins=None,
use_x_forwarded_for=False, use_x_forwarded_for=False,
trusted_networks=None, trusted_networks=None,
ip_bans=None,
login_threshold=0, login_threshold=0,
is_ban_enabled=False is_ban_enabled=False
) )
server.register_view(DescriptionXmlView(hass, config)) server.register_view(DescriptionXmlView(config))
server.register_view(HueUsernameView(hass)) server.register_view(HueUsernameView)
server.register_view(HueLightsView(hass, config)) server.register_view(HueLightsView(config))
upnp_listener = UPNPResponderThread( upnp_listener = UPNPResponderThread(
config.host_ip_addr, config.listen_port) config.host_ip_addr, config.listen_port)
@ -157,9 +156,8 @@ class DescriptionXmlView(HomeAssistantView):
name = 'description:xml' name = 'description:xml'
requires_auth = False requires_auth = False
def __init__(self, hass, config): def __init__(self, config):
"""Initialize the instance of the view.""" """Initialize the instance of the view."""
super().__init__(hass)
self.config = config self.config = config
@core.callback @core.callback
@ -201,10 +199,6 @@ class HueUsernameView(HomeAssistantView):
extra_urls = ['/api/'] extra_urls = ['/api/']
requires_auth = False requires_auth = False
def __init__(self, hass):
"""Initialize the instance of the view."""
super().__init__(hass)
@asyncio.coroutine @asyncio.coroutine
def post(self, request): def post(self, request):
"""Handle a POST request.""" """Handle a POST request."""
@ -229,30 +223,33 @@ class HueLightsView(HomeAssistantView):
'/api/{username}/lights/{entity_id}/state'] '/api/{username}/lights/{entity_id}/state']
requires_auth = False requires_auth = False
def __init__(self, hass, config): def __init__(self, config):
"""Initialize the instance of the view.""" """Initialize the instance of the view."""
super().__init__(hass)
self.config = config self.config = config
self.cached_states = {} self.cached_states = {}
@core.callback @core.callback
def get(self, request, username, entity_id=None): def get(self, request, username, entity_id=None):
"""Handle a GET request.""" """Handle a GET request."""
hass = request.app['hass']
if entity_id is None: if entity_id is None:
return self.async_get_lights_list() return self.async_get_lights_list(hass)
if not request.path.endswith('state'): if not request.path.endswith('state'):
return self.async_get_light_state(entity_id) return self.async_get_light_state(hass, entity_id)
return web.Response(text="Method not allowed", status=405) return web.Response(text="Method not allowed", status=405)
@asyncio.coroutine @asyncio.coroutine
def put(self, request, username, entity_id=None): def put(self, request, username, entity_id=None):
"""Handle a PUT request.""" """Handle a PUT request."""
hass = request.app['hass']
if not request.path.endswith('state'): if not request.path.endswith('state'):
return web.Response(text="Method not allowed", status=405) return web.Response(text="Method not allowed", status=405)
if entity_id and self.hass.states.get(entity_id) is None: if entity_id and hass.states.get(entity_id) is None:
return self.json_message('Entity not found', HTTP_NOT_FOUND) return self.json_message('Entity not found', HTTP_NOT_FOUND)
try: try:
@ -260,24 +257,25 @@ class HueLightsView(HomeAssistantView):
except ValueError: except ValueError:
return self.json_message('Invalid JSON', HTTP_BAD_REQUEST) return self.json_message('Invalid JSON', HTTP_BAD_REQUEST)
result = yield from self.async_put_light_state(json_data, entity_id) result = yield from self.async_put_light_state(hass, json_data,
entity_id)
return result return result
@core.callback @core.callback
def async_get_lights_list(self): def async_get_lights_list(self, hass):
"""Process a request to get the list of available lights.""" """Process a request to get the list of available lights."""
json_response = {} json_response = {}
for entity in self.hass.states.async_all(): for entity in hass.states.async_all():
if self.is_entity_exposed(entity): if self.is_entity_exposed(entity):
json_response[entity.entity_id] = entity_to_json(entity) json_response[entity.entity_id] = entity_to_json(entity)
return self.json(json_response) return self.json(json_response)
@core.callback @core.callback
def async_get_light_state(self, entity_id): def async_get_light_state(self, hass, entity_id):
"""Process a request to get the state of an individual light.""" """Process a request to get the state of an individual light."""
entity = self.hass.states.get(entity_id) entity = hass.states.get(entity_id)
if entity is None or not self.is_entity_exposed(entity): if entity is None or not self.is_entity_exposed(entity):
return web.Response(text="Entity not found", status=404) return web.Response(text="Entity not found", status=404)
@ -295,12 +293,12 @@ class HueLightsView(HomeAssistantView):
return self.json(json_response) return self.json(json_response)
@asyncio.coroutine @asyncio.coroutine
def async_put_light_state(self, request_json, entity_id): def async_put_light_state(self, hass, request_json, entity_id):
"""Process a request to set the state of an individual light.""" """Process a request to set the state of an individual light."""
config = self.config config = self.config
# Retrieve the entity from the state machine # Retrieve the entity from the state machine
entity = self.hass.states.get(entity_id) entity = hass.states.get(entity_id)
if entity is None: if entity is None:
return web.Response(text="Entity not found", status=404) return web.Response(text="Entity not found", status=404)
@ -345,7 +343,7 @@ class HueLightsView(HomeAssistantView):
self.cached_states[entity_id] = (result, brightness) self.cached_states[entity_id] = (result, brightness)
# Perform the requested action # Perform the requested action
yield from self.hass.services.async_call(core.DOMAIN, service, data, yield from hass.services.async_call(core.DOMAIN, service, data,
blocking=True) blocking=True)
json_response = \ json_response = \

View file

@ -75,8 +75,7 @@ def setup(hass, config):
descriptions[DOMAIN][SERVICE_CHECKIN], descriptions[DOMAIN][SERVICE_CHECKIN],
schema=CHECKIN_SERVICE_SCHEMA) schema=CHECKIN_SERVICE_SCHEMA)
hass.http.register_view(FoursquarePushReceiver( hass.http.register_view(FoursquarePushReceiver(config[CONF_PUSH_SECRET]))
hass, config[CONF_PUSH_SECRET]))
return True return True
@ -88,9 +87,8 @@ class FoursquarePushReceiver(HomeAssistantView):
url = "/api/foursquare" url = "/api/foursquare"
name = "foursquare" name = "foursquare"
def __init__(self, hass, push_secret): def __init__(self, push_secret):
"""Initialize the OAuth callback view.""" """Initialize the OAuth callback view."""
super().__init__(hass)
self.push_secret = push_secret self.push_secret = push_secret
@asyncio.coroutine @asyncio.coroutine
@ -110,4 +108,4 @@ class FoursquarePushReceiver(HomeAssistantView):
"push secret: %s", secret) "push secret: %s", secret)
return self.json_message('Incorrect secret', HTTP_BAD_REQUEST) return self.json_message('Incorrect secret', HTTP_BAD_REQUEST)
self.hass.bus.async_fire(EVENT_PUSH, data) request.app['hass'].bus.async_fire(EVENT_PUSH, data)

View file

@ -11,6 +11,8 @@ from homeassistant.core import callback
from homeassistant.const import HTTP_NOT_FOUND from homeassistant.const import HTTP_NOT_FOUND
from homeassistant.components import api, group from homeassistant.components import api, group
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.auth import is_trusted_ip
from homeassistant.components.http.const import KEY_DEVELOPMENT
from .version import FINGERPRINTS from .version import FINGERPRINTS
DOMAIN = 'frontend' DOMAIN = 'frontend'
@ -155,7 +157,7 @@ def setup(hass, config):
if os.path.isdir(local): if os.path.isdir(local):
hass.http.register_static_path("/local", local) hass.http.register_static_path("/local", local)
index_view = hass.data[DATA_INDEX_VIEW] = IndexView(hass) index_view = hass.data[DATA_INDEX_VIEW] = IndexView()
hass.http.register_view(index_view) hass.http.register_view(index_view)
# Components have registered panels before frontend got setup. # Components have registered panels before frontend got setup.
@ -185,12 +187,14 @@ class BootstrapView(HomeAssistantView):
@callback @callback
def get(self, request): def get(self, request):
"""Return all data needed to bootstrap Home Assistant.""" """Return all data needed to bootstrap Home Assistant."""
hass = request.app['hass']
return self.json({ return self.json({
'config': self.hass.config.as_dict(), 'config': hass.config.as_dict(),
'states': self.hass.states.async_all(), 'states': hass.states.async_all(),
'events': api.async_events_json(self.hass), 'events': api.async_events_json(hass),
'services': api.async_services_json(self.hass), 'services': api.async_services_json(hass),
'panels': self.hass.data[DATA_PANELS], 'panels': hass.data[DATA_PANELS],
}) })
@ -202,10 +206,8 @@ class IndexView(HomeAssistantView):
requires_auth = False requires_auth = False
extra_urls = ['/states', '/states/{entity_id}'] extra_urls = ['/states', '/states/{entity_id}']
def __init__(self, hass): def __init__(self):
"""Initialize the frontend view.""" """Initialize the frontend view."""
super().__init__(hass)
from jinja2 import FileSystemLoader, Environment from jinja2 import FileSystemLoader, Environment
self.templates = Environment( self.templates = Environment(
@ -217,14 +219,16 @@ class IndexView(HomeAssistantView):
@asyncio.coroutine @asyncio.coroutine
def get(self, request, entity_id=None): def get(self, request, entity_id=None):
"""Serve the index view.""" """Serve the index view."""
hass = request.app['hass']
if entity_id is not None: if entity_id is not None:
state = self.hass.states.get(entity_id) state = hass.states.get(entity_id)
if (not state or state.domain != 'group' or if (not state or state.domain != 'group' or
not state.attributes.get(group.ATTR_VIEW)): not state.attributes.get(group.ATTR_VIEW)):
return self.json_message('Entity not found', HTTP_NOT_FOUND) return self.json_message('Entity not found', HTTP_NOT_FOUND)
if self.hass.http.development: if request.app[KEY_DEVELOPMENT]:
core_url = '/static/home-assistant-polymer/build/core.js' core_url = '/static/home-assistant-polymer/build/core.js'
ui_url = '/static/home-assistant-polymer/src/home-assistant.html' ui_url = '/static/home-assistant-polymer/src/home-assistant.html'
else: else:
@ -241,19 +245,18 @@ class IndexView(HomeAssistantView):
if panel == 'states': if panel == 'states':
panel_url = '' panel_url = ''
else: else:
panel_url = self.hass.data[DATA_PANELS][panel]['url'] panel_url = hass.data[DATA_PANELS][panel]['url']
no_auth = 'true' no_auth = 'true'
if self.hass.config.api.api_password: if hass.config.api.api_password:
# require password if set # require password if set
no_auth = 'false' no_auth = 'false'
if self.hass.http.is_trusted_ip( if is_trusted_ip(request):
self.hass.http.get_real_ip(request)):
# bypass for trusted networks # bypass for trusted networks
no_auth = 'true' no_auth = 'true'
icons_url = '/static/mdi-{}.html'.format(FINGERPRINTS['mdi.html']) icons_url = '/static/mdi-{}.html'.format(FINGERPRINTS['mdi.html'])
template = yield from self.hass.loop.run_in_executor( template = yield from hass.loop.run_in_executor(
None, self.templates.get_template, 'index.html') None, self.templates.get_template, 'index.html')
# pylint is wrong # pylint is wrong
@ -262,7 +265,7 @@ class IndexView(HomeAssistantView):
resp = template.render( resp = template.render(
core_url=core_url, ui_url=ui_url, no_auth=no_auth, core_url=core_url, ui_url=ui_url, no_auth=no_auth,
icons_url=icons_url, icons=FINGERPRINTS['mdi.html'], icons_url=icons_url, icons=FINGERPRINTS['mdi.html'],
panel_url=panel_url, panels=self.hass.data[DATA_PANELS]) panel_url=panel_url, panels=hass.data[DATA_PANELS])
return web.Response(text=resp, content_type='text/html') return web.Response(text=resp, content_type='text/html')

View file

@ -184,8 +184,8 @@ def setup(hass, config):
filters.included_entities = include[CONF_ENTITIES] filters.included_entities = include[CONF_ENTITIES]
filters.included_domains = include[CONF_DOMAINS] filters.included_domains = include[CONF_DOMAINS]
hass.http.register_view(Last5StatesView(hass)) hass.http.register_view(Last5StatesView)
hass.http.register_view(HistoryPeriodView(hass, filters)) hass.http.register_view(HistoryPeriodView(filters))
register_built_in_panel(hass, 'history', 'History', 'mdi:poll-box') register_built_in_panel(hass, 'history', 'History', 'mdi:poll-box')
return True return True
@ -197,14 +197,10 @@ class Last5StatesView(HomeAssistantView):
url = '/api/history/entity/{entity_id}/recent_states' url = '/api/history/entity/{entity_id}/recent_states'
name = 'api:history:entity-recent-states' name = 'api:history:entity-recent-states'
def __init__(self, hass):
"""Initilalize the history last 5 states view."""
super().__init__(hass)
@asyncio.coroutine @asyncio.coroutine
def get(self, request, entity_id): def get(self, request, entity_id):
"""Retrieve last 5 states of entity.""" """Retrieve last 5 states of entity."""
result = yield from self.hass.loop.run_in_executor( result = yield from request.app['hass'].loop.run_in_executor(
None, last_5_states, entity_id) None, last_5_states, entity_id)
return self.json(result) return self.json(result)
@ -216,9 +212,8 @@ class HistoryPeriodView(HomeAssistantView):
name = 'api:history:view-period' name = 'api:history:view-period'
extra_urls = ['/api/history/period/{datetime}'] extra_urls = ['/api/history/period/{datetime}']
def __init__(self, hass, filters): def __init__(self, filters):
"""Initilalize the history period view.""" """Initilalize the history period view."""
super().__init__(hass)
self.filters = filters self.filters = filters
@asyncio.coroutine @asyncio.coroutine
@ -240,7 +235,7 @@ class HistoryPeriodView(HomeAssistantView):
end_time = start_time + one_day end_time = start_time + one_day
entity_id = request.GET.get('filter_entity_id') entity_id = request.GET.get('filter_entity_id')
result = yield from self.hass.loop.run_in_executor( result = yield from request.app['hass'].loop.run_in_executor(
None, get_significant_states, start_time, end_time, entity_id, None, get_significant_states, start_time, end_time, entity_id,
self.filters) self.filters)

View file

@ -1,641 +0,0 @@
"""
This module provides WSGI application to serve the Home Assistant API.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/http/
"""
import asyncio
import json
import logging
import mimetypes
import ssl
from datetime import datetime
from ipaddress import ip_address, ip_network
from pathlib import Path
import hmac
import os
import re
import voluptuous as vol
from aiohttp import web, hdrs
from aiohttp.file_sender import FileSender
from aiohttp.web_exceptions import (
HTTPUnauthorized, HTTPMovedPermanently, HTTPNotModified, HTTPForbidden)
from aiohttp.web_urldispatcher import StaticResource
import homeassistant.helpers.config_validation as cv
import homeassistant.remote as rem
from homeassistant import util
from homeassistant.components import persistent_notification
from homeassistant.config import load_yaml_config_file
from homeassistant.const import (
SERVER_PORT, HTTP_HEADER_HA_AUTH, # HTTP_HEADER_CACHE_CONTROL,
CONTENT_TYPE_JSON, ALLOWED_CORS_HEADERS, EVENT_HOMEASSISTANT_STOP,
EVENT_HOMEASSISTANT_START, HTTP_HEADER_X_FORWARDED_FOR)
from homeassistant.core import is_callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.yaml import dump
DOMAIN = 'http'
REQUIREMENTS = ('aiohttp_cors==0.5.0',)
CONF_API_PASSWORD = 'api_password'
CONF_SERVER_HOST = 'server_host'
CONF_SERVER_PORT = 'server_port'
CONF_DEVELOPMENT = 'development'
CONF_SSL_CERTIFICATE = 'ssl_certificate'
CONF_SSL_KEY = 'ssl_key'
CONF_CORS_ORIGINS = 'cors_allowed_origins'
CONF_USE_X_FORWARDED_FOR = 'use_x_forwarded_for'
CONF_TRUSTED_NETWORKS = 'trusted_networks'
CONF_LOGIN_ATTEMPTS_THRESHOLD = 'login_attempts_threshold'
CONF_IP_BAN_ENABLED = 'ip_ban_enabled'
DATA_API_PASSWORD = 'api_password'
NOTIFICATION_ID_LOGIN = 'http-login'
NOTIFICATION_ID_BAN = 'ip-ban'
IP_BANS = 'ip_bans.yaml'
ATTR_BANNED_AT = "banned_at"
# TLS configuation follows the best-practice guidelines specified here:
# https://wiki.mozilla.org/Security/Server_Side_TLS
# Intermediate guidelines are followed.
SSL_VERSION = ssl.PROTOCOL_SSLv23
SSL_OPTS = ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3
if hasattr(ssl, 'OP_NO_COMPRESSION'):
SSL_OPTS |= ssl.OP_NO_COMPRESSION
CIPHERS = "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:" \
"ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:" \
"ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:" \
"DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:" \
"ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:" \
"ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:" \
"ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:" \
"ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:" \
"DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA256:" \
"DHE-RSA-AES256-SHA:ECDHE-ECDSA-DES-CBC3-SHA:" \
"ECDHE-RSA-DES-CBC3-SHA:EDH-RSA-DES-CBC3-SHA:AES128-GCM-SHA256:" \
"AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:" \
"AES256-SHA:DES-CBC3-SHA:!DSS"
_FINGERPRINT = re.compile(r'^(.+)-[a-z0-9]{32}\.(\w+)$', re.IGNORECASE)
_LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = vol.Schema({
DOMAIN: vol.Schema({
vol.Optional(CONF_API_PASSWORD): cv.string,
vol.Optional(CONF_SERVER_HOST): cv.string,
vol.Optional(CONF_SERVER_PORT, default=SERVER_PORT):
vol.All(vol.Coerce(int), vol.Range(min=1, max=65535)),
vol.Optional(CONF_DEVELOPMENT): cv.string,
vol.Optional(CONF_SSL_CERTIFICATE): cv.isfile,
vol.Optional(CONF_SSL_KEY): cv.isfile,
vol.Optional(CONF_CORS_ORIGINS): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(CONF_USE_X_FORWARDED_FOR, default=False): cv.boolean,
vol.Optional(CONF_TRUSTED_NETWORKS):
vol.All(cv.ensure_list, [ip_network]),
vol.Optional(CONF_LOGIN_ATTEMPTS_THRESHOLD): cv.positive_int,
vol.Optional(CONF_IP_BAN_ENABLED): cv.boolean
}),
}, extra=vol.ALLOW_EXTRA)
# TEMP TO GET TESTS TO RUN
def request_class():
"""."""
raise Exception('not implemented')
class HideSensitiveFilter(logging.Filter):
"""Filter API password calls."""
def __init__(self, hass):
"""Initialize sensitive data filter."""
super().__init__()
self.hass = hass
def filter(self, record):
"""Hide sensitive data in messages."""
if self.hass.http.api_password is None:
return True
record.msg = record.msg.replace(self.hass.http.api_password, '*******')
return True
def setup(hass, config):
"""Set up the HTTP API and debug interface."""
logging.getLogger('aiohttp.access').addFilter(HideSensitiveFilter(hass))
conf = config.get(DOMAIN, {})
api_password = util.convert(conf.get(CONF_API_PASSWORD), str)
server_host = conf.get(CONF_SERVER_HOST, '0.0.0.0')
server_port = conf.get(CONF_SERVER_PORT, SERVER_PORT)
development = str(conf.get(CONF_DEVELOPMENT, '')) == '1'
ssl_certificate = conf.get(CONF_SSL_CERTIFICATE)
ssl_key = conf.get(CONF_SSL_KEY)
cors_origins = conf.get(CONF_CORS_ORIGINS, [])
use_x_forwarded_for = conf.get(CONF_USE_X_FORWARDED_FOR, False)
trusted_networks = [
ip_network(trusted_network)
for trusted_network in conf.get(CONF_TRUSTED_NETWORKS, [])]
is_ban_enabled = bool(conf.get(CONF_IP_BAN_ENABLED, False))
login_threshold = int(conf.get(CONF_LOGIN_ATTEMPTS_THRESHOLD, -1))
ip_bans = load_ip_bans_config(hass.config.path(IP_BANS))
server = HomeAssistantWSGI(
hass,
development=development,
server_host=server_host,
server_port=server_port,
api_password=api_password,
ssl_certificate=ssl_certificate,
ssl_key=ssl_key,
cors_origins=cors_origins,
use_x_forwarded_for=use_x_forwarded_for,
trusted_networks=trusted_networks,
ip_bans=ip_bans,
login_threshold=login_threshold,
is_ban_enabled=is_ban_enabled
)
@asyncio.coroutine
def stop_server(event):
"""Callback to stop the server."""
yield from server.stop()
@asyncio.coroutine
def start_server(event):
"""Callback to start the server."""
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
yield from server.start()
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_server)
hass.http = server
hass.config.api = rem.API(server_host if server_host != '0.0.0.0'
else util.get_local_ip(),
api_password, server_port,
ssl_certificate is not None)
return True
class GzipFileSender(FileSender):
"""FileSender class capable of sending gzip version if available."""
# pylint: disable=invalid-name
development = False
@asyncio.coroutine
def send(self, request, filepath):
"""Send filepath to client using request."""
gzip = False
if 'gzip' in request.headers[hdrs.ACCEPT_ENCODING]:
gzip_path = filepath.with_name(filepath.name + '.gz')
if gzip_path.is_file():
filepath = gzip_path
gzip = True
st = filepath.stat()
modsince = request.if_modified_since
if modsince is not None and st.st_mtime <= modsince.timestamp():
raise HTTPNotModified()
ct, encoding = mimetypes.guess_type(str(filepath))
if not ct:
ct = 'application/octet-stream'
resp = self._response_factory()
resp.content_type = ct
if encoding:
resp.headers[hdrs.CONTENT_ENCODING] = encoding
if gzip:
resp.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
resp.last_modified = st.st_mtime
# CACHE HACK
if not self.development:
cache_time = 31 * 86400 # = 1 month
resp.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format(
cache_time)
file_size = st.st_size
resp.content_length = file_size
with filepath.open('rb') as f:
yield from self._sendfile(request, resp, f, file_size)
return resp
_GZIP_FILE_SENDER = GzipFileSender()
@asyncio.coroutine
def staticresource_enhancer(app, handler):
"""Enhance StaticResourceHandler.
Adds gzip encoding and fingerprinting matching.
"""
inst = getattr(handler, '__self__', None)
if not isinstance(inst, StaticResource):
return handler
# pylint: disable=protected-access
inst._file_sender = _GZIP_FILE_SENDER
@asyncio.coroutine
def middleware_handler(request):
"""Strip out fingerprints from resource names."""
fingerprinted = _FINGERPRINT.match(request.match_info['filename'])
if fingerprinted:
request.match_info['filename'] = \
'{}.{}'.format(*fingerprinted.groups())
resp = yield from handler(request)
return resp
return middleware_handler
class HomeAssistantWSGI(object):
"""WSGI server for Home Assistant."""
def __init__(self, hass, development, api_password, ssl_certificate,
ssl_key, server_host, server_port, cors_origins,
use_x_forwarded_for, trusted_networks,
ip_bans, login_threshold, is_ban_enabled):
"""Initialize the WSGI Home Assistant server."""
import aiohttp_cors
self.app = web.Application(middlewares=[staticresource_enhancer],
loop=hass.loop)
self.hass = hass
self.development = development
self.api_password = api_password
self.ssl_certificate = ssl_certificate
self.ssl_key = ssl_key
self.server_host = server_host
self.server_port = server_port
self.use_x_forwarded_for = use_x_forwarded_for
self.trusted_networks = trusted_networks \
if trusted_networks is not None else []
self.event_forwarder = None
self._handler = None
self.server = None
self.login_threshold = login_threshold
self.ip_bans = ip_bans if ip_bans is not None else []
self.failed_login_attempts = {}
self.is_ban_enabled = is_ban_enabled
if cors_origins:
self.cors = aiohttp_cors.setup(self.app, defaults={
host: aiohttp_cors.ResourceOptions(
allow_headers=ALLOWED_CORS_HEADERS,
allow_methods='*',
) for host in cors_origins
})
else:
self.cors = None
# CACHE HACK
_GZIP_FILE_SENDER.development = development
def register_view(self, view):
"""Register a view with the WSGI server.
The view argument must be a class that inherits from HomeAssistantView.
It is optional to instantiate it before registering; this method will
handle it either way.
"""
if isinstance(view, type):
# Instantiate the view, if needed
view = view(self.hass)
view.register(self.app.router)
def register_redirect(self, url, redirect_to):
"""Register a redirect with the server.
If given this must be either a string or callable. In case of a
callable it's called with the url adapter that triggered the match and
the values of the URL as keyword arguments and has to return the target
for the redirect, otherwise it has to be a string with placeholders in
rule syntax.
"""
def redirect(request):
"""Redirect to location."""
raise HTTPMovedPermanently(redirect_to)
self.app.router.add_route('GET', url, redirect)
def register_static_path(self, url_root, path, cache_length=31):
"""Register a folder to serve as a static path.
Specify optional cache length of asset in days.
"""
if os.path.isdir(path):
self.app.router.add_static(url_root, path)
return
filepath = Path(path)
@asyncio.coroutine
def serve_file(request):
"""Redirect to location."""
res = yield from _GZIP_FILE_SENDER.send(request, filepath)
return res
# aiohttp supports regex matching for variables. Using that as temp
# to work around cache busting MD5.
# Turns something like /static/dev-panel.html into
# /static/{filename:dev-panel(-[a-z0-9]{32}|)\.html}
base, ext = url_root.rsplit('.', 1)
base, file = base.rsplit('/', 1)
regex = r"{}(-[a-z0-9]{{32}}|)\.{}".format(file, ext)
url_pattern = "{}/{{filename:{}}}".format(base, regex)
self.app.router.add_route('GET', url_pattern, serve_file)
@asyncio.coroutine
def start(self):
"""Start the wsgi server."""
if self.cors is not None:
for route in list(self.app.router.routes()):
self.cors.add(route)
if self.ssl_certificate:
context = ssl.SSLContext(SSL_VERSION)
context.options |= SSL_OPTS
context.set_ciphers(CIPHERS)
context.load_cert_chain(self.ssl_certificate, self.ssl_key)
else:
context = None
self._handler = self.app.make_handler()
self.server = yield from self.hass.loop.create_server(
self._handler, self.server_host, self.server_port, ssl=context)
@asyncio.coroutine
def stop(self):
"""Stop the wsgi server."""
self.server.close()
yield from self.server.wait_closed()
yield from self.app.shutdown()
yield from self._handler.finish_connections(60.0)
yield from self.app.cleanup()
def get_real_ip(self, request):
"""Return the clients correct ip address, even in proxied setups."""
if self.use_x_forwarded_for \
and HTTP_HEADER_X_FORWARDED_FOR in request.headers:
return request.headers.get(
HTTP_HEADER_X_FORWARDED_FOR).split(',')[0]
else:
peername = request.transport.get_extra_info('peername')
return peername[0] if peername is not None else None
def is_trusted_ip(self, remote_addr):
"""Match an ip address against trusted CIDR networks."""
return any(ip_address(remote_addr) in trusted_network
for trusted_network in self.hass.http.trusted_networks)
def wrong_login_attempt(self, remote_addr):
"""Registering wrong login attempt."""
if not self.is_ban_enabled or self.login_threshold < 1:
return
if remote_addr in self.failed_login_attempts:
self.failed_login_attempts[remote_addr] += 1
else:
self.failed_login_attempts[remote_addr] = 1
if self.failed_login_attempts[remote_addr] > self.login_threshold:
new_ban = IpBan(remote_addr)
self.ip_bans.append(new_ban)
update_ip_bans_config(self.hass.config.path(IP_BANS), new_ban)
_LOGGER.warning('Banned IP %s for too many login attempts',
remote_addr)
persistent_notification.async_create(
self.hass,
'Too many login attempts from {}'.format(remote_addr),
'Banning IP address', NOTIFICATION_ID_BAN)
def is_banned_ip(self, remote_addr):
"""Check if IP address is in a ban list."""
if not self.is_ban_enabled:
return False
ip_address_ = ip_address(remote_addr)
for ip_ban in self.ip_bans:
if ip_ban.ip_address == ip_address_:
return True
return False
class HomeAssistantView(object):
"""Base view for all views."""
url = None
extra_urls = []
requires_auth = True # Views inheriting from this class can override this
def __init__(self, hass):
"""Initilalize the base view."""
if not hasattr(self, 'url'):
class_name = self.__class__.__name__
raise AttributeError(
'{0} missing required attribute "url"'.format(class_name)
)
if not hasattr(self, 'name'):
class_name = self.__class__.__name__
raise AttributeError(
'{0} missing required attribute "name"'.format(class_name)
)
self.hass = hass
# pylint: disable=no-self-use
def json(self, result, status_code=200):
"""Return a JSON response."""
msg = json.dumps(
result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8')
return web.Response(
body=msg, content_type=CONTENT_TYPE_JSON, status=status_code)
def json_message(self, error, status_code=200):
"""Return a JSON message response."""
return self.json({'message': error}, status_code)
@asyncio.coroutine
# pylint: disable=no-self-use
def file(self, request, fil):
"""Return a file."""
assert isinstance(fil, str), 'only string paths allowed'
response = yield from _GZIP_FILE_SENDER.send(request, Path(fil))
return response
def register(self, router):
"""Register the view with a router."""
assert self.url is not None, 'No url set for view'
urls = [self.url] + self.extra_urls
for method in ('get', 'post', 'delete', 'put'):
handler = getattr(self, method, None)
if not handler:
continue
handler = request_handler_factory(self, handler)
for url in urls:
router.add_route(method, url, handler)
# aiohttp_cors does not work with class based views
# self.app.router.add_route('*', self.url, self, name=self.name)
# for url in self.extra_urls:
# self.app.router.add_route('*', url, self)
def request_handler_factory(view, handler):
"""Factory to wrap our handler classes.
Eventually authentication should be managed by middleware.
"""
@asyncio.coroutine
def handle(request):
"""Handle incoming request."""
if not view.hass.is_running:
return web.Response(status=503)
remote_addr = view.hass.http.get_real_ip(request)
if view.hass.http.is_banned_ip(remote_addr):
raise HTTPForbidden()
# Auth code verbose on purpose
authenticated = False
if view.hass.http.api_password is None:
authenticated = True
elif view.hass.http.is_trusted_ip(remote_addr):
authenticated = True
elif hmac.compare_digest(request.headers.get(HTTP_HEADER_HA_AUTH, ''),
view.hass.http.api_password):
# A valid auth header has been set
authenticated = True
elif hmac.compare_digest(request.GET.get(DATA_API_PASSWORD, ''),
view.hass.http.api_password):
authenticated = True
if view.requires_auth and not authenticated:
view.hass.http.wrong_login_attempt(remote_addr)
_LOGGER.warning('Login attempt or request with an invalid '
'password from %s', remote_addr)
persistent_notification.async_create(
view.hass,
'Invalid password used from {}'.format(remote_addr),
'Login attempt failed', NOTIFICATION_ID_LOGIN)
raise HTTPUnauthorized()
request.authenticated = authenticated
_LOGGER.info('Serving %s to %s (auth: %s)',
request.path, remote_addr, authenticated)
assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \
"Handler should be a coroutine or a callback."
result = handler(request, **request.match_info)
if asyncio.iscoroutine(result):
result = yield from result
if isinstance(result, web.StreamResponse):
# The method handler returned a ready-made Response, how nice of it
return result
status_code = 200
if isinstance(result, tuple):
result, status_code = result
if isinstance(result, str):
result = result.encode('utf-8')
elif result is None:
result = b''
elif not isinstance(result, bytes):
assert False, ('Result should be None, string, bytes or Response. '
'Got: {}').format(result)
return web.Response(body=result, status=status_code)
return handle
class IpBan(object):
"""Represents banned IP address."""
def __init__(self, ip_ban: str, banned_at: datetime=None) -> None:
"""Initializing Ip Ban object."""
self.ip_address = ip_address(ip_ban)
self.banned_at = banned_at
if self.banned_at is None:
self.banned_at = datetime.utcnow()
def load_ip_bans_config(path: str):
"""Loading list of banned IPs from config file."""
ip_list = []
ip_schema = vol.Schema({
vol.Optional('banned_at'): vol.Any(None, cv.datetime)
})
try:
try:
list_ = load_yaml_config_file(path)
except HomeAssistantError as err:
_LOGGER.error('Unable to load %s: %s', path, str(err))
return []
for ip_ban, ip_info in list_.items():
try:
ip_info = ip_schema(ip_info)
ip_info['ip_ban'] = ip_address(ip_ban)
ip_list.append(IpBan(**ip_info))
except vol.Invalid:
_LOGGER.exception('Failed to load IP ban')
continue
except(HomeAssistantError, FileNotFoundError):
# No need to report error, file absence means
# that no bans were applied.
return []
return ip_list
def update_ip_bans_config(path: str, ip_ban: IpBan):
"""Update config file with new banned IP address."""
with open(path, 'a') as out:
ip_ = {str(ip_ban.ip_address): {
ATTR_BANNED_AT: ip_ban.banned_at.strftime("%Y-%m-%dT%H:%M:%S")
}}
out.write('\n')
out.write(dump(ip_))

View file

@ -0,0 +1,407 @@
"""
This module provides WSGI application to serve the Home Assistant API.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/http/
"""
import asyncio
import json
import logging
import ssl
from ipaddress import ip_network
from pathlib import Path
import os
import voluptuous as vol
from aiohttp import web
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently
import homeassistant.helpers.config_validation as cv
import homeassistant.remote as rem
from homeassistant.util import get_local_ip
from homeassistant.components import persistent_notification
from homeassistant.const import (
SERVER_PORT, CONTENT_TYPE_JSON, ALLOWED_CORS_HEADERS,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START)
from homeassistant.core import is_callback
from homeassistant.util.logging import HideSensitiveDataFilter
from .auth import auth_middleware
from .ban import ban_middleware, process_wrong_login
from .const import (
KEY_USE_X_FORWARDED_FOR, KEY_TRUSTED_NETWORKS,
KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD,
KEY_DEVELOPMENT, KEY_AUTHENTICATED)
from .static import GZIP_FILE_SENDER, staticresource_middleware
from .util import get_real_ip
DOMAIN = 'http'
REQUIREMENTS = ('aiohttp_cors==0.5.0',)
CONF_API_PASSWORD = 'api_password'
CONF_SERVER_HOST = 'server_host'
CONF_SERVER_PORT = 'server_port'
CONF_DEVELOPMENT = 'development'
CONF_SSL_CERTIFICATE = 'ssl_certificate'
CONF_SSL_KEY = 'ssl_key'
CONF_CORS_ORIGINS = 'cors_allowed_origins'
CONF_USE_X_FORWARDED_FOR = 'use_x_forwarded_for'
CONF_TRUSTED_NETWORKS = 'trusted_networks'
CONF_LOGIN_ATTEMPTS_THRESHOLD = 'login_attempts_threshold'
CONF_IP_BAN_ENABLED = 'ip_ban_enabled'
NOTIFICATION_ID_LOGIN = 'http-login'
# TLS configuation follows the best-practice guidelines specified here:
# https://wiki.mozilla.org/Security/Server_Side_TLS
# Intermediate guidelines are followed.
SSL_VERSION = ssl.PROTOCOL_SSLv23
SSL_OPTS = ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3
if hasattr(ssl, 'OP_NO_COMPRESSION'):
SSL_OPTS |= ssl.OP_NO_COMPRESSION
CIPHERS = "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:" \
"ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:" \
"ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:" \
"DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:" \
"ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:" \
"ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:" \
"ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:" \
"ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:" \
"DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA256:" \
"DHE-RSA-AES256-SHA:ECDHE-ECDSA-DES-CBC3-SHA:" \
"ECDHE-RSA-DES-CBC3-SHA:EDH-RSA-DES-CBC3-SHA:AES128-GCM-SHA256:" \
"AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:" \
"AES256-SHA:DES-CBC3-SHA:!DSS"
_LOGGER = logging.getLogger(__name__)
DEFAULT_SERVER_HOST = '0.0.0.0'
DEFAULT_DEVELOPMENT = '0'
DEFAULT_LOGIN_ATTEMPT_THRESHOLD = -1
HTTP_SCHEMA = vol.Schema({
vol.Optional(CONF_API_PASSWORD, default=None): cv.string,
vol.Optional(CONF_SERVER_HOST, default=DEFAULT_SERVER_HOST): cv.string,
vol.Optional(CONF_SERVER_PORT, default=SERVER_PORT):
vol.All(vol.Coerce(int), vol.Range(min=1, max=65535)),
vol.Optional(CONF_DEVELOPMENT, default=DEFAULT_DEVELOPMENT): cv.string,
vol.Optional(CONF_SSL_CERTIFICATE, default=None): cv.isfile,
vol.Optional(CONF_SSL_KEY, default=None): cv.isfile,
vol.Optional(CONF_CORS_ORIGINS, default=[]): vol.All(cv.ensure_list,
[cv.string]),
vol.Optional(CONF_USE_X_FORWARDED_FOR, default=False): cv.boolean,
vol.Optional(CONF_TRUSTED_NETWORKS, default=[]):
vol.All(cv.ensure_list, [ip_network]),
vol.Optional(CONF_LOGIN_ATTEMPTS_THRESHOLD,
default=DEFAULT_LOGIN_ATTEMPT_THRESHOLD): cv.positive_int,
vol.Optional(CONF_IP_BAN_ENABLED, default=True): cv.boolean
})
CONFIG_SCHEMA = vol.Schema({
DOMAIN: HTTP_SCHEMA,
}, extra=vol.ALLOW_EXTRA)
@asyncio.coroutine
def async_setup(hass, config):
"""Set up the HTTP API and debug interface."""
conf = config.get(DOMAIN)
if conf is None:
conf = HTTP_SCHEMA({})
api_password = conf[CONF_API_PASSWORD]
server_host = conf[CONF_SERVER_HOST]
server_port = conf[CONF_SERVER_PORT]
development = conf[CONF_DEVELOPMENT] == '1'
ssl_certificate = conf[CONF_SSL_CERTIFICATE]
ssl_key = conf[CONF_SSL_KEY]
cors_origins = conf[CONF_CORS_ORIGINS]
use_x_forwarded_for = conf[CONF_USE_X_FORWARDED_FOR]
trusted_networks = conf[CONF_TRUSTED_NETWORKS]
is_ban_enabled = conf[CONF_IP_BAN_ENABLED]
login_threshold = conf[CONF_LOGIN_ATTEMPTS_THRESHOLD]
if api_password is not None:
logging.getLogger('aiohttp.access').addFilter(
HideSensitiveDataFilter(api_password))
server = HomeAssistantWSGI(
hass,
development=development,
server_host=server_host,
server_port=server_port,
api_password=api_password,
ssl_certificate=ssl_certificate,
ssl_key=ssl_key,
cors_origins=cors_origins,
use_x_forwarded_for=use_x_forwarded_for,
trusted_networks=trusted_networks,
login_threshold=login_threshold,
is_ban_enabled=is_ban_enabled
)
@asyncio.coroutine
def stop_server(event):
"""Callback to stop the server."""
yield from server.stop()
@asyncio.coroutine
def start_server(event):
"""Callback to start the server."""
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
yield from server.start()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_server)
hass.http = server
hass.config.api = rem.API(server_host if server_host != '0.0.0.0'
else get_local_ip(),
api_password, server_port,
ssl_certificate is not None)
return True
class HomeAssistantWSGI(object):
"""WSGI server for Home Assistant."""
def __init__(self, hass, development, api_password, ssl_certificate,
ssl_key, server_host, server_port, cors_origins,
use_x_forwarded_for, trusted_networks,
login_threshold, is_ban_enabled):
"""Initialize the WSGI Home Assistant server."""
import aiohttp_cors
middlewares = [auth_middleware, staticresource_middleware]
if is_ban_enabled:
middlewares.insert(0, ban_middleware)
self.app = web.Application(middlewares=middlewares, loop=hass.loop)
self.app['hass'] = hass
self.app[KEY_USE_X_FORWARDED_FOR] = use_x_forwarded_for
self.app[KEY_TRUSTED_NETWORKS] = trusted_networks
self.app[KEY_BANS_ENABLED] = is_ban_enabled
self.app[KEY_LOGIN_THRESHOLD] = login_threshold
self.app[KEY_DEVELOPMENT] = development
self.hass = hass
self.development = development
self.api_password = api_password
self.ssl_certificate = ssl_certificate
self.ssl_key = ssl_key
self.server_host = server_host
self.server_port = server_port
self._handler = None
self.server = None
if cors_origins:
self.cors = aiohttp_cors.setup(self.app, defaults={
host: aiohttp_cors.ResourceOptions(
allow_headers=ALLOWED_CORS_HEADERS,
allow_methods='*',
) for host in cors_origins
})
else:
self.cors = None
def register_view(self, view):
"""Register a view with the WSGI server.
The view argument must be a class that inherits from HomeAssistantView.
It is optional to instantiate it before registering; this method will
handle it either way.
"""
if isinstance(view, type):
# Instantiate the view, if needed
view = view()
if not hasattr(view, 'url'):
class_name = view.__class__.__name__
raise AttributeError(
'{0} missing required attribute "url"'.format(class_name)
)
if not hasattr(view, 'name'):
class_name = view.__class__.__name__
raise AttributeError(
'{0} missing required attribute "name"'.format(class_name)
)
view.register(self.app.router)
def register_redirect(self, url, redirect_to):
"""Register a redirect with the server.
If given this must be either a string or callable. In case of a
callable it's called with the url adapter that triggered the match and
the values of the URL as keyword arguments and has to return the target
for the redirect, otherwise it has to be a string with placeholders in
rule syntax.
"""
def redirect(request):
"""Redirect to location."""
raise HTTPMovedPermanently(redirect_to)
self.app.router.add_route('GET', url, redirect)
def register_static_path(self, url_root, path, cache_length=31):
"""Register a folder to serve as a static path.
Specify optional cache length of asset in days.
"""
if os.path.isdir(path):
self.app.router.add_static(url_root, path)
return
filepath = Path(path)
@asyncio.coroutine
def serve_file(request):
"""Serve file from disk."""
res = yield from GZIP_FILE_SENDER.send(request, filepath)
return res
# aiohttp supports regex matching for variables. Using that as temp
# to work around cache busting MD5.
# Turns something like /static/dev-panel.html into
# /static/{filename:dev-panel(-[a-z0-9]{32}|)\.html}
base, ext = url_root.rsplit('.', 1)
base, file = base.rsplit('/', 1)
regex = r"{}(-[a-z0-9]{{32}}|)\.{}".format(file, ext)
url_pattern = "{}/{{filename:{}}}".format(base, regex)
self.app.router.add_route('GET', url_pattern, serve_file)
@asyncio.coroutine
def start(self):
"""Start the wsgi server."""
if self.cors is not None:
for route in list(self.app.router.routes()):
self.cors.add(route)
if self.ssl_certificate:
context = ssl.SSLContext(SSL_VERSION)
context.options |= SSL_OPTS
context.set_ciphers(CIPHERS)
context.load_cert_chain(self.ssl_certificate, self.ssl_key)
else:
context = None
self._handler = self.app.make_handler()
self.server = yield from self.hass.loop.create_server(
self._handler, self.server_host, self.server_port, ssl=context)
@asyncio.coroutine
def stop(self):
"""Stop the wsgi server."""
self.server.close()
yield from self.server.wait_closed()
yield from self.app.shutdown()
yield from self._handler.finish_connections(60.0)
yield from self.app.cleanup()
class HomeAssistantView(object):
"""Base view for all views."""
url = None
extra_urls = []
requires_auth = True # Views inheriting from this class can override this
# pylint: disable=no-self-use
def json(self, result, status_code=200):
"""Return a JSON response."""
msg = json.dumps(
result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8')
return web.Response(
body=msg, content_type=CONTENT_TYPE_JSON, status=status_code)
def json_message(self, error, status_code=200):
"""Return a JSON message response."""
return self.json({'message': error}, status_code)
@asyncio.coroutine
# pylint: disable=no-self-use
def file(self, request, fil):
"""Return a file."""
assert isinstance(fil, str), 'only string paths allowed'
response = yield from GZIP_FILE_SENDER.send(request, Path(fil))
return response
def register(self, router):
"""Register the view with a router."""
assert self.url is not None, 'No url set for view'
urls = [self.url] + self.extra_urls
for method in ('get', 'post', 'delete', 'put'):
handler = getattr(self, method, None)
if not handler:
continue
handler = request_handler_factory(self, handler)
for url in urls:
router.add_route(method, url, handler)
# aiohttp_cors does not work with class based views
# self.app.router.add_route('*', self.url, self, name=self.name)
# for url in self.extra_urls:
# self.app.router.add_route('*', url, self)
def request_handler_factory(view, handler):
"""Factory to wrap our handler classes."""
assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \
"Handler should be a coroutine or a callback."
@asyncio.coroutine
def handle(request):
"""Handle incoming request."""
if not request.app['hass'].is_running:
return web.Response(status=503)
remote_addr = get_real_ip(request)
authenticated = request.get(KEY_AUTHENTICATED, False)
if view.requires_auth and not authenticated:
yield from process_wrong_login(request)
_LOGGER.warning('Login attempt or request with an invalid '
'password from %s', remote_addr)
persistent_notification.async_create(
request.app['hass'],
'Invalid password used from {}'.format(remote_addr),
'Login attempt failed', NOTIFICATION_ID_LOGIN)
raise HTTPUnauthorized()
_LOGGER.info('Serving %s to %s (auth: %s)',
request.path, remote_addr, authenticated)
result = handler(request, **request.match_info)
if asyncio.iscoroutine(result):
result = yield from result
if isinstance(result, web.StreamResponse):
# The method handler returned a ready-made Response, how nice of it
return result
status_code = 200
if isinstance(result, tuple):
result, status_code = result
if isinstance(result, str):
result = result.encode('utf-8')
elif result is None:
result = b''
elif not isinstance(result, bytes):
assert False, ('Result should be None, string, bytes or Response. '
'Got: {}').format(result)
return web.Response(body=result, status=status_code)
return handle

View file

@ -0,0 +1,61 @@
"""Authentication for HTTP component."""
import asyncio
import hmac
import logging
from homeassistant.const import HTTP_HEADER_HA_AUTH
from .util import get_real_ip
from .const import KEY_TRUSTED_NETWORKS, KEY_AUTHENTICATED
DATA_API_PASSWORD = 'api_password'
_LOGGER = logging.getLogger(__name__)
@asyncio.coroutine
def auth_middleware(app, handler):
"""Authentication middleware."""
# If no password set, just always set authenticated=True
if app['hass'].http.api_password is None:
@asyncio.coroutine
def no_auth_middleware_handler(request):
"""Auth middleware to approve all requests."""
request[KEY_AUTHENTICATED] = True
return handler(request)
return no_auth_middleware_handler
@asyncio.coroutine
def auth_middleware_handler(request):
"""Auth middleware to check authentication."""
hass = app['hass']
# Auth code verbose on purpose
authenticated = False
if hmac.compare_digest(request.headers.get(HTTP_HEADER_HA_AUTH, ''),
hass.http.api_password):
# A valid auth header has been set
authenticated = True
elif hmac.compare_digest(request.GET.get(DATA_API_PASSWORD, ''),
hass.http.api_password):
authenticated = True
elif is_trusted_ip(request):
authenticated = True
request[KEY_AUTHENTICATED] = authenticated
return handler(request)
return auth_middleware_handler
def is_trusted_ip(request):
"""Test if request is from a trusted ip."""
ip_addr = get_real_ip(request)
return ip_addr and any(
ip_addr in trusted_network for trusted_network
in request.app[KEY_TRUSTED_NETWORKS])

View file

@ -0,0 +1,132 @@
"""Ban logic for HTTP component."""
import asyncio
from collections import defaultdict
from datetime import datetime
from ipaddress import ip_address
import logging
from aiohttp.web_exceptions import HTTPForbidden
import voluptuous as vol
from homeassistant.components import persistent_notification
from homeassistant.config import load_yaml_config_file
from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv
from homeassistant.util.yaml import dump
from .const import (
KEY_BANS_ENABLED, KEY_BANNED_IPS, KEY_LOGIN_THRESHOLD,
KEY_FAILED_LOGIN_ATTEMPTS)
from .util import get_real_ip
NOTIFICATION_ID_BAN = 'ip-ban'
IP_BANS_FILE = 'ip_bans.yaml'
ATTR_BANNED_AT = "banned_at"
SCHEMA_IP_BAN_ENTRY = vol.Schema({
vol.Optional('banned_at'): vol.Any(None, cv.datetime)
})
_LOGGER = logging.getLogger(__name__)
@asyncio.coroutine
def ban_middleware(app, handler):
"""IP Ban middleware."""
if not app[KEY_BANS_ENABLED]:
return handler
if KEY_BANNED_IPS not in app:
hass = app['hass']
app[KEY_BANNED_IPS] = yield from hass.loop.run_in_executor(
None, load_ip_bans_config, hass.config.path(IP_BANS_FILE))
@asyncio.coroutine
def ban_middleware_handler(request):
"""Verify if IP is not banned."""
ip_address_ = get_real_ip(request)
is_banned = any(ip_ban.ip_address == ip_address_
for ip_ban in request.app[KEY_BANNED_IPS])
if is_banned:
raise HTTPForbidden()
return handler(request)
return ban_middleware_handler
@asyncio.coroutine
def process_wrong_login(request):
"""Process a wrong login attempt."""
if (not request.app[KEY_BANS_ENABLED] or
request.app[KEY_LOGIN_THRESHOLD] < 1):
return
if KEY_FAILED_LOGIN_ATTEMPTS not in request.app:
request.app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
remote_addr = get_real_ip(request)
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1
if (request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] >
request.app[KEY_LOGIN_THRESHOLD]):
new_ban = IpBan(remote_addr)
request.app[KEY_BANNED_IPS].append(new_ban)
hass = request.app['hass']
yield from hass.loop.run_in_executor(
None, update_ip_bans_config, hass.config.path(IP_BANS_FILE),
new_ban)
_LOGGER.warning('Banned IP %s for too many login attempts',
remote_addr)
persistent_notification.async_create(
hass,
'Too many login attempts from {}'.format(remote_addr),
'Banning IP address', NOTIFICATION_ID_BAN)
class IpBan(object):
"""Represents banned IP address."""
def __init__(self, ip_ban: str, banned_at: datetime=None) -> None:
"""Initializing Ip Ban object."""
self.ip_address = ip_address(ip_ban)
self.banned_at = banned_at or datetime.utcnow()
def load_ip_bans_config(path: str):
"""Loading list of banned IPs from config file."""
ip_list = []
try:
list_ = load_yaml_config_file(path)
except FileNotFoundError:
return []
except HomeAssistantError as err:
_LOGGER.error('Unable to load %s: %s', path, str(err))
return []
for ip_ban, ip_info in list_.items():
try:
ip_info = SCHEMA_IP_BAN_ENTRY(ip_info)
ip_list.append(IpBan(ip_ban, ip_info['banned_at']))
except vol.Invalid as err:
_LOGGER.error('Failed to load IP ban %s: %s', ip_info, err)
continue
return ip_list
def update_ip_bans_config(path: str, ip_ban: IpBan):
"""Update config file with new banned IP address."""
with open(path, 'a') as out:
ip_ = {str(ip_ban.ip_address): {
ATTR_BANNED_AT: ip_ban.banned_at.strftime("%Y-%m-%dT%H:%M:%S")
}}
out.write('\n')
out.write(dump(ip_))

View file

@ -0,0 +1,12 @@
"""HTTP specific constants."""
KEY_AUTHENTICATED = 'ha_authenticated'
KEY_USE_X_FORWARDED_FOR = 'ha_use_x_forwarded_for'
KEY_TRUSTED_NETWORKS = 'ha_trusted_networks'
KEY_REAL_IP = 'ha_real_ip'
KEY_BANS_ENABLED = 'ha_bans_enabled'
KEY_BANNED_IPS = 'ha_banned_ips'
KEY_FAILED_LOGIN_ATTEMPTS = 'ha_failed_login_attempts'
KEY_LOGIN_THRESHOLD = 'ha_login_treshold'
KEY_DEVELOPMENT = 'ha_development'
HTTP_HEADER_X_FORWARDED_FOR = 'X-Forwarded-For'

View file

@ -0,0 +1,93 @@
"""Static file handling for HTTP component."""
import asyncio
import mimetypes
import re
from aiohttp import hdrs
from aiohttp.file_sender import FileSender
from aiohttp.web_urldispatcher import StaticResource
from aiohttp.web_exceptions import HTTPNotModified
from .const import KEY_DEVELOPMENT
_FINGERPRINT = re.compile(r'^(.+)-[a-z0-9]{32}\.(\w+)$', re.IGNORECASE)
class GzipFileSender(FileSender):
"""FileSender class capable of sending gzip version if available."""
# pylint: disable=invalid-name
@asyncio.coroutine
def send(self, request, filepath):
"""Send filepath to client using request."""
gzip = False
if 'gzip' in request.headers[hdrs.ACCEPT_ENCODING]:
gzip_path = filepath.with_name(filepath.name + '.gz')
if gzip_path.is_file():
filepath = gzip_path
gzip = True
st = filepath.stat()
modsince = request.if_modified_since
if modsince is not None and st.st_mtime <= modsince.timestamp():
raise HTTPNotModified()
ct, encoding = mimetypes.guess_type(str(filepath))
if not ct:
ct = 'application/octet-stream'
resp = self._response_factory()
resp.content_type = ct
if encoding:
resp.headers[hdrs.CONTENT_ENCODING] = encoding
if gzip:
resp.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
resp.last_modified = st.st_mtime
# CACHE HACK
if not request.app[KEY_DEVELOPMENT]:
cache_time = 31 * 86400 # = 1 month
resp.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format(
cache_time)
file_size = st.st_size
resp.content_length = file_size
with filepath.open('rb') as f:
yield from self._sendfile(request, resp, f, file_size)
return resp
GZIP_FILE_SENDER = GzipFileSender()
@asyncio.coroutine
def staticresource_middleware(app, handler):
"""Enhance StaticResourceHandler middleware.
Adds gzip encoding and fingerprinting matching.
"""
inst = getattr(handler, '__self__', None)
if not isinstance(inst, StaticResource):
return handler
# pylint: disable=protected-access
inst._file_sender = GZIP_FILE_SENDER
@asyncio.coroutine
def static_middleware_handler(request):
"""Strip out fingerprints from resource names."""
fingerprinted = _FINGERPRINT.match(request.match_info['filename'])
if fingerprinted:
request.match_info['filename'] = \
'{}.{}'.format(*fingerprinted.groups())
resp = yield from handler(request)
return resp
return static_middleware_handler

View file

@ -0,0 +1,25 @@
"""HTTP utilities."""
from ipaddress import ip_address
from .const import (
KEY_REAL_IP, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
def get_real_ip(request):
"""Get IP address of client."""
if KEY_REAL_IP in request:
return request[KEY_REAL_IP]
if (request.app[KEY_USE_X_FORWARDED_FOR] and
HTTP_HEADER_X_FORWARDED_FOR in request.headers):
request[KEY_REAL_IP] = ip_address(
request.headers.get(HTTP_HEADER_X_FORWARDED_FOR).split(',')[0])
else:
peername = request.transport.get_extra_info('peername')
if peername:
request[KEY_REAL_IP] = ip_address(peername[0])
else:
request[KEY_REAL_IP] = None
return request[KEY_REAL_IP]

View file

@ -250,11 +250,10 @@ def setup(hass, config):
discovery.load_platform(hass, "sensor", DOMAIN, {}, config) discovery.load_platform(hass, "sensor", DOMAIN, {}, config)
hass.http.register_view(iOSIdentifyDeviceView(hass)) hass.http.register_view(iOSIdentifyDeviceView)
app_config = config.get(DOMAIN, {}) app_config = config.get(DOMAIN, {})
hass.http.register_view(iOSPushConfigView(hass, hass.http.register_view(iOSPushConfigView(app_config.get(CONF_PUSH, {})))
app_config.get(CONF_PUSH, {})))
return True return True
@ -266,9 +265,8 @@ class iOSPushConfigView(HomeAssistantView):
url = "/api/ios/push" url = "/api/ios/push"
name = "api:ios:push" name = "api:ios:push"
def __init__(self, hass, push_config): def __init__(self, push_config):
"""Init the view.""" """Init the view."""
super().__init__(hass)
self.push_config = push_config self.push_config = push_config
@callback @callback
@ -283,10 +281,6 @@ class iOSIdentifyDeviceView(HomeAssistantView):
url = "/api/ios/identify" url = "/api/ios/identify"
name = "api:ios:identify" name = "api:ios:identify"
def __init__(self, hass):
"""Init the view."""
super().__init__(hass)
@asyncio.coroutine @asyncio.coroutine
def post(self, request): def post(self, request):
"""Handle the POST request for device identification.""" """Handle the POST request for device identification."""

View file

@ -101,7 +101,7 @@ def setup(hass, config):
message = message.async_render() message = message.async_render()
async_log_entry(hass, name, message, domain, entity_id) async_log_entry(hass, name, message, domain, entity_id)
hass.http.register_view(LogbookView(hass, config)) hass.http.register_view(LogbookView(config))
register_built_in_panel(hass, 'logbook', 'Logbook', register_built_in_panel(hass, 'logbook', 'Logbook',
'mdi:format-list-bulleted-type') 'mdi:format-list-bulleted-type')
@ -118,9 +118,8 @@ class LogbookView(HomeAssistantView):
name = 'api:logbook' name = 'api:logbook'
extra_urls = ['/api/logbook/{datetime}'] extra_urls = ['/api/logbook/{datetime}']
def __init__(self, hass, config): def __init__(self, config):
"""Initilalize the logbook view.""" """Initilalize the logbook view."""
super().__init__(hass)
self.config = config self.config = config
@asyncio.coroutine @asyncio.coroutine
@ -146,7 +145,8 @@ class LogbookView(HomeAssistantView):
events = recorder.execute(query) events = recorder.execute(query)
return _exclude_events(events, self.config) return _exclude_events(events, self.config)
events = yield from self.hass.loop.run_in_executor(None, get_results) events = yield from request.app['hass'].loop.run_in_executor(
None, get_results)
return self.json(humanify(events)) return self.json(humanify(events))

View file

@ -17,7 +17,7 @@ from homeassistant.config import load_yaml_config_file
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView, KEY_AUTHENTICATED
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.util.async import run_coroutine_threadsafe from homeassistant.util.async import run_coroutine_threadsafe
from homeassistant.const import ( from homeassistant.const import (
@ -304,7 +304,7 @@ def setup(hass, config):
component = EntityComponent( component = EntityComponent(
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL) logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL)
hass.http.register_view(MediaPlayerImageView(hass, component.entities)) hass.http.register_view(MediaPlayerImageView(component.entities))
component.setup(config) component.setup(config)
@ -736,9 +736,8 @@ class MediaPlayerImageView(HomeAssistantView):
url = "/api/media_player_proxy/{entity_id}" url = "/api/media_player_proxy/{entity_id}"
name = "api:media_player:image" name = "api:media_player:image"
def __init__(self, hass, entities): def __init__(self, entities):
"""Initialize a media player view.""" """Initialize a media player view."""
super().__init__(hass)
self.entities = entities self.entities = entities
@asyncio.coroutine @asyncio.coroutine
@ -748,14 +747,14 @@ class MediaPlayerImageView(HomeAssistantView):
if player is None: if player is None:
return web.Response(status=404) return web.Response(status=404)
authenticated = (request.authenticated or authenticated = (request[KEY_AUTHENTICATED] or
request.GET.get('token') == player.access_token) request.GET.get('token') == player.access_token)
if not authenticated: if not authenticated:
return web.Response(status=401) return web.Response(status=401)
data, content_type = yield from _async_fetch_image( data, content_type = yield from _async_fetch_image(
self.hass, player.media_image_url) request.app['hass'], player.media_image_url)
if data is None: if data is None:
return web.Response(status=500) return web.Response(status=500)

View file

@ -107,8 +107,8 @@ def get_service(hass, config):
return None return None
hass.http.register_view( hass.http.register_view(
HTML5PushRegistrationView(hass, registrations, json_path)) HTML5PushRegistrationView(registrations, json_path))
hass.http.register_view(HTML5PushCallbackView(hass, registrations)) hass.http.register_view(HTML5PushCallbackView(registrations))
gcm_api_key = config.get(ATTR_GCM_API_KEY) gcm_api_key = config.get(ATTR_GCM_API_KEY)
gcm_sender_id = config.get(ATTR_GCM_SENDER_ID) gcm_sender_id = config.get(ATTR_GCM_SENDER_ID)
@ -168,9 +168,8 @@ class HTML5PushRegistrationView(HomeAssistantView):
url = '/api/notify.html5' url = '/api/notify.html5'
name = 'api:notify.html5' name = 'api:notify.html5'
def __init__(self, hass, registrations, json_path): def __init__(self, registrations, json_path):
"""Init HTML5PushRegistrationView.""" """Init HTML5PushRegistrationView."""
super().__init__(hass)
self.registrations = registrations self.registrations = registrations
self.json_path = json_path self.json_path = json_path
@ -237,9 +236,8 @@ class HTML5PushCallbackView(HomeAssistantView):
url = '/api/notify.html5/callback' url = '/api/notify.html5/callback'
name = 'api:notify.html5/callback' name = 'api:notify.html5/callback'
def __init__(self, hass, registrations): def __init__(self, registrations):
"""Init HTML5PushCallbackView.""" """Init HTML5PushCallbackView."""
super().__init__(hass)
self.registrations = registrations self.registrations = registrations
def decode_jwt(self, token): def decode_jwt(self, token):
@ -324,7 +322,7 @@ class HTML5PushCallbackView(HomeAssistantView):
event_name = '{}.{}'.format(NOTIFY_CALLBACK_EVENT, event_name = '{}.{}'.format(NOTIFY_CALLBACK_EVENT,
event_payload[ATTR_TYPE]) event_payload[ATTR_TYPE])
self.hass.bus.fire(event_name, event_payload) request.app['hass'].bus.fire(event_name, event_payload)
return self.json({'status': 'ok', return self.json({'status': 'ok',
'event': event_payload[ATTR_TYPE]}) 'event': event_payload[ATTR_TYPE]})

View file

@ -274,7 +274,7 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
hass.http.register_redirect(FITBIT_AUTH_START, fitbit_auth_start_url) hass.http.register_redirect(FITBIT_AUTH_START, fitbit_auth_start_url)
hass.http.register_view(FitbitAuthCallbackView( hass.http.register_view(FitbitAuthCallbackView(
hass, config, add_devices, oauth)) config, add_devices, oauth))
request_oauth_completion(hass) request_oauth_completion(hass)
@ -286,9 +286,8 @@ class FitbitAuthCallbackView(HomeAssistantView):
url = '/auth/fitbit/callback' url = '/auth/fitbit/callback'
name = 'auth:fitbit:callback' name = 'auth:fitbit:callback'
def __init__(self, hass, config, add_devices, oauth): def __init__(self, config, add_devices, oauth):
"""Initialize the OAuth callback view.""" """Initialize the OAuth callback view."""
super().__init__(hass)
self.config = config self.config = config
self.add_devices = add_devices self.add_devices = add_devices
self.oauth = oauth self.oauth = oauth
@ -299,6 +298,7 @@ class FitbitAuthCallbackView(HomeAssistantView):
from oauthlib.oauth2.rfc6749.errors import MismatchingStateError from oauthlib.oauth2.rfc6749.errors import MismatchingStateError
from oauthlib.oauth2.rfc6749.errors import MissingTokenError from oauthlib.oauth2.rfc6749.errors import MissingTokenError
hass = request.app['hass']
data = request.GET data = request.GET
response_message = """Fitbit has been successfully authorized! response_message = """Fitbit has been successfully authorized!
@ -306,7 +306,7 @@ class FitbitAuthCallbackView(HomeAssistantView):
if data.get('code') is not None: if data.get('code') is not None:
redirect_uri = '{}{}'.format( redirect_uri = '{}{}'.format(
self.hass.config.api.base_url, FITBIT_AUTH_CALLBACK_PATH) hass.config.api.base_url, FITBIT_AUTH_CALLBACK_PATH)
try: try:
self.oauth.fetch_access_token(data.get('code'), redirect_uri) self.oauth.fetch_access_token(data.get('code'), redirect_uri)
@ -336,12 +336,11 @@ class FitbitAuthCallbackView(HomeAssistantView):
ATTR_CLIENT_ID: self.oauth.client_id, ATTR_CLIENT_ID: self.oauth.client_id,
ATTR_CLIENT_SECRET: self.oauth.client_secret ATTR_CLIENT_SECRET: self.oauth.client_secret
} }
if not config_from_file(self.hass.config.path(FITBIT_CONFIG_FILE), if not config_from_file(hass.config.path(FITBIT_CONFIG_FILE),
config_contents): config_contents):
_LOGGER.error("Failed to save config file") _LOGGER.error("Failed to save config file")
self.hass.async_add_job(setup_platform, self.hass, self.config, hass.async_add_job(setup_platform, hass, self.config, self.add_devices)
self.add_devices)
return html_response return html_response

View file

@ -59,7 +59,7 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
sensors = {} sensors = {}
hass.http.register_view(TorqueReceiveDataView( hass.http.register_view(TorqueReceiveDataView(
hass, email, vehicle, sensors, add_devices)) email, vehicle, sensors, add_devices))
return True return True
@ -69,9 +69,8 @@ class TorqueReceiveDataView(HomeAssistantView):
url = API_PATH url = API_PATH
name = 'api:torque' name = 'api:torque'
def __init__(self, hass, email, vehicle, sensors, add_devices): def __init__(self, email, vehicle, sensors, add_devices):
"""Initialize a Torque view.""" """Initialize a Torque view."""
super().__init__(hass)
self.email = email self.email = email
self.vehicle = vehicle self.vehicle = vehicle
self.sensors = sensors self.sensors = sensors
@ -80,6 +79,7 @@ class TorqueReceiveDataView(HomeAssistantView):
@callback @callback
def get(self, request): def get(self, request):
"""Handle Torque data request.""" """Handle Torque data request."""
hass = request.app['hass']
data = request.GET data = request.GET
if self.email is not None and self.email != data[SENSOR_EMAIL_FIELD]: if self.email is not None and self.email != data[SENSOR_EMAIL_FIELD]:
@ -108,7 +108,7 @@ class TorqueReceiveDataView(HomeAssistantView):
self.sensors[pid] = TorqueSensor( self.sensors[pid] = TorqueSensor(
ENTITY_NAME_FORMAT.format(self.vehicle, names[pid]), ENTITY_NAME_FORMAT.format(self.vehicle, names[pid]),
units.get(pid, None)) units.get(pid, None))
self.hass.async_add_job(self.add_devices, [self.sensors[pid]]) hass.async_add_job(self.add_devices, [self.sensors[pid]])
return None return None

View file

@ -97,6 +97,7 @@ class NetioApiView(HomeAssistantView):
@callback @callback
def get(self, request, host): def get(self, request, host):
"""Request handler.""" """Request handler."""
hass = request.app['hass']
data = request.GET data = request.GET
states, consumptions, cumulated_consumptions, start_dates = \ states, consumptions, cumulated_consumptions, start_dates = \
[], [], [], [] [], [], [], []
@ -119,7 +120,7 @@ class NetioApiView(HomeAssistantView):
ndev.start_dates = start_dates ndev.start_dates = start_dates
for dev in DEVICES[host].entities: for dev in DEVICES[host].entities:
self.hass.async_add_job(dev.async_update_ha_state()) hass.async_add_job(dev.async_update_ha_state())
return self.json(True) return self.json(True)

View file

@ -360,7 +360,6 @@ HTTP_HEADER_CONTENT_LENGTH = 'Content-Length'
HTTP_HEADER_CACHE_CONTROL = 'Cache-Control' HTTP_HEADER_CACHE_CONTROL = 'Cache-Control'
HTTP_HEADER_EXPIRES = 'Expires' HTTP_HEADER_EXPIRES = 'Expires'
HTTP_HEADER_ORIGIN = 'Origin' HTTP_HEADER_ORIGIN = 'Origin'
HTTP_HEADER_X_FORWARDED_FOR = 'X-Forwarded-For'
HTTP_HEADER_X_REQUESTED_WITH = 'X-Requested-With' HTTP_HEADER_X_REQUESTED_WITH = 'X-Requested-With'
HTTP_HEADER_ACCEPT = 'Accept' HTTP_HEADER_ACCEPT = 'Accept'
HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN = 'Access-Control-Allow-Origin' HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN = 'Access-Control-Allow-Origin'

View file

@ -0,0 +1,17 @@
"""Logging utilities."""
import logging
class HideSensitiveDataFilter(logging.Filter):
"""Filter API password calls."""
def __init__(self, text):
"""Initialize sensitive data filter."""
super().__init__()
self.text = text
def filter(self, record):
"""Hide sensitive data in messages."""
record.msg = record.msg.replace(self.text, '*******')
return True

View file

@ -10,6 +10,8 @@ import logging
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
from aiohttp import web
from homeassistant import core as ha, loader from homeassistant import core as ha, loader
from homeassistant.bootstrap import ( from homeassistant.bootstrap import (
setup_component, async_prepare_setup_component) setup_component, async_prepare_setup_component)
@ -22,6 +24,9 @@ from homeassistant.const import (
EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE, EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE,
ATTR_DISCOVERED, SERVER_PORT) ATTR_DISCOVERED, SERVER_PORT)
from homeassistant.components import sun, mqtt from homeassistant.components import sun, mqtt
from homeassistant.components.http.auth import auth_middleware
from homeassistant.components.http.const import (
KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED)
_TEST_INSTANCE_PORT = SERVER_PORT _TEST_INSTANCE_PORT = SERVER_PORT
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -210,13 +215,23 @@ def mock_http_component(hass):
"""Store registered view.""" """Store registered view."""
if isinstance(view, type): if isinstance(view, type):
# Instantiate the view, if needed # Instantiate the view, if needed
view = view(hass) view = view()
hass.http.views[view.name] = view hass.http.views[view.name] = view
hass.http.register_view = mock_register_view hass.http.register_view = mock_register_view
def mock_http_component_app(hass):
"""Create an aiohttp.web.Application instance for testing."""
hass.http.api_password = None
app = web.Application(middlewares=[auth_middleware], loop=hass.loop)
app['hass'] = hass
app[KEY_USE_X_FORWARDED_FOR] = False
app[KEY_BANS_ENABLED] = False
return app
def mock_mqtt_component(hass): def mock_mqtt_component(hass):
"""Mock the MQTT component.""" """Mock the MQTT component."""
with mock.patch('homeassistant.components.mqtt.MQTT') as mock_mqtt: with mock.patch('homeassistant.components.mqtt.MQTT') as mock_mqtt:

View file

@ -27,8 +27,8 @@ def test_fetching_url(aioclient_mock, hass, test_client):
resp = yield from client.get('/api/camera_proxy/camera.config_test') resp = yield from client.get('/api/camera_proxy/camera.config_test')
assert aioclient_mock.call_count == 1
assert resp.status == 200 assert resp.status == 200
assert aioclient_mock.call_count == 1
body = yield from resp.text() body = yield from resp.text()
assert body == 'hello world' assert body == 'hello world'

View file

@ -0,0 +1 @@
"""Tests for the HTTP component."""

View file

@ -0,0 +1,169 @@
"""The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access
import logging
from ipaddress import ip_address, ip_network
from unittest.mock import patch
import requests
from homeassistant import bootstrap, const
import homeassistant.components.http as http
from homeassistant.components.http.const import (
KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
from tests.common import get_test_instance_port, get_test_home_assistant
API_PASSWORD = 'test1234'
SERVER_PORT = get_test_instance_port()
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
HA_HEADERS = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
}
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1',
'FD01:DB8::1']
TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
'2001:DB8:ABCD::1']
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
hass = None
def _url(path=''):
"""Helper method to generate URLs."""
return HTTP_BASE_URL + path
# pylint: disable=invalid-name
def setUpModule():
"""Initialize a Home Assistant server."""
global hass
hass = get_test_home_assistant()
bootstrap.setup_component(
hass, http.DOMAIN, {
http.DOMAIN: {
http.CONF_API_PASSWORD: API_PASSWORD,
http.CONF_SERVER_PORT: SERVER_PORT,
}
}
)
bootstrap.setup_component(hass, 'api')
hass.http.app[KEY_TRUSTED_NETWORKS] = [
ip_network(trusted_network)
for trusted_network in TRUSTED_NETWORKS]
hass.start()
# pylint: disable=invalid-name
def tearDownModule():
"""Stop the Home Assistant server."""
hass.stop()
class TestHttp:
"""Test HTTP component."""
def test_access_denied_without_password(self):
"""Test access without password."""
req = requests.get(_url(const.URL_API))
assert req.status_code == 401
def test_access_denied_with_wrong_password_in_header(self):
"""Test access with wrong password."""
req = requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: 'wrongpassword'})
assert req.status_code == 401
def test_access_denied_with_x_forwarded_for(self, caplog):
"""Test access denied through the X-Forwarded-For http header."""
hass.http.use_x_forwarded_for = True
for remote_addr in UNTRUSTED_ADDRESSES:
req = requests.get(_url(const.URL_API), headers={
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
assert req.status_code == 401, \
"{} shouldn't be trusted".format(remote_addr)
def test_access_denied_with_untrusted_ip(self, caplog):
"""Test access with an untrusted ip address."""
for remote_addr in UNTRUSTED_ADDRESSES:
with patch('homeassistant.components.http.'
'util.get_real_ip',
return_value=ip_address(remote_addr)):
req = requests.get(
_url(const.URL_API), params={'api_password': ''})
assert req.status_code == 401, \
"{} shouldn't be trusted".format(remote_addr)
def test_access_with_password_in_header(self, caplog):
"""Test access with password in URL."""
# Hide logging from requests package that we use to test logging
caplog.set_level(
logging.WARNING, logger='requests.packages.urllib3.connectionpool')
req = requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status_code == 200
logs = caplog.text
assert const.URL_API in logs
assert API_PASSWORD not in logs
def test_access_denied_with_wrong_password_in_url(self):
"""Test access with wrong password."""
req = requests.get(
_url(const.URL_API), params={'api_password': 'wrongpassword'})
assert req.status_code == 401
def test_access_with_password_in_url(self, caplog):
"""Test access with password in URL."""
# Hide logging from requests package that we use to test logging
caplog.set_level(
logging.WARNING, logger='requests.packages.urllib3.connectionpool')
req = requests.get(
_url(const.URL_API), params={'api_password': API_PASSWORD})
assert req.status_code == 200
logs = caplog.text
assert const.URL_API in logs
assert API_PASSWORD not in logs
def test_access_granted_with_x_forwarded_for(self, caplog):
"""Test access denied through the X-Forwarded-For http header."""
hass.http.app[KEY_USE_X_FORWARDED_FOR] = True
for remote_addr in TRUSTED_ADDRESSES:
req = requests.get(_url(const.URL_API), headers={
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
assert req.status_code == 200, \
"{} should be trusted".format(remote_addr)
def test_access_granted_with_trusted_ip(self, caplog):
"""Test access with trusted addresses."""
for remote_addr in TRUSTED_ADDRESSES:
with patch('homeassistant.components.http.'
'auth.get_real_ip',
return_value=ip_address(remote_addr)):
req = requests.get(
_url(const.URL_API), params={'api_password': ''})
assert req.status_code == 200, \
'{} should be trusted'.format(remote_addr)

View file

@ -0,0 +1,118 @@
"""The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access
from ipaddress import ip_address
from unittest.mock import patch, mock_open
import requests
from homeassistant import bootstrap, const
import homeassistant.components.http as http
from homeassistant.components.http.const import (
KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD, KEY_BANNED_IPS)
from homeassistant.components.http.ban import IpBan, IP_BANS_FILE
from tests.common import get_test_instance_port, get_test_home_assistant
API_PASSWORD = 'test1234'
SERVER_PORT = get_test_instance_port()
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
HA_HEADERS = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
}
BANNED_IPS = ['200.201.202.203', '100.64.0.2']
hass = None
def _url(path=''):
"""Helper method to generate URLs."""
return HTTP_BASE_URL + path
# pylint: disable=invalid-name
def setUpModule():
"""Initialize a Home Assistant server."""
global hass
hass = get_test_home_assistant()
bootstrap.setup_component(
hass, http.DOMAIN, {
http.DOMAIN: {
http.CONF_API_PASSWORD: API_PASSWORD,
http.CONF_SERVER_PORT: SERVER_PORT,
}
}
)
bootstrap.setup_component(hass, 'api')
hass.http.app[KEY_BANNED_IPS] = [IpBan(banned_ip) for banned_ip
in BANNED_IPS]
hass.start()
# pylint: disable=invalid-name
def tearDownModule():
"""Stop the Home Assistant server."""
hass.stop()
class TestHttp:
"""Test HTTP component."""
def test_access_from_banned_ip(self):
"""Test accessing to server from banned IP. Both trusted and not."""
hass.http.app[KEY_BANS_ENABLED] = True
for remote_addr in BANNED_IPS:
with patch('homeassistant.components.http.'
'ban.get_real_ip',
return_value=ip_address(remote_addr)):
req = requests.get(
_url(const.URL_API))
assert req.status_code == 403
def test_access_from_banned_ip_when_ban_is_off(self):
"""Test accessing to server from banned IP when feature is off"""
hass.http.app[KEY_BANS_ENABLED] = False
for remote_addr in BANNED_IPS:
with patch('homeassistant.components.http.'
'ban.get_real_ip',
return_value=ip_address(remote_addr)):
req = requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status_code == 200
def test_ip_bans_file_creation(self):
"""Testing if banned IP file created"""
hass.http.app[KEY_BANS_ENABLED] = True
hass.http.app[KEY_LOGIN_THRESHOLD] = 1
m = mock_open()
def call_server():
with patch('homeassistant.components.http.'
'ban.get_real_ip',
return_value=ip_address("200.201.202.204")):
print("GETTING API")
return requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'})
with patch('homeassistant.components.http.ban.open', m, create=True):
req = call_server()
assert req.status_code == 401
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS)
assert m.call_count == 0
req = call_server()
assert req.status_code == 401
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
req = call_server()
assert req.status_code == 403
assert m.call_count == 1

View file

@ -0,0 +1,111 @@
"""The tests for the Home Assistant HTTP component."""
import requests
from homeassistant import bootstrap, const
import homeassistant.components.http as http
from tests.common import get_test_instance_port, get_test_home_assistant
API_PASSWORD = 'test1234'
SERVER_PORT = get_test_instance_port()
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
HA_HEADERS = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
}
CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE]
hass = None
def _url(path=''):
"""Helper method to generate URLs."""
return HTTP_BASE_URL + path
# pylint: disable=invalid-name
def setUpModule():
"""Initialize a Home Assistant server."""
global hass
hass = get_test_home_assistant()
bootstrap.setup_component(
hass, http.DOMAIN, {
http.DOMAIN: {
http.CONF_API_PASSWORD: API_PASSWORD,
http.CONF_SERVER_PORT: SERVER_PORT,
http.CONF_CORS_ORIGINS: CORS_ORIGINS,
}
}
)
bootstrap.setup_component(hass, 'api')
hass.start()
# pylint: disable=invalid-name
def tearDownModule():
"""Stop the Home Assistant server."""
hass.stop()
class TestHttp:
"""Test HTTP component."""
def test_cors_allowed_with_password_in_url(self):
"""Test cross origin resource sharing with password in url."""
req = requests.get(_url(const.URL_API),
params={'api_password': API_PASSWORD},
headers={const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL})
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
def test_cors_allowed_with_password_in_header(self):
"""Test cross origin resource sharing with password in header."""
headers = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL
}
req = requests.get(_url(const.URL_API), headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
def test_cors_denied_without_origin_header(self):
"""Test cross origin resource sharing with password in header."""
headers = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD
}
req = requests.get(_url(const.URL_API), headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
assert req.status_code == 200
assert allow_origin not in req.headers
assert allow_headers not in req.headers
def test_cors_preflight_allowed(self):
"""Test cross origin resource sharing preflight (OPTIONS) request."""
headers = {
const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL,
'Access-Control-Request-Method': 'GET',
'Access-Control-Request-Headers': 'x-ha-access'
}
req = requests.options(_url(const.URL_API), headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
assert req.headers.get(allow_headers) == \
const.HTTP_HEADER_HA_AUTH.upper()

View file

@ -3,10 +3,10 @@ import asyncio
import json import json
from unittest.mock import patch, MagicMock, mock_open from unittest.mock import patch, MagicMock, mock_open
from aiohttp import web
from homeassistant.components.notify import html5 from homeassistant.components.notify import html5
from tests.common import mock_http_component_app
SUBSCRIPTION_1 = { SUBSCRIPTION_1 = {
'browser': 'chrome', 'browser': 'chrome',
'subscription': { 'subscription': {
@ -121,7 +121,8 @@ class TestHtml5Notify(object):
assert view.json_path == hass.config.path.return_value assert view.json_path == hass.config.path.return_value
assert view.registrations == {} assert view.registrations == {}
app = web.Application(loop=loop) hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router) view.register(app.router)
client = yield from test_client(app) client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False hass.http.is_banned_ip.return_value = False
@ -153,7 +154,8 @@ class TestHtml5Notify(object):
view = hass.mock_calls[1][1][0] view = hass.mock_calls[1][1][0]
app = web.Application(loop=loop) hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router) view.register(app.router)
client = yield from test_client(app) client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False hass.http.is_banned_ip.return_value = False
@ -208,7 +210,8 @@ class TestHtml5Notify(object):
assert view.json_path == hass.config.path.return_value assert view.json_path == hass.config.path.return_value
assert view.registrations == config assert view.registrations == config
app = web.Application(loop=loop) hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router) view.register(app.router)
client = yield from test_client(app) client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False hass.http.is_banned_ip.return_value = False
@ -253,7 +256,8 @@ class TestHtml5Notify(object):
assert view.json_path == hass.config.path.return_value assert view.json_path == hass.config.path.return_value
assert view.registrations == config assert view.registrations == config
app = web.Application(loop=loop) hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router) view.register(app.router)
client = yield from test_client(app) client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False hass.http.is_banned_ip.return_value = False
@ -296,7 +300,8 @@ class TestHtml5Notify(object):
assert view.json_path == hass.config.path.return_value assert view.json_path == hass.config.path.return_value
assert view.registrations == config assert view.registrations == config
app = web.Application(loop=loop) hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router) view.register(app.router)
client = yield from test_client(app) client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False hass.http.is_banned_ip.return_value = False
@ -331,7 +336,8 @@ class TestHtml5Notify(object):
view = hass.mock_calls[2][1][0] view = hass.mock_calls[2][1][0]
app = web.Application(loop=loop) hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router) view.register(app.router)
client = yield from test_client(app) client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False hass.http.is_banned_ip.return_value = False
@ -387,7 +393,8 @@ class TestHtml5Notify(object):
bearer_token = "Bearer {}".format(push_payload['data']['jwt']) bearer_token = "Bearer {}".format(push_payload['data']['jwt'])
app = web.Application(loop=loop) hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router) view.register(app.router)
client = yield from test_client(app) client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False hass.http.is_banned_ip.return_value = False

View file

@ -6,7 +6,7 @@ import unittest
import requests import requests
import homeassistant.bootstrap as bootstrap import homeassistant.bootstrap as bootstrap
from homeassistant.components import frontend, http from homeassistant.components import http
from homeassistant.const import HTTP_HEADER_HA_AUTH from homeassistant.const import HTTP_HEADER_HA_AUTH
from tests.common import get_test_instance_port, get_test_home_assistant from tests.common import get_test_instance_port, get_test_home_assistant
@ -45,7 +45,6 @@ def setUpModule():
def tearDownModule(): def tearDownModule():
"""Stop everything that was started.""" """Stop everything that was started."""
hass.stop() hass.stop()
frontend.PANELS = {}
class TestFrontend(unittest.TestCase): class TestFrontend(unittest.TestCase):

View file

@ -1,285 +0,0 @@
"""The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access
import logging
from ipaddress import ip_network
from unittest.mock import patch, mock_open
import requests
from homeassistant import bootstrap, const
import homeassistant.components.http as http
from tests.common import get_test_instance_port, get_test_home_assistant
API_PASSWORD = 'test1234'
SERVER_PORT = get_test_instance_port()
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
HA_HEADERS = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
}
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1',
'FD01:DB8::1']
TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
'2001:DB8:ABCD::1']
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
BANNED_IPS = ['200.201.202.203', '100.64.0.1']
CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE]
hass = None
def _url(path=''):
"""Helper method to generate URLs."""
return HTTP_BASE_URL + path
# pylint: disable=invalid-name
def setUpModule():
"""Initialize a Home Assistant server."""
global hass
hass = get_test_home_assistant()
hass.bus.listen('test_event', lambda _: _)
hass.states.set('test.test', 'a_state')
bootstrap.setup_component(
hass, http.DOMAIN, {
http.DOMAIN: {
http.CONF_API_PASSWORD: API_PASSWORD,
http.CONF_SERVER_PORT: SERVER_PORT,
http.CONF_CORS_ORIGINS: CORS_ORIGINS,
}
}
)
bootstrap.setup_component(hass, 'api')
hass.http.trusted_networks = [
ip_network(trusted_network)
for trusted_network in TRUSTED_NETWORKS]
hass.http.ip_bans = [http.IpBan(banned_ip)
for banned_ip in BANNED_IPS]
hass.start()
# pylint: disable=invalid-name
def tearDownModule():
"""Stop the Home Assistant server."""
hass.stop()
class TestHttp:
"""Test HTTP component."""
def test_access_denied_without_password(self):
"""Test access without password."""
req = requests.get(_url(const.URL_API))
assert req.status_code == 401
def test_access_denied_with_wrong_password_in_header(self):
"""Test access with wrong password."""
req = requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: 'wrongpassword'})
assert req.status_code == 401
def test_access_denied_with_x_forwarded_for(self, caplog):
"""Test access denied through the X-Forwarded-For http header."""
hass.http.use_x_forwarded_for = True
for remote_addr in UNTRUSTED_ADDRESSES:
req = requests.get(_url(const.URL_API), headers={
const.HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
assert req.status_code == 401, \
"{} shouldn't be trusted".format(remote_addr)
def test_access_denied_with_untrusted_ip(self, caplog):
"""Test access with an untrusted ip address."""
for remote_addr in UNTRUSTED_ADDRESSES:
with patch('homeassistant.components.http.'
'HomeAssistantWSGI.get_real_ip',
return_value=remote_addr):
req = requests.get(
_url(const.URL_API), params={'api_password': ''})
assert req.status_code == 401, \
"{} shouldn't be trusted".format(remote_addr)
def test_access_with_password_in_header(self, caplog):
"""Test access with password in URL."""
# Hide logging from requests package that we use to test logging
caplog.set_level(
logging.WARNING, logger='requests.packages.urllib3.connectionpool')
req = requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status_code == 200
logs = caplog.text
# assert const.URL_API in logs
assert API_PASSWORD not in logs
def test_access_denied_with_wrong_password_in_url(self):
"""Test access with wrong password."""
req = requests.get(
_url(const.URL_API), params={'api_password': 'wrongpassword'})
assert req.status_code == 401
def test_access_with_password_in_url(self, caplog):
"""Test access with password in URL."""
# Hide logging from requests package that we use to test logging
caplog.set_level(
logging.WARNING, logger='requests.packages.urllib3.connectionpool')
req = requests.get(
_url(const.URL_API), params={'api_password': API_PASSWORD})
assert req.status_code == 200
logs = caplog.text
# assert const.URL_API in logs
assert API_PASSWORD not in logs
def test_access_granted_with_x_forwarded_for(self, caplog):
"""Test access denied through the X-Forwarded-For http header."""
hass.http.use_x_forwarded_for = True
for remote_addr in TRUSTED_ADDRESSES:
req = requests.get(_url(const.URL_API), headers={
const.HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
assert req.status_code == 200, \
"{} should be trusted".format(remote_addr)
def test_access_granted_with_trusted_ip(self, caplog):
"""Test access with trusted addresses."""
for remote_addr in TRUSTED_ADDRESSES:
with patch('homeassistant.components.http.'
'HomeAssistantWSGI.get_real_ip',
return_value=remote_addr):
req = requests.get(
_url(const.URL_API), params={'api_password': ''})
assert req.status_code == 200, \
'{} should be trusted'.format(remote_addr)
def test_cors_allowed_with_password_in_url(self):
"""Test cross origin resource sharing with password in url."""
req = requests.get(_url(const.URL_API),
params={'api_password': API_PASSWORD},
headers={const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL})
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
def test_cors_allowed_with_password_in_header(self):
"""Test cross origin resource sharing with password in header."""
headers = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL
}
req = requests.get(_url(const.URL_API), headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
def test_cors_denied_without_origin_header(self):
"""Test cross origin resource sharing with password in header."""
headers = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD
}
req = requests.get(_url(const.URL_API), headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
assert req.status_code == 200
assert allow_origin not in req.headers
assert allow_headers not in req.headers
def test_cors_preflight_allowed(self):
"""Test cross origin resource sharing preflight (OPTIONS) request."""
headers = {
const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL,
'Access-Control-Request-Method': 'GET',
'Access-Control-Request-Headers': 'x-ha-access'
}
req = requests.options(_url(const.URL_API), headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
assert req.headers.get(allow_headers) == \
const.HTTP_HEADER_HA_AUTH.upper()
def test_access_from_banned_ip(self):
"""Test accessing to server from banned IP. Both trusted and not."""
hass.http.is_ban_enabled = True
for remote_addr in BANNED_IPS:
with patch('homeassistant.components.http.'
'HomeAssistantWSGI.get_real_ip',
return_value=remote_addr):
req = requests.get(
_url(const.URL_API))
assert req.status_code == 403
def test_access_from_banned_ip_when_ban_is_off(self):
"""Test accessing to server from banned IP when feature is off"""
hass.http.is_ban_enabled = False
for remote_addr in BANNED_IPS:
with patch('homeassistant.components.http.'
'HomeAssistantWSGI.get_real_ip',
return_value=remote_addr):
req = requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status_code == 200
def test_ip_bans_file_creation(self):
"""Testing if banned IP file created"""
hass.http.is_ban_enabled = True
hass.http.login_threshold = 1
m = mock_open()
def call_server():
with patch('homeassistant.components.http.'
'HomeAssistantWSGI.get_real_ip',
return_value="200.201.202.204"):
return requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'})
with patch('homeassistant.components.http.open', m, create=True):
req = call_server()
assert req.status_code == 401
assert len(hass.http.ip_bans) == len(BANNED_IPS)
assert m.call_count == 0
req = call_server()
assert req.status_code == 401
assert len(hass.http.ip_bans) == len(BANNED_IPS) + 1
m.assert_called_once_with(hass.config.path(http.IP_BANS), 'a')
req = call_server()
assert req.status_code == 403
assert m.call_count == 1

View file

@ -165,7 +165,15 @@ class TestCheckConfig(unittest.TestCase):
self.assertDictEqual({ self.assertDictEqual({
'components': {'http': {'api_password': 'abc123', 'components': {'http': {'api_password': 'abc123',
'cors_allowed_origins': [],
'development': '0',
'ip_ban_enabled': True,
'login_attempts_threshold': -1,
'server_host': '0.0.0.0',
'server_port': 8123, 'server_port': 8123,
'ssl_certificate': None,
'ssl_key': None,
'trusted_networks': [],
'use_x_forwarded_for': False}}, 'use_x_forwarded_for': False}},
'except': {}, 'except': {},
'secret_cache': {secrets_path: {'http_pw': 'abc123'}}, 'secret_cache': {secrets_path: {'http_pw': 'abc123'}},