Tons of fixes - WIP

This commit is contained in:
Paulus Schoutsen 2016-05-14 00:58:36 -07:00
parent 768c98d359
commit 15e329a588
22 changed files with 938 additions and 1604 deletions

View file

@ -1,41 +1,17 @@
"""
This module provides an API and a HTTP interface for debug purposes.
For more details about the RESTful API, please refer to the documentation at
https://home-assistant.io/developers/api/
"""
import gzip
"""This module provides WSGI application to serve the Home Assistant API."""
import hmac
import json
import logging
import ssl
import threading
import time
from datetime import timedelta
from http import cookies
from http.server import HTTPServer, SimpleHTTPRequestHandler
from socketserver import ThreadingMixIn
from urllib.parse import parse_qs, urlparse
import voluptuous as vol
import re
import homeassistant.bootstrap as bootstrap
import homeassistant.core as ha
import homeassistant.remote as rem
import homeassistant.util as util
import homeassistant.util.dt as date_util
import homeassistant.helpers.config_validation as cv
from homeassistant.const import (
CONTENT_TYPE_JSON, CONTENT_TYPE_TEXT_PLAIN, HTTP_HEADER_ACCEPT_ENCODING,
HTTP_HEADER_CACHE_CONTROL, HTTP_HEADER_CONTENT_ENCODING,
HTTP_HEADER_CONTENT_LENGTH, HTTP_HEADER_CONTENT_TYPE, HTTP_HEADER_EXPIRES,
HTTP_HEADER_HA_AUTH, HTTP_HEADER_VARY,
HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN,
HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS, HTTP_METHOD_NOT_ALLOWED,
HTTP_NOT_FOUND, HTTP_OK, HTTP_UNAUTHORIZED, HTTP_UNPROCESSABLE_ENTITY,
ALLOWED_CORS_HEADERS,
SERVER_PORT, URL_ROOT, URL_API_EVENT_FORWARD)
from homeassistant import util
from homeassistant.const import SERVER_PORT, HTTP_HEADER_HA_AUTH
DOMAIN = "http"
REQUIREMENTS = ("eventlet==0.18.4", "static3==0.6.1", "Werkzeug==0.11.5",)
CONF_API_PASSWORD = "api_password"
CONF_SERVER_HOST = "server_host"
@ -43,61 +19,42 @@ CONF_SERVER_PORT = "server_port"
CONF_DEVELOPMENT = "development"
CONF_SSL_CERTIFICATE = 'ssl_certificate'
CONF_SSL_KEY = 'ssl_key'
CONF_CORS_ORIGINS = 'cors_allowed_origins'
DATA_API_PASSWORD = 'api_password'
# Throttling time in seconds for expired sessions check
SESSION_CLEAR_INTERVAL = timedelta(seconds=20)
SESSION_TIMEOUT_SECONDS = 1800
SESSION_KEY = 'sessionId'
_FINGERPRINT = re.compile(r'^(.+)-[a-z0-9]{32}\.(\w+)$', re.IGNORECASE)
_LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = vol.Schema({
DOMAIN: vol.Schema({
vol.Optional(CONF_API_PASSWORD): cv.string,
vol.Optional(CONF_SERVER_HOST): cv.string,
vol.Optional(CONF_SERVER_PORT, default=SERVER_PORT):
vol.All(vol.Coerce(int), vol.Range(min=1, max=65535)),
vol.Optional(CONF_DEVELOPMENT): cv.string,
vol.Optional(CONF_SSL_CERTIFICATE): cv.isfile,
vol.Optional(CONF_SSL_KEY): cv.isfile,
vol.Optional(CONF_CORS_ORIGINS): cv.ensure_list
}),
}, extra=vol.ALLOW_EXTRA)
def setup(hass, config):
"""Set up the HTTP API and debug interface."""
conf = config.get(DOMAIN, {})
api_password = util.convert(conf.get(CONF_API_PASSWORD), str)
# If no server host is given, accept all incoming requests
server_host = conf.get(CONF_SERVER_HOST, '0.0.0.0')
server_port = conf.get(CONF_SERVER_PORT, SERVER_PORT)
development = str(conf.get(CONF_DEVELOPMENT, "")) == "1"
ssl_certificate = conf.get(CONF_SSL_CERTIFICATE)
ssl_key = conf.get(CONF_SSL_KEY)
cors_origins = conf.get(CONF_CORS_ORIGINS, [])
try:
server = HomeAssistantHTTPServer(
(server_host, server_port), RequestHandler, hass, api_password,
development, ssl_certificate, ssl_key, cors_origins)
except OSError:
# If address already in use
_LOGGER.exception("Error setting up HTTP server")
return False
server = HomeAssistantWSGI(
hass,
development=development,
server_host=server_host,
server_port=server_port,
api_password=api_password,
ssl_certificate=ssl_certificate,
ssl_key=ssl_key,
)
hass.bus.listen_once(
ha.EVENT_HOMEASSISTANT_START,
lambda event:
threading.Thread(target=server.start, daemon=True,
name='HTTP-server').start())
name='WSGI-server').start())
hass.http = server
hass.wsgi = server
hass.config.api = rem.API(server_host if server_host != '0.0.0.0'
else util.get_local_ip(),
api_password, server_port,
@ -106,413 +63,277 @@ def setup(hass, config):
return True
# pylint: disable=too-many-instance-attributes
class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
"""Handle HTTP requests in a threaded fashion."""
# class StaticFileServer(object):
# """Static file serving middleware."""
# pylint: disable=too-few-public-methods
allow_reuse_address = True
daemon_threads = True
# def __call__(self, environ, start_response):
# from werkzeug.wsgi import DispatcherMiddleware
# app = DispatcherMiddleware(self.base_app, self.extra_apps)
# # Strip out any cachebusting MD% fingerprints
# fingerprinted = _FINGERPRINT.match(environ['PATH_INFO'])
# if fingerprinted:
# environ['PATH_INFO'] = "{}.{}".format(*fingerprinted.groups())
# return app(environ, start_response)
class HomeAssistantWSGI(object):
"""WSGI server for Home Assistant."""
# pylint: disable=too-many-instance-attributes, too-many-locals
# pylint: disable=too-many-arguments
def __init__(self, server_address, request_handler_class,
hass, api_password, development, ssl_certificate, ssl_key,
cors_origins):
"""Initialize the server."""
super().__init__(server_address, request_handler_class)
self.server_address = server_address
def __init__(self, hass, development, api_password, ssl_certificate,
ssl_key, server_host, server_port):
"""Initilalize the WSGI Home Assistant server."""
from werkzeug.exceptions import BadRequest
from werkzeug.wrappers import BaseRequest, AcceptMixin
from werkzeug.contrib.wrappers import JSONRequestMixin
from werkzeug.routing import Map
from werkzeug.utils import cached_property
from werkzeug.wrappers import Response
class Request(BaseRequest, AcceptMixin, JSONRequestMixin):
"""Base class for incoming requests."""
@cached_property
def json(self):
"""Get the result of json.loads if possible."""
if not self.data:
return None
elif 'json' not in self.environ.get('CONTENT_TYPE', ''):
raise BadRequest('Not a JSON request')
try:
return json.loads(self.data.decode(
self.charset, self.encoding_errors))
except (TypeError, ValueError):
raise BadRequest('Unable to read JSON request')
Response.mimetype = 'text/html'
# pylint: disable=invalid-name
self.Request = Request
self.url_map = Map()
self.views = {}
self.hass = hass
self.api_password = api_password
self.extra_apps = {}
self.development = development
self.paths = []
self.sessions = SessionStore()
self.use_ssl = ssl_certificate is not None
self.cors_origins = cors_origins
# We will lazy init this one if needed
self.api_password = api_password
self.ssl_certificate = ssl_certificate
self.ssl_key = ssl_key
self.server_host = server_host
self.server_port = server_port
self.event_forwarder = None
if development:
_LOGGER.info("running http in development mode")
def register_view(self, view):
"""Register a view with the WSGI server.
if ssl_certificate is not None:
context = ssl.create_default_context(
purpose=ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(ssl_certificate, keyfile=ssl_key)
self.socket = context.wrap_socket(self.socket, server_side=True)
The view argument must inherit from the HomeAssistantView class, and
it must have (globally unique) 'url' and 'name' attributes.
"""
from werkzeug.routing import Rule
if view.name in self.views:
_LOGGER.warning("View '%s' is being overwritten", view.name)
if isinstance(view, type):
view = view(self.hass)
self.views[view.name] = view
rule = Rule(view.url, endpoint=view.name)
self.url_map.add(rule)
for url in view.extra_urls:
rule = Rule(url, endpoint=view.name)
self.url_map.add(rule)
def register_redirect(self, url, redirect_to):
"""Register a redirect with the server.
If given this must be either a string or callable. In case of a
callable its called with the url adapter that triggered the match and
the values of the URL as keyword arguments and has to return the target
for the redirect, otherwise it has to be a string with placeholders in
rule syntax.
"""
from werkzeug.routing import Rule
self.url_map.add(Rule(url, redirect_to=redirect_to))
def register_static_path(self, url_root, path):
"""Register a folder to serve as a static path."""
from static import Cling
if url_root in self.extra_apps:
_LOGGER.warning("Static path '%s' is being overwritten", path)
self.extra_apps[url_root] = Cling(path)
def start(self):
"""Start the HTTP server."""
def stop_http(event):
"""Stop the HTTP server."""
self.shutdown()
"""Start the wsgi server."""
from eventlet import wsgi
import eventlet
self.hass.bus.listen_once(ha.EVENT_HOMEASSISTANT_STOP, stop_http)
sock = eventlet.listen((self.server_host, self.server_port))
if self.ssl_certificate:
eventlet.wrap_ssl(sock, certfile=self.ssl_certificate,
keyfile=self.ssl_key, server_side=True)
wsgi.server(sock, self)
protocol = 'https' if self.use_ssl else 'http'
def dispatch_request(self, request):
"""Handle incoming request."""
from werkzeug.exceptions import (
MethodNotAllowed, NotFound, BadRequest, Unauthorized,
)
from werkzeug.routing import RequestRedirect
_LOGGER.info(
"Starting web interface at %s://%s:%d",
protocol, self.server_address[0], self.server_address[1])
# 31-1-2015: Refactored frontend/api components out of this component
# To prevent stuff from breaking, load the two extracted components
bootstrap.setup_component(self.hass, 'api')
bootstrap.setup_component(self.hass, 'frontend')
self.serve_forever()
def register_path(self, method, url, callback, require_auth=True):
"""Register a path with the server."""
self.paths.append((method, url, callback, require_auth))
def log_message(self, fmt, *args):
"""Redirect built-in log to HA logging."""
# pylint: disable=no-self-use
_LOGGER.info(fmt, *args)
# pylint: disable=too-many-public-methods,too-many-locals
class RequestHandler(SimpleHTTPRequestHandler):
"""Handle incoming HTTP requests.
We extend from SimpleHTTPRequestHandler instead of Base so we
can use the guess content type methods.
"""
server_version = "HomeAssistant/1.0"
def __init__(self, req, client_addr, server):
"""Constructor, call the base constructor and set up session."""
# Track if this was an authenticated request
self.authenticated = False
SimpleHTTPRequestHandler.__init__(self, req, client_addr, server)
self.protocol_version = 'HTTP/1.1'
def log_message(self, fmt, *arguments):
"""Redirect built-in log to HA logging."""
if self.server.api_password is None:
_LOGGER.info(fmt, *arguments)
else:
_LOGGER.info(
fmt, *(arg.replace(self.server.api_password, '*******')
if isinstance(arg, str) else arg for arg in arguments))
def _handle_request(self, method): # pylint: disable=too-many-branches
"""Perform some common checks and call appropriate method."""
url = urlparse(self.path)
# Read query input. parse_qs gives a list for each value, we want last
data = {key: data[-1] for key, data in parse_qs(url.query).items()}
# Did we get post input ?
content_length = int(self.headers.get(HTTP_HEADER_CONTENT_LENGTH, 0))
if content_length:
body_content = self.rfile.read(content_length).decode("UTF-8")
try:
data.update(json.loads(body_content))
except (TypeError, ValueError):
# TypeError if JSON object is not a dict
# ValueError if we could not parse JSON
_LOGGER.exception(
"Exception parsing JSON: %s", body_content)
self.write_json_message(
"Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY)
return
if self.verify_session():
# The user has a valid session already
self.authenticated = True
elif self.server.api_password is None:
# No password is set, so everyone is authenticated
self.authenticated = True
elif hmac.compare_digest(self.headers.get(HTTP_HEADER_HA_AUTH, ''),
self.server.api_password):
# A valid auth header has been set
self.authenticated = True
elif hmac.compare_digest(data.get(DATA_API_PASSWORD, ''),
self.server.api_password):
# A valid password has been specified
self.authenticated = True
else:
self.authenticated = False
# we really shouldn't need to forward the password from here
if url.path not in [URL_ROOT, URL_API_EVENT_FORWARD]:
data.pop(DATA_API_PASSWORD, None)
if '_METHOD' in data:
method = data.pop('_METHOD')
# Var to keep track if we found a path that matched a handler but
# the method was different
path_matched_but_not_method = False
# Var to hold the handler for this path and method if found
handle_request_method = False
require_auth = True
# Check every handler to find matching result
for t_method, t_path, t_handler, t_auth in self.server.paths:
# we either do string-comparison or regular expression matching
# pylint: disable=maybe-no-member
if isinstance(t_path, str):
path_match = url.path == t_path
else:
path_match = t_path.match(url.path)
if path_match and method == t_method:
# Call the method
handle_request_method = t_handler
require_auth = t_auth
break
elif path_match:
path_matched_but_not_method = True
# Did we find a handler for the incoming request?
if handle_request_method:
# For some calls we need a valid password
msg = "API password missing or incorrect."
if require_auth and not self.authenticated:
self.write_json_message(msg, HTTP_UNAUTHORIZED)
_LOGGER.warning('%s Source IP: %s',
msg,
self.client_address[0])
return
handle_request_method(self, path_match, data)
elif path_matched_but_not_method:
self.send_response(HTTP_METHOD_NOT_ALLOWED)
self.end_headers()
else:
self.send_response(HTTP_NOT_FOUND)
self.end_headers()
def do_HEAD(self): # pylint: disable=invalid-name
"""HEAD request handler."""
self._handle_request('HEAD')
def do_GET(self): # pylint: disable=invalid-name
"""GET request handler."""
self._handle_request('GET')
def do_POST(self): # pylint: disable=invalid-name
"""POST request handler."""
self._handle_request('POST')
def do_PUT(self): # pylint: disable=invalid-name
"""PUT request handler."""
self._handle_request('PUT')
def do_DELETE(self): # pylint: disable=invalid-name
"""DELETE request handler."""
self._handle_request('DELETE')
def write_json_message(self, message, status_code=HTTP_OK):
"""Helper method to return a message to the caller."""
self.write_json({'message': message}, status_code=status_code)
def write_json(self, data=None, status_code=HTTP_OK, location=None):
"""Helper method to return JSON to the caller."""
json_data = json.dumps(data, indent=4, sort_keys=True,
cls=rem.JSONEncoder).encode('UTF-8')
self.send_response(status_code)
if location:
self.send_header('Location', location)
self.set_session_cookie_header()
self.write_content(json_data, CONTENT_TYPE_JSON)
def write_text(self, message, status_code=HTTP_OK):
"""Helper method to return a text message to the caller."""
msg_data = message.encode('UTF-8')
self.send_response(status_code)
self.set_session_cookie_header()
self.write_content(msg_data, CONTENT_TYPE_TEXT_PLAIN)
def write_file(self, path, cache_headers=True):
"""Return a file to the user."""
adapter = self.url_map.bind_to_environ(request.environ)
try:
with open(path, 'rb') as inp:
self.write_file_pointer(self.guess_type(path), inp,
cache_headers)
endpoint, values = adapter.match()
return self.views[endpoint].handle_request(request, **values)
except RequestRedirect as ex:
return ex
except BadRequest as ex:
return self._handle_error(request, str(ex), 400)
except NotFound as ex:
return self._handle_error(request, str(ex), 404)
except MethodNotAllowed as ex:
return self._handle_error(request, str(ex), 405)
except Unauthorized as ex:
return self._handle_error(request, str(ex), 401)
# TODO This long chain of except blocks is silly. _handle_error should
# just take the exception as an argument and parse the status code
# itself
except IOError:
self.send_response(HTTP_NOT_FOUND)
self.end_headers()
_LOGGER.exception("Unable to serve %s", path)
def base_app(self, environ, start_response):
"""WSGI Handler of requests to base app."""
request = self.Request(environ)
response = self.dispatch_request(request)
return response(environ, start_response)
def write_file_pointer(self, content_type, inp, cache_headers=True):
"""Helper function to write a file pointer to the user."""
self.send_response(HTTP_OK)
def __call__(self, environ, start_response):
"""Handle a request for base app + extra apps."""
from werkzeug.wsgi import DispatcherMiddleware
if cache_headers:
self.set_cache_header()
self.set_session_cookie_header()
app = DispatcherMiddleware(self.base_app, self.extra_apps)
# Strip out any cachebusting MD5 fingerprints
fingerprinted = _FINGERPRINT.match(environ.get('PATH_INFO', ''))
if fingerprinted:
environ['PATH_INFO'] = "{}.{}".format(*fingerprinted.groups())
return app(environ, start_response)
self.write_content(inp.read(), content_type)
def _handle_error(self, request, message, status):
"""Handle a WSGI request error."""
from werkzeug.wrappers import Response
if request.accept_mimetypes.accept_json:
message = json.dumps({
"result": "error",
"message": message,
})
mimetype = "application/json"
else:
mimetype = "text/plain"
return Response(message, status=status, mimetype=mimetype)
def write_content(self, content, content_type=None):
"""Helper method to write content bytes to output stream."""
if content_type is not None:
self.send_header(HTTP_HEADER_CONTENT_TYPE, content_type)
if 'gzip' in self.headers.get(HTTP_HEADER_ACCEPT_ENCODING, ''):
content = gzip.compress(content)
class HomeAssistantView(object):
"""Base view for all views."""
self.send_header(HTTP_HEADER_CONTENT_ENCODING, "gzip")
self.send_header(HTTP_HEADER_VARY, HTTP_HEADER_ACCEPT_ENCODING)
extra_urls = []
requires_auth = True # Views inheriting from this class can override this
self.send_header(HTTP_HEADER_CONTENT_LENGTH, str(len(content)))
def __init__(self, hass):
"""Initilalize the base view."""
from werkzeug.wrappers import Response
cors_check = (self.headers.get("Origin") in self.server.cors_origins)
self.hass = hass
# pylint: disable=invalid-name
self.Response = Response
cors_headers = ", ".join(ALLOWED_CORS_HEADERS)
if self.server.cors_origins and cors_check:
self.send_header(HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN,
self.headers.get("Origin"))
self.send_header(HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS,
cors_headers)
self.end_headers()
if self.command == 'HEAD':
return
self.wfile.write(content)
def set_cache_header(self):
"""Add cache headers if not in development."""
if self.server.development:
return
# 1 year in seconds
cache_time = 365 * 86400
self.send_header(
HTTP_HEADER_CACHE_CONTROL,
"public, max-age={}".format(cache_time))
self.send_header(
HTTP_HEADER_EXPIRES,
self.date_time_string(time.time()+cache_time))
def set_session_cookie_header(self):
"""Add the header for the session cookie and return session ID."""
if not self.authenticated:
return None
session_id = self.get_cookie_session_id()
if session_id is not None:
self.server.sessions.extend_validation(session_id)
return session_id
self.send_header(
'Set-Cookie',
'{}={}'.format(SESSION_KEY, self.server.sessions.create())
def handle_request(self, request, **values):
"""Handle request to url."""
from werkzeug.exceptions import (
MethodNotAllowed, Unauthorized, BadRequest,
)
return session_id
def verify_session(self):
"""Verify that we are in a valid session."""
return self.get_cookie_session_id() is not None
def get_cookie_session_id(self):
"""Extract the current session ID from the cookie.
Return None if not set or invalid.
"""
if 'Cookie' not in self.headers:
return None
cookie = cookies.SimpleCookie()
try:
cookie.load(self.headers["Cookie"])
except cookies.CookieError:
return None
handler = getattr(self, request.method.lower())
except AttributeError:
raise MethodNotAllowed
morsel = cookie.get(SESSION_KEY)
# TODO: session support + uncomment session test
if morsel is None:
return None
# Auth code verbose on purpose
authenticated = False
session_id = cookie[SESSION_KEY].value
if not self.requires_auth:
authenticated = True
if self.server.sessions.is_valid(session_id):
return session_id
elif self.hass.wsgi.api_password is None:
authenticated = True
return None
elif hmac.compare_digest(request.headers.get(HTTP_HEADER_HA_AUTH, ''),
self.hass.wsgi.api_password):
# A valid auth header has been set
authenticated = True
def destroy_session(self):
"""Destroy the session."""
session_id = self.get_cookie_session_id()
elif hmac.compare_digest(request.args.get(DATA_API_PASSWORD, ''),
self.hass.wsgi.api_password):
authenticated = True
if session_id is None:
return
else:
# Do we still want to support passing it in as post data?
try:
json_data = request.json
if (json_data is not None and
hmac.compare_digest(
json_data.get(DATA_API_PASSWORD, ''),
self.hass.wsgi.api_password)):
authenticated = True
except BadRequest:
pass
self.send_header('Set-Cookie', '')
self.server.sessions.destroy(session_id)
if not authenticated:
raise Unauthorized()
result = handler(request, **values)
def session_valid_time():
"""Time till when a session will be valid."""
return date_util.utcnow() + timedelta(seconds=SESSION_TIMEOUT_SECONDS)
if isinstance(result, self.Response):
# The method handler returned a ready-made Response, how nice of it
return result
status_code = 200
class SessionStore(object):
"""Responsible for storing and retrieving HTTP sessions."""
if isinstance(result, tuple):
result, status_code = result
def __init__(self):
"""Setup the session store."""
self._sessions = {}
self._lock = threading.RLock()
return self.Response(result, status=status_code)
@util.Throttle(SESSION_CLEAR_INTERVAL)
def _remove_expired(self):
"""Remove any expired sessions."""
now = date_util.utcnow()
for key in [key for key, valid_time in self._sessions.items()
if valid_time < now]:
self._sessions.pop(key)
def json(self, result, status_code=200):
"""Return a JSON response."""
msg = json.dumps(
result,
sort_keys=True,
cls=rem.JSONEncoder
).encode('UTF-8')
return self.Response(msg, mimetype="application/json",
status=status_code)
def is_valid(self, key):
"""Return True if a valid session is given."""
with self._lock:
self._remove_expired()
def json_message(self, error, status_code=200):
"""Return a JSON message response."""
return self.json({'message': error}, status_code)
return (key in self._sessions and
self._sessions[key] > date_util.utcnow())
def file(self, request, fil, content_type=None):
"""Return a file."""
from werkzeug.wsgi import wrap_file
from werkzeug.exceptions import NotFound
def extend_validation(self, key):
"""Extend a session validation time."""
with self._lock:
if key not in self._sessions:
return
self._sessions[key] = session_valid_time()
if isinstance(fil, str):
try:
fil = open(fil)
except IOError:
raise NotFound()
def destroy(self, key):
"""Destroy a session by key."""
with self._lock:
self._sessions.pop(key, None)
# TODO mimetypes, etc
def create(self):
"""Create a new session."""
with self._lock:
session_id = util.get_random_string(20)
while session_id in self._sessions:
session_id = util.get_random_string(20)
self._sessions[session_id] = session_valid_time()
return session_id
resp = self.Response(wrap_file(request.environ, fil))
if content_type is not None:
resp.mimetype = content_type
return resp