Clean up HTTP sessions and allow log out
This commit is contained in:
parent
99aa4307ef
commit
78cfed1fb0
4 changed files with 108 additions and 119 deletions
|
@ -18,7 +18,7 @@ from homeassistant.bootstrap import ERROR_LOG_FILENAME
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
URL_API, URL_API_STATES, URL_API_EVENTS, URL_API_SERVICES, URL_API_STREAM,
|
URL_API, URL_API_STATES, URL_API_EVENTS, URL_API_SERVICES, URL_API_STREAM,
|
||||||
URL_API_EVENT_FORWARD, URL_API_STATES_ENTITY, URL_API_COMPONENTS,
|
URL_API_EVENT_FORWARD, URL_API_STATES_ENTITY, URL_API_COMPONENTS,
|
||||||
URL_API_CONFIG, URL_API_BOOTSTRAP, URL_API_ERROR_LOG,
|
URL_API_CONFIG, URL_API_BOOTSTRAP, URL_API_ERROR_LOG, URL_API_LOG_OUT,
|
||||||
EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP, MATCH_ALL,
|
EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP, MATCH_ALL,
|
||||||
HTTP_OK, HTTP_CREATED, HTTP_BAD_REQUEST, HTTP_NOT_FOUND,
|
HTTP_OK, HTTP_CREATED, HTTP_BAD_REQUEST, HTTP_NOT_FOUND,
|
||||||
HTTP_UNPROCESSABLE_ENTITY)
|
HTTP_UNPROCESSABLE_ENTITY)
|
||||||
|
@ -89,6 +89,8 @@ def setup(hass, config):
|
||||||
hass.http.register_path('GET', URL_API_ERROR_LOG,
|
hass.http.register_path('GET', URL_API_ERROR_LOG,
|
||||||
_handle_get_api_error_log)
|
_handle_get_api_error_log)
|
||||||
|
|
||||||
|
hass.http.register_path('POST', URL_API_LOG_OUT, _handle_post_api_log_out)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@ -347,6 +349,13 @@ def _handle_get_api_error_log(handler, path_match, data):
|
||||||
False)
|
False)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_post_api_log_out(handler, path_match, data):
|
||||||
|
""" Log user out. """
|
||||||
|
handler.send_response(HTTP_OK)
|
||||||
|
handler.destroy_session()
|
||||||
|
handler.end_headers()
|
||||||
|
|
||||||
|
|
||||||
def _services_json(hass):
|
def _services_json(hass):
|
||||||
""" Generate services data to JSONify. """
|
""" Generate services data to JSONify. """
|
||||||
return [{"domain": key, "services": value}
|
return [{"domain": key, "services": value}
|
||||||
|
|
|
@ -54,8 +54,7 @@ def setup(hass, config):
|
||||||
|
|
||||||
|
|
||||||
def _handle_get_root(handler, path_match, data):
|
def _handle_get_root(handler, path_match, data):
|
||||||
""" Renders the debug interface. """
|
""" Renders the frontend. """
|
||||||
|
|
||||||
handler.send_response(HTTP_OK)
|
handler.send_response(HTTP_OK)
|
||||||
handler.send_header('Content-type', 'text/html; charset=utf-8')
|
handler.send_header('Content-type', 'text/html; charset=utf-8')
|
||||||
handler.end_headers()
|
handler.end_headers()
|
||||||
|
|
|
@ -12,10 +12,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
import gzip
|
import gzip
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import string
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from homeassistant.util import Throttle
|
|
||||||
from http.server import SimpleHTTPRequestHandler, HTTPServer
|
from http.server import SimpleHTTPRequestHandler, HTTPServer
|
||||||
from http import cookies
|
from http import cookies
|
||||||
from socketserver import ThreadingMixIn
|
from socketserver import ThreadingMixIn
|
||||||
|
@ -44,40 +41,34 @@ CONF_SESSIONS_ENABLED = "sessions_enabled"
|
||||||
DATA_API_PASSWORD = 'api_password'
|
DATA_API_PASSWORD = 'api_password'
|
||||||
|
|
||||||
# Throttling time in seconds for expired sessions check
|
# Throttling time in seconds for expired sessions check
|
||||||
MIN_SEC_SESSION_CLEARING = timedelta(seconds=20)
|
SESSION_CLEAR_INTERVAL = timedelta(seconds=20)
|
||||||
SESSION_TIMEOUT_SECONDS = 1800
|
SESSION_TIMEOUT_SECONDS = 1800
|
||||||
SESSION_KEY = 'sessionId'
|
SESSION_KEY = 'sessionId'
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def setup(hass, config=None):
|
def setup(hass, config):
|
||||||
""" Sets up the HTTP API and debug interface. """
|
""" Sets up the HTTP API and debug interface. """
|
||||||
if config is None or DOMAIN not in config:
|
conf = config[DOMAIN]
|
||||||
config = {DOMAIN: {}}
|
|
||||||
|
|
||||||
api_password = util.convert(config[DOMAIN].get(CONF_API_PASSWORD), str)
|
|
||||||
|
|
||||||
|
api_password = util.convert(conf.get(CONF_API_PASSWORD), str)
|
||||||
no_password_set = api_password is None
|
no_password_set = api_password is None
|
||||||
|
|
||||||
if no_password_set:
|
if no_password_set:
|
||||||
api_password = util.get_random_string()
|
api_password = util.get_random_string()
|
||||||
|
|
||||||
# If no server host is given, accept all incoming requests
|
# If no server host is given, accept all incoming requests
|
||||||
server_host = config[DOMAIN].get(CONF_SERVER_HOST, '0.0.0.0')
|
server_host = conf.get(CONF_SERVER_HOST, '0.0.0.0')
|
||||||
|
server_port = conf.get(CONF_SERVER_PORT, SERVER_PORT)
|
||||||
server_port = config[DOMAIN].get(CONF_SERVER_PORT, SERVER_PORT)
|
development = str(conf.get(CONF_DEVELOPMENT, "")) == "1"
|
||||||
|
|
||||||
development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1"
|
|
||||||
|
|
||||||
sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, True)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
server = HomeAssistantHTTPServer(
|
server = HomeAssistantHTTPServer(
|
||||||
(server_host, server_port), RequestHandler, hass, api_password,
|
(server_host, server_port), RequestHandler, hass, api_password,
|
||||||
development, no_password_set, sessions_enabled)
|
development, no_password_set)
|
||||||
except OSError:
|
except OSError:
|
||||||
# Happens if address already in use
|
# If address already in use
|
||||||
_LOGGER.exception("Error setting up HTTP server")
|
_LOGGER.exception("Error setting up HTTP server")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -102,8 +93,7 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
||||||
|
|
||||||
# pylint: disable=too-many-arguments
|
# pylint: disable=too-many-arguments
|
||||||
def __init__(self, server_address, request_handler_class,
|
def __init__(self, server_address, request_handler_class,
|
||||||
hass, api_password, development, no_password_set,
|
hass, api_password, development, no_password_set):
|
||||||
sessions_enabled):
|
|
||||||
super().__init__(server_address, request_handler_class)
|
super().__init__(server_address, request_handler_class)
|
||||||
|
|
||||||
self.server_address = server_address
|
self.server_address = server_address
|
||||||
|
@ -112,7 +102,7 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
||||||
self.development = development
|
self.development = development
|
||||||
self.no_password_set = no_password_set
|
self.no_password_set = no_password_set
|
||||||
self.paths = []
|
self.paths = []
|
||||||
self.sessions = SessionStore(sessions_enabled)
|
self.sessions = SessionStore()
|
||||||
|
|
||||||
# We will lazy init this one if needed
|
# We will lazy init this one if needed
|
||||||
self.event_forwarder = None
|
self.event_forwarder = None
|
||||||
|
@ -161,7 +151,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
||||||
|
|
||||||
def __init__(self, req, client_addr, server):
|
def __init__(self, req, client_addr, server):
|
||||||
""" Contructor, call the base constructor and set up session """
|
""" Contructor, call the base constructor and set up session """
|
||||||
self._session = None
|
# Track if this was an authenticated request
|
||||||
|
self.authenticated = False
|
||||||
SimpleHTTPRequestHandler.__init__(self, req, client_addr, server)
|
SimpleHTTPRequestHandler.__init__(self, req, client_addr, server)
|
||||||
|
|
||||||
def log_message(self, fmt, *arguments):
|
def log_message(self, fmt, *arguments):
|
||||||
|
@ -201,18 +192,18 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
||||||
"Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY)
|
"Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY)
|
||||||
return
|
return
|
||||||
|
|
||||||
self._session = self.get_session()
|
|
||||||
if self.server.no_password_set:
|
if self.server.no_password_set:
|
||||||
api_password = self.server.api_password
|
_LOGGER.warning('NO PASSWORD SET')
|
||||||
else:
|
self.authenticated = True
|
||||||
|
elif HTTP_HEADER_HA_AUTH in self.headers:
|
||||||
api_password = self.headers.get(HTTP_HEADER_HA_AUTH)
|
api_password = self.headers.get(HTTP_HEADER_HA_AUTH)
|
||||||
|
|
||||||
if not api_password and DATA_API_PASSWORD in data:
|
if not api_password and DATA_API_PASSWORD in data:
|
||||||
api_password = data[DATA_API_PASSWORD]
|
api_password = data[DATA_API_PASSWORD]
|
||||||
|
|
||||||
if not api_password and self._session is not None:
|
self.authenticated = api_password == self.server.api_password
|
||||||
api_password = self._session.cookie_values.get(
|
else:
|
||||||
CONF_API_PASSWORD)
|
self.authenticated = self.verify_session()
|
||||||
|
|
||||||
if '_METHOD' in data:
|
if '_METHOD' in data:
|
||||||
method = data.pop('_METHOD')
|
method = data.pop('_METHOD')
|
||||||
|
@ -245,18 +236,13 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
||||||
|
|
||||||
# Did we find a handler for the incoming request?
|
# Did we find a handler for the incoming request?
|
||||||
if handle_request_method:
|
if handle_request_method:
|
||||||
|
|
||||||
# For some calls we need a valid password
|
# For some calls we need a valid password
|
||||||
if require_auth and api_password != self.server.api_password:
|
if require_auth and not self.authenticated:
|
||||||
self.write_json_message(
|
self.write_json_message(
|
||||||
"API password missing or incorrect.", HTTP_UNAUTHORIZED)
|
"API password missing or incorrect.", HTTP_UNAUTHORIZED)
|
||||||
|
return
|
||||||
|
|
||||||
else:
|
handle_request_method(self, path_match, data)
|
||||||
if self._session is None and require_auth:
|
|
||||||
self._session = self.server.sessions.create(
|
|
||||||
api_password)
|
|
||||||
|
|
||||||
handle_request_method(self, path_match, data)
|
|
||||||
|
|
||||||
elif path_matched_but_not_method:
|
elif path_matched_but_not_method:
|
||||||
self.send_response(HTTP_METHOD_NOT_ALLOWED)
|
self.send_response(HTTP_METHOD_NOT_ALLOWED)
|
||||||
|
@ -369,63 +355,62 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
||||||
self.date_time_string(time.time()+cache_time))
|
self.date_time_string(time.time()+cache_time))
|
||||||
|
|
||||||
def set_session_cookie_header(self):
|
def set_session_cookie_header(self):
|
||||||
""" Add the header for the session cookie """
|
""" Add the header for the session cookie. """
|
||||||
if self.server.sessions.enabled and self._session is not None:
|
if not self.authenticated:
|
||||||
existing_sess_id = self.get_current_session_id()
|
return
|
||||||
|
|
||||||
if existing_sess_id != self._session.session_id:
|
current = self.get_cookie_session_id()
|
||||||
self.send_header(
|
|
||||||
'Set-Cookie',
|
|
||||||
SESSION_KEY+'='+self._session.session_id)
|
|
||||||
|
|
||||||
def get_session(self):
|
if current is not None:
|
||||||
""" Get the requested session object from cookie value """
|
self.server.sessions.extend_validation(current)
|
||||||
if self.server.sessions.enabled is not True:
|
return
|
||||||
return None
|
|
||||||
|
|
||||||
session_id = self.get_current_session_id()
|
self.send_header(
|
||||||
if session_id is not None:
|
'Set-Cookie',
|
||||||
session = self.server.sessions.get(session_id)
|
'{}={}'.format(SESSION_KEY, self.server.sessions.create())
|
||||||
if session is not None:
|
)
|
||||||
session.reset_expiry()
|
|
||||||
return session
|
|
||||||
|
|
||||||
return None
|
def verify_session(self):
|
||||||
|
""" Verify that we are in a valid session. """
|
||||||
|
return self.get_cookie_session_id() is not None
|
||||||
|
|
||||||
def get_current_session_id(self):
|
def get_cookie_session_id(self):
|
||||||
"""
|
"""
|
||||||
Extracts the current session id from the
|
Extracts the current session id from the
|
||||||
cookie or returns None if not set
|
cookie or returns None if not set or invalid
|
||||||
"""
|
"""
|
||||||
|
if 'Cookie' not in self.headers:
|
||||||
|
return None
|
||||||
|
|
||||||
cookie = cookies.SimpleCookie()
|
cookie = cookies.SimpleCookie()
|
||||||
|
try:
|
||||||
|
cookie.load(self.headers["Cookie"])
|
||||||
|
except cookies.CookieError:
|
||||||
|
return None
|
||||||
|
|
||||||
if self.headers.get('Cookie', None) is not None:
|
morsel = cookie.get(SESSION_KEY)
|
||||||
cookie.load(self.headers.get("Cookie"))
|
|
||||||
|
|
||||||
if cookie.get(SESSION_KEY, False):
|
if morsel is None:
|
||||||
return cookie[SESSION_KEY].value
|
return None
|
||||||
|
|
||||||
return None
|
current = cookie[SESSION_KEY].value
|
||||||
|
|
||||||
|
return current if self.server.sessions.is_valid(current) else None
|
||||||
|
|
||||||
|
def destroy_session(self):
|
||||||
|
""" Destroys session. """
|
||||||
|
current = self.get_cookie_session_id()
|
||||||
|
|
||||||
|
if current is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.send_header('Set-Cookie', '')
|
||||||
|
self.server.sessions.destroy(current)
|
||||||
|
|
||||||
|
|
||||||
class ServerSession:
|
def session_valid_time():
|
||||||
""" A very simple session class """
|
""" Time till when a session will be valid. """
|
||||||
def __init__(self, session_id):
|
return date_util.utcnow() + timedelta(seconds=SESSION_TIMEOUT_SECONDS)
|
||||||
""" Set up the expiry time on creation """
|
|
||||||
self._expiry = 0
|
|
||||||
self.reset_expiry()
|
|
||||||
self.cookie_values = {}
|
|
||||||
self.session_id = session_id
|
|
||||||
|
|
||||||
def reset_expiry(self):
|
|
||||||
""" Resets the expiry based on current time """
|
|
||||||
self._expiry = date_util.utcnow() + timedelta(
|
|
||||||
seconds=SESSION_TIMEOUT_SECONDS)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_expired(self):
|
|
||||||
""" Return true if the session is expired based on the expiry time """
|
|
||||||
return self._expiry < date_util.utcnow()
|
|
||||||
|
|
||||||
|
|
||||||
class SessionStore(object):
|
class SessionStore(object):
|
||||||
|
@ -433,47 +418,42 @@ class SessionStore(object):
|
||||||
def __init__(self, enabled=True):
|
def __init__(self, enabled=True):
|
||||||
""" Set up the session store """
|
""" Set up the session store """
|
||||||
self._sessions = {}
|
self._sessions = {}
|
||||||
self.enabled = enabled
|
self.lock = threading.RLock()
|
||||||
self.session_lock = threading.RLock()
|
|
||||||
|
|
||||||
@Throttle(MIN_SEC_SESSION_CLEARING)
|
@util.Throttle(SESSION_CLEAR_INTERVAL)
|
||||||
def remove_expired(self):
|
def _remove_expired(self):
|
||||||
""" Remove any expired sessions. """
|
""" Remove any expired sessions. """
|
||||||
if self.session_lock.acquire(False):
|
now = date_util.utcnow()
|
||||||
try:
|
for key in [key for key, valid_time in self._sessions.items()
|
||||||
keys = []
|
if valid_time < now]:
|
||||||
for key in self._sessions.keys():
|
self._sessions.pop(key)
|
||||||
keys.append(key)
|
|
||||||
|
|
||||||
for key in keys:
|
def is_valid(self, key):
|
||||||
if self._sessions[key].is_expired:
|
""" Return True if a valid session is given. """
|
||||||
del self._sessions[key]
|
with self.lock:
|
||||||
_LOGGER.info("Cleared expired session %s", key)
|
self._remove_expired()
|
||||||
finally:
|
|
||||||
self.session_lock.release()
|
|
||||||
|
|
||||||
def add(self, key, session):
|
return (key in self._sessions and
|
||||||
""" Add a new session to the list of tracked sessions """
|
self._sessions[key] > date_util.utcnow())
|
||||||
self.remove_expired()
|
|
||||||
with self.session_lock:
|
|
||||||
self._sessions[key] = session
|
|
||||||
|
|
||||||
def get(self, key):
|
def extend_validation(self, key):
|
||||||
""" get a session by key """
|
""" Extend a session validation time. """
|
||||||
self.remove_expired()
|
with self.lock:
|
||||||
session = self._sessions.get(key, None)
|
self._sessions[key] = session_valid_time()
|
||||||
if session is not None and session.is_expired:
|
|
||||||
return None
|
|
||||||
return session
|
|
||||||
|
|
||||||
def create(self, api_password):
|
def destroy(self, key):
|
||||||
""" Creates a new session and adds it to the sessions """
|
""" Destroy a session by key. """
|
||||||
if self.enabled is not True:
|
with self.lock:
|
||||||
return None
|
self._sessions.pop(key, None)
|
||||||
|
|
||||||
chars = string.ascii_letters + string.digits
|
def create(self):
|
||||||
session_id = ''.join([random.choice(chars) for i in range(20)])
|
""" Creates a new session. """
|
||||||
session = ServerSession(session_id)
|
with self.lock:
|
||||||
session.cookie_values[CONF_API_PASSWORD] = api_password
|
session_id = util.get_random_string(20)
|
||||||
self.add(session_id, session)
|
|
||||||
return session
|
while session_id in self._sessions:
|
||||||
|
session_id = util.get_random_string(20)
|
||||||
|
|
||||||
|
self._sessions[session_id] = session_valid_time()
|
||||||
|
|
||||||
|
return session_id
|
||||||
|
|
|
@ -164,6 +164,7 @@ URL_API_EVENT_FORWARD = "/api/event_forwarding"
|
||||||
URL_API_COMPONENTS = "/api/components"
|
URL_API_COMPONENTS = "/api/components"
|
||||||
URL_API_BOOTSTRAP = "/api/bootstrap"
|
URL_API_BOOTSTRAP = "/api/bootstrap"
|
||||||
URL_API_ERROR_LOG = "/api/error_log"
|
URL_API_ERROR_LOG = "/api/error_log"
|
||||||
|
URL_API_LOG_OUT = "/api/log_out"
|
||||||
|
|
||||||
HTTP_OK = 200
|
HTTP_OK = 200
|
||||||
HTTP_CREATED = 201
|
HTTP_CREATED = 201
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue