Remove strict connection (#117933)

This commit is contained in:
Robert Resch 2024-05-24 15:50:22 +02:00 committed by GitHub
parent 6f81852eb4
commit cb62f4242e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 39 additions and 1816 deletions

View file

@ -10,8 +10,7 @@ import os
import socket
import ssl
from tempfile import NamedTemporaryFile
from typing import Any, Final, Required, TypedDict, cast
from urllib.parse import quote_plus, urljoin
from typing import Any, Final, TypedDict, cast
from aiohttp import web
from aiohttp.abc import AbstractStreamWriter
@ -30,20 +29,8 @@ from yarl import URL
from homeassistant.components.network import async_get_source_ip
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, SERVER_PORT
from homeassistant.core import (
Event,
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
callback,
)
from homeassistant.exceptions import (
HomeAssistantError,
ServiceValidationError,
Unauthorized,
UnknownUser,
)
from homeassistant.core import Event, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import storage
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.http import (
@ -66,14 +53,9 @@ from homeassistant.util import dt as dt_util, ssl as ssl_util
from homeassistant.util.async_ import create_eager_task
from homeassistant.util.json import json_loads
from .auth import async_setup_auth, async_sign_path
from .auth import async_setup_auth
from .ban import setup_bans
from .const import ( # noqa: F401
DOMAIN,
KEY_HASS_REFRESH_TOKEN_ID,
KEY_HASS_USER,
StrictConnectionMode,
)
from .const import DOMAIN, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER # noqa: F401
from .cors import setup_cors
from .decorators import require_admin # noqa: F401
from .forwarded import async_setup_forwarded
@ -96,7 +78,6 @@ CONF_TRUSTED_PROXIES: Final = "trusted_proxies"
CONF_LOGIN_ATTEMPTS_THRESHOLD: Final = "login_attempts_threshold"
CONF_IP_BAN_ENABLED: Final = "ip_ban_enabled"
CONF_SSL_PROFILE: Final = "ssl_profile"
CONF_STRICT_CONNECTION: Final = "strict_connection"
SSL_MODERN: Final = "modern"
SSL_INTERMEDIATE: Final = "intermediate"
@ -146,9 +127,6 @@ HTTP_SCHEMA: Final = vol.All(
[SSL_INTERMEDIATE, SSL_MODERN]
),
vol.Optional(CONF_USE_X_FRAME_OPTIONS, default=True): cv.boolean,
vol.Optional(
CONF_STRICT_CONNECTION, default=StrictConnectionMode.DISABLED
): vol.Coerce(StrictConnectionMode),
}
),
)
@ -172,7 +150,6 @@ class ConfData(TypedDict, total=False):
login_attempts_threshold: int
ip_ban_enabled: bool
ssl_profile: str
strict_connection: Required[StrictConnectionMode]
@bind_hass
@ -241,7 +218,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
login_threshold=login_threshold,
is_ban_enabled=is_ban_enabled,
use_x_frame_options=use_x_frame_options,
strict_connection_non_cloud=conf[CONF_STRICT_CONNECTION],
)
async def stop_server(event: Event) -> None:
@ -271,7 +247,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
local_ip, host, server_port, ssl_certificate is not None
)
_setup_services(hass, conf)
return True
@ -356,7 +331,6 @@ class HomeAssistantHTTP:
login_threshold: int,
is_ban_enabled: bool,
use_x_frame_options: bool,
strict_connection_non_cloud: StrictConnectionMode,
) -> None:
"""Initialize the server."""
self.app[KEY_HASS] = self.hass
@ -373,7 +347,7 @@ class HomeAssistantHTTP:
if is_ban_enabled:
setup_bans(self.hass, self.app, login_threshold)
await async_setup_auth(self.hass, self.app, strict_connection_non_cloud)
await async_setup_auth(self.hass, self.app)
setup_headers(self.app, use_x_frame_options)
setup_cors(self.app, cors_origins)
@ -602,61 +576,3 @@ async def start_http_server_and_save_config(
]
store.async_delay_save(lambda: conf, SAVE_DELAY)
@callback
def _setup_services(hass: HomeAssistant, conf: ConfData) -> None:
"""Set up services for HTTP component."""
async def create_temporary_strict_connection_url(
call: ServiceCall,
) -> ServiceResponse:
"""Create a strict connection url and return it."""
# Copied form homeassistant/helpers/service.py#_async_admin_handler
# as the helper supports no responses yet
if call.context.user_id:
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(context=call.context)
if not user.is_admin:
raise Unauthorized(context=call.context)
if conf[CONF_STRICT_CONNECTION] is StrictConnectionMode.DISABLED:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="strict_connection_not_enabled_non_cloud",
)
try:
url = get_url(
hass, prefer_external=True, allow_internal=False, allow_cloud=False
)
except NoURLAvailableError as ex:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="no_external_url_available",
) from ex
# to avoid circular import
# pylint: disable-next=import-outside-toplevel
from homeassistant.components.auth import STRICT_CONNECTION_URL
path = async_sign_path(
hass,
STRICT_CONNECTION_URL,
datetime.timedelta(hours=1),
use_content_user=True,
)
url = urljoin(url, path)
return {
"url": f"https://login.home-assistant.io?u={quote_plus(url)}",
"direct_url": url,
}
hass.services.async_register(
DOMAIN,
"create_temporary_strict_connection_url",
create_temporary_strict_connection_url,
supports_response=SupportsResponse.ONLY,
)

View file

@ -4,18 +4,14 @@ from __future__ import annotations
from collections.abc import Awaitable, Callable
from datetime import timedelta
from http import HTTPStatus
from ipaddress import ip_address
import logging
import os
import secrets
import time
from typing import Any, Final
from aiohttp import hdrs
from aiohttp.web import Application, Request, Response, StreamResponse, middleware
from aiohttp.web_exceptions import HTTPBadRequest
from aiohttp_session import session_middleware
from aiohttp.web import Application, Request, StreamResponse, middleware
import jwt
from jwt import api_jws
from yarl import URL
@ -25,21 +21,13 @@ from homeassistant.auth.const import GROUP_ID_READ_ONLY
from homeassistant.auth.models import User
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import singleton
from homeassistant.helpers.http import current_request
from homeassistant.helpers.json import json_bytes
from homeassistant.helpers.network import is_cloud_connection
from homeassistant.helpers.storage import Store
from homeassistant.util.network import is_local
from .const import (
DOMAIN,
KEY_AUTHENTICATED,
KEY_HASS_REFRESH_TOKEN_ID,
KEY_HASS_USER,
StrictConnectionMode,
)
from .session import HomeAssistantCookieStorage
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
_LOGGER = logging.getLogger(__name__)
@ -51,11 +39,6 @@ SAFE_QUERY_PARAMS: Final = ["height", "width"]
STORAGE_VERSION = 1
STORAGE_KEY = "http.auth"
CONTENT_USER_NAME = "Home Assistant Content"
STRICT_CONNECTION_EXCLUDED_PATH = "/api/webhook/"
STRICT_CONNECTION_GUARD_PAGE_NAME = "strict_connection_guard_page.html"
STRICT_CONNECTION_GUARD_PAGE = os.path.join(
os.path.dirname(__file__), STRICT_CONNECTION_GUARD_PAGE_NAME
)
@callback
@ -137,7 +120,6 @@ def async_user_not_allowed_do_auth(
async def async_setup_auth(
hass: HomeAssistant,
app: Application,
strict_connection_mode_non_cloud: StrictConnectionMode,
) -> None:
"""Create auth middleware for the app."""
store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
@ -160,10 +142,6 @@ async def async_setup_auth(
hass.data[STORAGE_KEY] = refresh_token.id
if strict_connection_mode_non_cloud is StrictConnectionMode.GUARD_PAGE:
# Load the guard page content on setup
await _read_strict_connection_guard_page(hass)
@callback
def async_validate_auth_header(request: Request) -> bool:
"""Test authorization header against access token.
@ -252,37 +230,6 @@ async def async_setup_auth(
authenticated = True
auth_type = "signed request"
if not authenticated and not request.path.startswith(
STRICT_CONNECTION_EXCLUDED_PATH
):
strict_connection_mode = strict_connection_mode_non_cloud
strict_connection_func = (
_async_perform_strict_connection_action_on_non_local
)
if is_cloud_connection(hass):
from homeassistant.components.cloud.util import ( # pylint: disable=import-outside-toplevel
get_strict_connection_mode,
)
strict_connection_mode = get_strict_connection_mode(hass)
strict_connection_func = _async_perform_strict_connection_action
if (
strict_connection_mode is not StrictConnectionMode.DISABLED
and not await hass.auth.session.async_validate_request_for_strict_connection_session(
request
)
and (
resp := await strict_connection_func(
hass,
request,
strict_connection_mode is StrictConnectionMode.GUARD_PAGE,
)
)
is not None
):
return resp
if authenticated and _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"Authenticated %s for %s using %s",
@ -294,69 +241,4 @@ async def async_setup_auth(
request[KEY_AUTHENTICATED] = authenticated
return await handler(request)
app.middlewares.append(session_middleware(HomeAssistantCookieStorage(hass)))
app.middlewares.append(auth_middleware)
async def _async_perform_strict_connection_action_on_non_local(
hass: HomeAssistant,
request: Request,
guard_page: bool,
) -> StreamResponse | None:
"""Perform strict connection mode action if the request is not local.
The function does the following:
- Try to get the IP address of the request. If it fails, assume it's not local
- If the request is local, return None (allow the request to continue)
- If guard_page is True, return a response with the content
- Otherwise close the connection and raise an exception
"""
try:
ip_address_ = ip_address(request.remote) # type: ignore[arg-type]
except ValueError:
_LOGGER.debug("Invalid IP address: %s", request.remote)
ip_address_ = None
if ip_address_ and is_local(ip_address_):
return None
return await _async_perform_strict_connection_action(hass, request, guard_page)
async def _async_perform_strict_connection_action(
hass: HomeAssistant,
request: Request,
guard_page: bool,
) -> StreamResponse | None:
"""Perform strict connection mode action.
The function does the following:
- If guard_page is True, return a response with the content
- Otherwise close the connection and raise an exception
"""
_LOGGER.debug("Perform strict connection action for %s", request.remote)
if guard_page:
return Response(
text=await _read_strict_connection_guard_page(hass),
content_type="text/html",
status=HTTPStatus.IM_A_TEAPOT,
)
if transport := request.transport:
# it should never happen that we don't have a transport
transport.close()
# We need to raise an exception to stop processing the request
raise HTTPBadRequest
@singleton.singleton(f"{DOMAIN}_{STRICT_CONNECTION_GUARD_PAGE_NAME}")
async def _read_strict_connection_guard_page(hass: HomeAssistant) -> str:
"""Read the strict connection guard page from disk via executor."""
def read_guard_page() -> str:
with open(STRICT_CONNECTION_GUARD_PAGE, encoding="utf-8") as file:
return file.read()
return await hass.async_add_executor_job(read_guard_page)

View file

@ -1,6 +1,5 @@
"""HTTP specific constants."""
from enum import StrEnum
from typing import Final
from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS # noqa: F401
@ -9,11 +8,3 @@ DOMAIN: Final = "http"
KEY_HASS_USER: Final = "hass_user"
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"
class StrictConnectionMode(StrEnum):
"""Enum for strict connection mode."""
DISABLED = "disabled"
GUARD_PAGE = "guard_page"
DROP_CONNECTION = "drop_connection"

View file

@ -1,5 +0,0 @@
{
"services": {
"create_temporary_strict_connection_url": "mdi:login-variant"
}
}

View file

@ -1 +0,0 @@
create_temporary_strict_connection_url: ~

View file

@ -1,160 +0,0 @@
"""Session http module."""
from functools import lru_cache
import logging
from aiohttp.web import Request, StreamResponse
from aiohttp_session import Session, SessionData
from aiohttp_session.cookie_storage import EncryptedCookieStorage
from cryptography.fernet import InvalidToken
from homeassistant.auth.const import REFRESH_TOKEN_EXPIRATION
from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import json_dumps
from homeassistant.helpers.network import is_cloud_connection
from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads
from .ban import process_wrong_login
_LOGGER = logging.getLogger(__name__)
COOKIE_NAME = "SC"
PREFIXED_COOKIE_NAME = f"__Host-{COOKIE_NAME}"
SESSION_CACHE_SIZE = 16
def _get_cookie_name(is_secure: bool) -> str:
"""Return the cookie name."""
return PREFIXED_COOKIE_NAME if is_secure else COOKIE_NAME
class HomeAssistantCookieStorage(EncryptedCookieStorage):
"""Home Assistant cookie storage.
Own class is required:
- to set the secure flag based on the connection type
- to use a LRU cache for session decryption
"""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the cookie storage."""
super().__init__(
hass.auth.session.key,
cookie_name=PREFIXED_COOKIE_NAME,
max_age=int(REFRESH_TOKEN_EXPIRATION),
httponly=True,
samesite="Lax",
secure=True,
encoder=json_dumps,
decoder=json_loads,
)
self._hass = hass
def _secure_connection(self, request: Request) -> bool:
"""Return if the connection is secure (https)."""
return is_cloud_connection(self._hass) or request.secure
def load_cookie(self, request: Request) -> str | None:
"""Load cookie."""
is_secure = self._secure_connection(request)
cookie_name = _get_cookie_name(is_secure)
return request.cookies.get(cookie_name)
@lru_cache(maxsize=SESSION_CACHE_SIZE)
def _decrypt_cookie(self, cookie: str) -> Session | None:
"""Decrypt and validate cookie."""
try:
data = SessionData( # type: ignore[misc]
self._decoder(
self._fernet.decrypt(
cookie.encode("utf-8"), ttl=self.max_age
).decode("utf-8")
)
)
except (InvalidToken, TypeError, ValueError, *JSON_DECODE_EXCEPTIONS):
_LOGGER.warning("Cannot decrypt/parse cookie value")
return None
session = Session(None, data=data, new=data is None, max_age=self.max_age)
# Validate session if not empty
if (
not session.empty
and not self._hass.auth.session.async_validate_strict_connection_session(
session
)
):
# Invalidate session as it is not valid
session.invalidate()
return session
async def new_session(self) -> Session:
"""Create a new session and mark it as changed."""
session = Session(None, data=None, new=True, max_age=self.max_age)
session.changed()
return session
async def load_session(self, request: Request) -> Session:
"""Load session."""
# Split parent function to use lru_cache
if (cookie := self.load_cookie(request)) is None:
return await self.new_session()
if (session := self._decrypt_cookie(cookie)) is None:
# Decrypting/parsing failed, log wrong login and create a new session
await process_wrong_login(request)
session = await self.new_session()
return session
async def save_session(
self, request: Request, response: StreamResponse, session: Session
) -> None:
"""Save session."""
is_secure = self._secure_connection(request)
cookie_name = _get_cookie_name(is_secure)
if session.empty:
response.del_cookie(cookie_name)
else:
params = self.cookie_params.copy()
params["secure"] = is_secure
params["max_age"] = session.max_age
cookie_data = self._encoder(self._get_session_data(session)).encode("utf-8")
response.set_cookie(
cookie_name,
self._fernet.encrypt(cookie_data).decode("utf-8"),
**params,
)
# Add Cache-Control header to not cache the cookie as it
# is used for session management
self._add_cache_control_header(response)
@staticmethod
def _add_cache_control_header(response: StreamResponse) -> None:
"""Add/set cache control header to no-cache="Set-Cookie"."""
# Structure of the Cache-Control header defined in
# https://datatracker.ietf.org/doc/html/rfc2068#section-14.9
if header := response.headers.get("Cache-Control"):
directives = []
for directive in header.split(","):
directive = directive.strip()
directive_lowered = directive.lower()
if directive_lowered.startswith("no-cache"):
if "set-cookie" in directive_lowered or directive.find("=") == -1:
# Set-Cookie is already in the no-cache directive or
# the whole request should not be cached -> Nothing to do
return
# Add Set-Cookie to the no-cache
# [:-1] to remove the " at the end of the directive
directive = f"{directive[:-1]}, Set-Cookie"
directives.append(directive)
header = ", ".join(directives)
else:
header = 'no-cache="Set-Cookie"'
response.headers["Cache-Control"] = header

File diff suppressed because one or more lines are too long

View file

@ -1,16 +0,0 @@
{
"exceptions": {
"strict_connection_not_enabled_non_cloud": {
"message": "Strict connection is not enabled for non-cloud requests"
},
"no_external_url_available": {
"message": "No external URL available"
}
},
"services": {
"create_temporary_strict_connection_url": {
"name": "Create a temporary strict connection URL",
"description": "Create a temporary strict connection URL, which can be used to login on another device."
}
}
}