Clean up HTTP sessions and allow log out
This commit is contained in:
parent
99aa4307ef
commit
78cfed1fb0
4 changed files with 108 additions and 119 deletions
|
@ -18,7 +18,7 @@ from homeassistant.bootstrap import ERROR_LOG_FILENAME
|
|||
from homeassistant.const import (
|
||||
URL_API, URL_API_STATES, URL_API_EVENTS, URL_API_SERVICES, URL_API_STREAM,
|
||||
URL_API_EVENT_FORWARD, URL_API_STATES_ENTITY, URL_API_COMPONENTS,
|
||||
URL_API_CONFIG, URL_API_BOOTSTRAP, URL_API_ERROR_LOG,
|
||||
URL_API_CONFIG, URL_API_BOOTSTRAP, URL_API_ERROR_LOG, URL_API_LOG_OUT,
|
||||
EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP, MATCH_ALL,
|
||||
HTTP_OK, HTTP_CREATED, HTTP_BAD_REQUEST, HTTP_NOT_FOUND,
|
||||
HTTP_UNPROCESSABLE_ENTITY)
|
||||
|
@ -89,6 +89,8 @@ def setup(hass, config):
|
|||
hass.http.register_path('GET', URL_API_ERROR_LOG,
|
||||
_handle_get_api_error_log)
|
||||
|
||||
hass.http.register_path('POST', URL_API_LOG_OUT, _handle_post_api_log_out)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
@ -347,6 +349,13 @@ def _handle_get_api_error_log(handler, path_match, data):
|
|||
False)
|
||||
|
||||
|
||||
def _handle_post_api_log_out(handler, path_match, data):
|
||||
""" Log user out. """
|
||||
handler.send_response(HTTP_OK)
|
||||
handler.destroy_session()
|
||||
handler.end_headers()
|
||||
|
||||
|
||||
def _services_json(hass):
|
||||
""" Generate services data to JSONify. """
|
||||
return [{"domain": key, "services": value}
|
||||
|
|
|
@ -54,8 +54,7 @@ def setup(hass, config):
|
|||
|
||||
|
||||
def _handle_get_root(handler, path_match, data):
|
||||
""" Renders the debug interface. """
|
||||
|
||||
""" Renders the frontend. """
|
||||
handler.send_response(HTTP_OK)
|
||||
handler.send_header('Content-type', 'text/html; charset=utf-8')
|
||||
handler.end_headers()
|
||||
|
|
|
@ -12,10 +12,7 @@ import logging
|
|||
import time
|
||||
import gzip
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
from datetime import timedelta
|
||||
from homeassistant.util import Throttle
|
||||
from http.server import SimpleHTTPRequestHandler, HTTPServer
|
||||
from http import cookies
|
||||
from socketserver import ThreadingMixIn
|
||||
|
@ -44,40 +41,34 @@ CONF_SESSIONS_ENABLED = "sessions_enabled"
|
|||
DATA_API_PASSWORD = 'api_password'
|
||||
|
||||
# Throttling time in seconds for expired sessions check
|
||||
MIN_SEC_SESSION_CLEARING = timedelta(seconds=20)
|
||||
SESSION_CLEAR_INTERVAL = timedelta(seconds=20)
|
||||
SESSION_TIMEOUT_SECONDS = 1800
|
||||
SESSION_KEY = 'sessionId'
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup(hass, config=None):
|
||||
def setup(hass, config):
|
||||
""" Sets up the HTTP API and debug interface. """
|
||||
if config is None or DOMAIN not in config:
|
||||
config = {DOMAIN: {}}
|
||||
|
||||
api_password = util.convert(config[DOMAIN].get(CONF_API_PASSWORD), str)
|
||||
conf = config[DOMAIN]
|
||||
|
||||
api_password = util.convert(conf.get(CONF_API_PASSWORD), str)
|
||||
no_password_set = api_password is None
|
||||
|
||||
if no_password_set:
|
||||
api_password = util.get_random_string()
|
||||
|
||||
# If no server host is given, accept all incoming requests
|
||||
server_host = config[DOMAIN].get(CONF_SERVER_HOST, '0.0.0.0')
|
||||
|
||||
server_port = config[DOMAIN].get(CONF_SERVER_PORT, SERVER_PORT)
|
||||
|
||||
development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1"
|
||||
|
||||
sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, True)
|
||||
server_host = conf.get(CONF_SERVER_HOST, '0.0.0.0')
|
||||
server_port = conf.get(CONF_SERVER_PORT, SERVER_PORT)
|
||||
development = str(conf.get(CONF_DEVELOPMENT, "")) == "1"
|
||||
|
||||
try:
|
||||
server = HomeAssistantHTTPServer(
|
||||
(server_host, server_port), RequestHandler, hass, api_password,
|
||||
development, no_password_set, sessions_enabled)
|
||||
development, no_password_set)
|
||||
except OSError:
|
||||
# Happens if address already in use
|
||||
# If address already in use
|
||||
_LOGGER.exception("Error setting up HTTP server")
|
||||
return False
|
||||
|
||||
|
@ -102,8 +93,7 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
|||
|
||||
# pylint: disable=too-many-arguments
|
||||
def __init__(self, server_address, request_handler_class,
|
||||
hass, api_password, development, no_password_set,
|
||||
sessions_enabled):
|
||||
hass, api_password, development, no_password_set):
|
||||
super().__init__(server_address, request_handler_class)
|
||||
|
||||
self.server_address = server_address
|
||||
|
@ -112,7 +102,7 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
|||
self.development = development
|
||||
self.no_password_set = no_password_set
|
||||
self.paths = []
|
||||
self.sessions = SessionStore(sessions_enabled)
|
||||
self.sessions = SessionStore()
|
||||
|
||||
# We will lazy init this one if needed
|
||||
self.event_forwarder = None
|
||||
|
@ -161,7 +151,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||
|
||||
def __init__(self, req, client_addr, server):
|
||||
""" Contructor, call the base constructor and set up session """
|
||||
self._session = None
|
||||
# Track if this was an authenticated request
|
||||
self.authenticated = False
|
||||
SimpleHTTPRequestHandler.__init__(self, req, client_addr, server)
|
||||
|
||||
def log_message(self, fmt, *arguments):
|
||||
|
@ -201,18 +192,18 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||
"Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY)
|
||||
return
|
||||
|
||||
self._session = self.get_session()
|
||||
if self.server.no_password_set:
|
||||
api_password = self.server.api_password
|
||||
else:
|
||||
_LOGGER.warning('NO PASSWORD SET')
|
||||
self.authenticated = True
|
||||
elif HTTP_HEADER_HA_AUTH in self.headers:
|
||||
api_password = self.headers.get(HTTP_HEADER_HA_AUTH)
|
||||
|
||||
if not api_password and DATA_API_PASSWORD in data:
|
||||
api_password = data[DATA_API_PASSWORD]
|
||||
|
||||
if not api_password and self._session is not None:
|
||||
api_password = self._session.cookie_values.get(
|
||||
CONF_API_PASSWORD)
|
||||
self.authenticated = api_password == self.server.api_password
|
||||
else:
|
||||
self.authenticated = self.verify_session()
|
||||
|
||||
if '_METHOD' in data:
|
||||
method = data.pop('_METHOD')
|
||||
|
@ -245,18 +236,13 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||
|
||||
# Did we find a handler for the incoming request?
|
||||
if handle_request_method:
|
||||
|
||||
# For some calls we need a valid password
|
||||
if require_auth and api_password != self.server.api_password:
|
||||
if require_auth and not self.authenticated:
|
||||
self.write_json_message(
|
||||
"API password missing or incorrect.", HTTP_UNAUTHORIZED)
|
||||
return
|
||||
|
||||
else:
|
||||
if self._session is None and require_auth:
|
||||
self._session = self.server.sessions.create(
|
||||
api_password)
|
||||
|
||||
handle_request_method(self, path_match, data)
|
||||
handle_request_method(self, path_match, data)
|
||||
|
||||
elif path_matched_but_not_method:
|
||||
self.send_response(HTTP_METHOD_NOT_ALLOWED)
|
||||
|
@ -369,63 +355,62 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||
self.date_time_string(time.time()+cache_time))
|
||||
|
||||
def set_session_cookie_header(self):
|
||||
""" Add the header for the session cookie """
|
||||
if self.server.sessions.enabled and self._session is not None:
|
||||
existing_sess_id = self.get_current_session_id()
|
||||
""" Add the header for the session cookie. """
|
||||
if not self.authenticated:
|
||||
return
|
||||
|
||||
if existing_sess_id != self._session.session_id:
|
||||
self.send_header(
|
||||
'Set-Cookie',
|
||||
SESSION_KEY+'='+self._session.session_id)
|
||||
current = self.get_cookie_session_id()
|
||||
|
||||
def get_session(self):
|
||||
""" Get the requested session object from cookie value """
|
||||
if self.server.sessions.enabled is not True:
|
||||
return None
|
||||
if current is not None:
|
||||
self.server.sessions.extend_validation(current)
|
||||
return
|
||||
|
||||
session_id = self.get_current_session_id()
|
||||
if session_id is not None:
|
||||
session = self.server.sessions.get(session_id)
|
||||
if session is not None:
|
||||
session.reset_expiry()
|
||||
return session
|
||||
self.send_header(
|
||||
'Set-Cookie',
|
||||
'{}={}'.format(SESSION_KEY, self.server.sessions.create())
|
||||
)
|
||||
|
||||
return None
|
||||
def verify_session(self):
|
||||
""" Verify that we are in a valid session. """
|
||||
return self.get_cookie_session_id() is not None
|
||||
|
||||
def get_current_session_id(self):
|
||||
def get_cookie_session_id(self):
|
||||
"""
|
||||
Extracts the current session id from the
|
||||
cookie or returns None if not set
|
||||
cookie or returns None if not set or invalid
|
||||
"""
|
||||
if 'Cookie' not in self.headers:
|
||||
return None
|
||||
|
||||
cookie = cookies.SimpleCookie()
|
||||
try:
|
||||
cookie.load(self.headers["Cookie"])
|
||||
except cookies.CookieError:
|
||||
return None
|
||||
|
||||
if self.headers.get('Cookie', None) is not None:
|
||||
cookie.load(self.headers.get("Cookie"))
|
||||
morsel = cookie.get(SESSION_KEY)
|
||||
|
||||
if cookie.get(SESSION_KEY, False):
|
||||
return cookie[SESSION_KEY].value
|
||||
if morsel is None:
|
||||
return None
|
||||
|
||||
return None
|
||||
current = cookie[SESSION_KEY].value
|
||||
|
||||
return current if self.server.sessions.is_valid(current) else None
|
||||
|
||||
def destroy_session(self):
|
||||
""" Destroys session. """
|
||||
current = self.get_cookie_session_id()
|
||||
|
||||
if current is None:
|
||||
return
|
||||
|
||||
self.send_header('Set-Cookie', '')
|
||||
self.server.sessions.destroy(current)
|
||||
|
||||
|
||||
class ServerSession:
|
||||
""" A very simple session class """
|
||||
def __init__(self, session_id):
|
||||
""" Set up the expiry time on creation """
|
||||
self._expiry = 0
|
||||
self.reset_expiry()
|
||||
self.cookie_values = {}
|
||||
self.session_id = session_id
|
||||
|
||||
def reset_expiry(self):
|
||||
""" Resets the expiry based on current time """
|
||||
self._expiry = date_util.utcnow() + timedelta(
|
||||
seconds=SESSION_TIMEOUT_SECONDS)
|
||||
|
||||
@property
|
||||
def is_expired(self):
|
||||
""" Return true if the session is expired based on the expiry time """
|
||||
return self._expiry < date_util.utcnow()
|
||||
def session_valid_time():
|
||||
""" Time till when a session will be valid. """
|
||||
return date_util.utcnow() + timedelta(seconds=SESSION_TIMEOUT_SECONDS)
|
||||
|
||||
|
||||
class SessionStore(object):
|
||||
|
@ -433,47 +418,42 @@ class SessionStore(object):
|
|||
def __init__(self, enabled=True):
|
||||
""" Set up the session store """
|
||||
self._sessions = {}
|
||||
self.enabled = enabled
|
||||
self.session_lock = threading.RLock()
|
||||
self.lock = threading.RLock()
|
||||
|
||||
@Throttle(MIN_SEC_SESSION_CLEARING)
|
||||
def remove_expired(self):
|
||||
@util.Throttle(SESSION_CLEAR_INTERVAL)
|
||||
def _remove_expired(self):
|
||||
""" Remove any expired sessions. """
|
||||
if self.session_lock.acquire(False):
|
||||
try:
|
||||
keys = []
|
||||
for key in self._sessions.keys():
|
||||
keys.append(key)
|
||||
now = date_util.utcnow()
|
||||
for key in [key for key, valid_time in self._sessions.items()
|
||||
if valid_time < now]:
|
||||
self._sessions.pop(key)
|
||||
|
||||
for key in keys:
|
||||
if self._sessions[key].is_expired:
|
||||
del self._sessions[key]
|
||||
_LOGGER.info("Cleared expired session %s", key)
|
||||
finally:
|
||||
self.session_lock.release()
|
||||
def is_valid(self, key):
|
||||
""" Return True if a valid session is given. """
|
||||
with self.lock:
|
||||
self._remove_expired()
|
||||
|
||||
def add(self, key, session):
|
||||
""" Add a new session to the list of tracked sessions """
|
||||
self.remove_expired()
|
||||
with self.session_lock:
|
||||
self._sessions[key] = session
|
||||
return (key in self._sessions and
|
||||
self._sessions[key] > date_util.utcnow())
|
||||
|
||||
def get(self, key):
|
||||
""" get a session by key """
|
||||
self.remove_expired()
|
||||
session = self._sessions.get(key, None)
|
||||
if session is not None and session.is_expired:
|
||||
return None
|
||||
return session
|
||||
def extend_validation(self, key):
|
||||
""" Extend a session validation time. """
|
||||
with self.lock:
|
||||
self._sessions[key] = session_valid_time()
|
||||
|
||||
def create(self, api_password):
|
||||
""" Creates a new session and adds it to the sessions """
|
||||
if self.enabled is not True:
|
||||
return None
|
||||
def destroy(self, key):
|
||||
""" Destroy a session by key. """
|
||||
with self.lock:
|
||||
self._sessions.pop(key, None)
|
||||
|
||||
chars = string.ascii_letters + string.digits
|
||||
session_id = ''.join([random.choice(chars) for i in range(20)])
|
||||
session = ServerSession(session_id)
|
||||
session.cookie_values[CONF_API_PASSWORD] = api_password
|
||||
self.add(session_id, session)
|
||||
return session
|
||||
def create(self):
|
||||
""" Creates a new session. """
|
||||
with self.lock:
|
||||
session_id = util.get_random_string(20)
|
||||
|
||||
while session_id in self._sessions:
|
||||
session_id = util.get_random_string(20)
|
||||
|
||||
self._sessions[session_id] = session_valid_time()
|
||||
|
||||
return session_id
|
||||
|
|
|
@ -164,6 +164,7 @@ URL_API_EVENT_FORWARD = "/api/event_forwarding"
|
|||
URL_API_COMPONENTS = "/api/components"
|
||||
URL_API_BOOTSTRAP = "/api/bootstrap"
|
||||
URL_API_ERROR_LOG = "/api/error_log"
|
||||
URL_API_LOG_OUT = "/api/log_out"
|
||||
|
||||
HTTP_OK = 200
|
||||
HTTP_CREATED = 201
|
||||
|
|
Loading…
Add table
Reference in a new issue