Add strict connection (#112387)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Robert Resch 2024-04-12 14:47:46 +02:00 committed by GitHub
parent f70ce8abf9
commit 348e1df949
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 1187 additions and 64 deletions

View file

@ -28,6 +28,7 @@ from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRA
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
from .models import AuthFlowResult
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
from .session import SessionManager
EVENT_USER_ADDED = "user_added"
EVENT_USER_UPDATED = "user_updated"
@ -85,7 +86,7 @@ async def auth_manager_from_config(
module_hash[module.id] = module
manager = AuthManager(hass, store, provider_hash, module_hash)
manager.async_setup()
await manager.async_setup()
return manager
@ -180,9 +181,9 @@ class AuthManager:
self._remove_expired_job = HassJob(
self._async_remove_expired_refresh_tokens, job_type=HassJobType.Callback
)
self.session = SessionManager(hass, self)
@callback
def async_setup(self) -> None:
async def async_setup(self) -> None:
"""Set up the auth manager."""
hass = self.hass
hass.async_add_shutdown_job(
@ -191,6 +192,7 @@ class AuthManager:
)
)
self._async_track_next_refresh_token_expiration()
await self.session.async_setup()
@property
def auth_providers(self) -> list[AuthProvider]:

View file

@ -0,0 +1,205 @@
"""Session auth module."""
from __future__ import annotations
from datetime import datetime, timedelta
import secrets
from typing import TYPE_CHECKING, Final, TypedDict
from aiohttp.web import Request
from aiohttp_session import Session, get_session, new_session
from cryptography.fernet import Fernet
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.storage import Store
from homeassistant.util import dt as dt_util
from .models import RefreshToken
if TYPE_CHECKING:
from . import AuthManager
TEMP_TIMEOUT = timedelta(minutes=5)
TEMP_TIMEOUT_SECONDS = TEMP_TIMEOUT.total_seconds()
SESSION_ID = "id"
STORAGE_VERSION = 1
STORAGE_KEY = "auth.session"
class StrictConnectionTempSessionData:
"""Data for accessing unauthorized resources for a short period of time."""
__slots__ = ("cancel_remove", "absolute_expiry")
def __init__(self, cancel_remove: CALLBACK_TYPE) -> None:
"""Initialize the temp session data."""
self.cancel_remove: Final[CALLBACK_TYPE] = cancel_remove
self.absolute_expiry: Final[datetime] = dt_util.utcnow() + TEMP_TIMEOUT
class StoreData(TypedDict):
"""Data to store."""
unauthorized_sessions: dict[str, str]
key: str
class SessionManager:
"""Session manager."""
def __init__(self, hass: HomeAssistant, auth: AuthManager) -> None:
"""Initialize the strict connection manager."""
self._auth = auth
self._hass = hass
self._temp_sessions: dict[str, StrictConnectionTempSessionData] = {}
self._strict_connection_sessions: dict[str, str] = {}
self._store = Store[StoreData](
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
)
self._key: str | None = None
self._refresh_token_revoke_callbacks: dict[str, CALLBACK_TYPE] = {}
@property
def key(self) -> str:
"""Return the encryption key."""
if self._key is None:
self._key = Fernet.generate_key().decode()
self._async_schedule_save()
return self._key
async def async_validate_request_for_strict_connection_session(
self,
request: Request,
) -> bool:
"""Check if a request has a valid strict connection session."""
session = await get_session(request)
if session.new or session.empty:
return False
result = self.async_validate_strict_connection_session(session)
if result is False:
session.invalidate()
return result
@callback
def async_validate_strict_connection_session(
self,
session: Session,
) -> bool:
"""Validate a strict connection session."""
if not (session_id := session.get(SESSION_ID)):
return False
if token_id := self._strict_connection_sessions.get(session_id):
if self._auth.async_get_refresh_token(token_id):
return True
# refresh token is invalid, delete entry
self._strict_connection_sessions.pop(session_id)
self._async_schedule_save()
if data := self._temp_sessions.get(session_id):
if dt_util.utcnow() <= data.absolute_expiry:
return True
# session expired, delete entry
self._temp_sessions.pop(session_id).cancel_remove()
return False
@callback
def _async_register_revoke_token_callback(self, refresh_token_id: str) -> None:
"""Register a callback to revoke all sessions for a refresh token."""
if refresh_token_id in self._refresh_token_revoke_callbacks:
return
@callback
def async_invalidate_auth_sessions() -> None:
"""Invalidate all sessions for a refresh token."""
self._strict_connection_sessions = {
session_id: token_id
for session_id, token_id in self._strict_connection_sessions.items()
if token_id != refresh_token_id
}
self._async_schedule_save()
self._refresh_token_revoke_callbacks[refresh_token_id] = (
self._auth.async_register_revoke_token_callback(
refresh_token_id, async_invalidate_auth_sessions
)
)
async def async_create_session(
self,
request: Request,
refresh_token: RefreshToken,
) -> None:
"""Create new session for given refresh token.
Caller needs to make sure that the refresh token is valid.
By creating a session, we are implicitly revoking all other
sessions for the given refresh token as there is one refresh
token per device/user case.
"""
self._strict_connection_sessions = {
session_id: token_id
for session_id, token_id in self._strict_connection_sessions.items()
if token_id != refresh_token.id
}
self._async_register_revoke_token_callback(refresh_token.id)
session_id = await self._async_create_new_session(request)
self._strict_connection_sessions[session_id] = refresh_token.id
self._async_schedule_save()
async def async_create_temp_unauthorized_session(self, request: Request) -> None:
"""Create a temporary unauthorized session."""
session_id = await self._async_create_new_session(
request, max_age=int(TEMP_TIMEOUT_SECONDS)
)
@callback
def remove(_: datetime) -> None:
self._temp_sessions.pop(session_id, None)
self._temp_sessions[session_id] = StrictConnectionTempSessionData(
async_call_later(self._hass, TEMP_TIMEOUT_SECONDS, remove)
)
async def _async_create_new_session(
self,
request: Request,
*,
max_age: int | None = None,
) -> str:
session_id = secrets.token_hex(64)
session = await new_session(request)
session[SESSION_ID] = session_id
if max_age is not None:
session.max_age = max_age
return session_id
@callback
def _async_schedule_save(self, delay: float = 1) -> None:
"""Save sessions."""
self._store.async_delay_save(self._data_to_save, delay)
@callback
def _data_to_save(self) -> StoreData:
"""Return the data to store."""
return StoreData(
unauthorized_sessions=self._strict_connection_sessions,
key=self.key,
)
async def async_setup(self) -> None:
"""Set up session manager."""
data = await self._store.async_load()
if data is None:
return
self._key = data["key"]
self._strict_connection_sessions = data["unauthorized_sessions"]
for token_id in self._strict_connection_sessions.values():
self._async_register_revoke_token_callback(token_id)

View file

@ -162,6 +162,7 @@ from homeassistant.util import dt as dt_util
from . import indieauth, login_flow, mfa_setup_flow
DOMAIN = "auth"
STRICT_CONNECTION_URL = "/auth/strict_connection/temp_token"
StoreResultType = Callable[[str, Credentials], str]
RetrieveResultType = Callable[[str, str], Credentials | None]
@ -187,6 +188,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass.http.register_view(RevokeTokenView())
hass.http.register_view(LinkUserView(retrieve_result))
hass.http.register_view(OAuth2AuthorizeCallbackView())
hass.http.register_view(StrictConnectionTempTokenView())
websocket_api.async_register_command(hass, websocket_current_user)
websocket_api.async_register_command(hass, websocket_create_long_lived_access_token)
@ -260,10 +262,10 @@ class TokenView(HomeAssistantView):
return await RevokeTokenView.post(self, request) # type: ignore[arg-type]
if grant_type == "authorization_code":
return await self._async_handle_auth_code(hass, data, request.remote)
return await self._async_handle_auth_code(hass, data, request)
if grant_type == "refresh_token":
return await self._async_handle_refresh_token(hass, data, request.remote)
return await self._async_handle_refresh_token(hass, data, request)
return self.json(
{"error": "unsupported_grant_type"}, status_code=HTTPStatus.BAD_REQUEST
@ -273,7 +275,7 @@ class TokenView(HomeAssistantView):
self,
hass: HomeAssistant,
data: MultiDictProxy[str],
remote_addr: str | None,
request: web.Request,
) -> web.Response:
"""Handle authorization code request."""
client_id = data.get("client_id")
@ -313,7 +315,7 @@ class TokenView(HomeAssistantView):
)
try:
access_token = hass.auth.async_create_access_token(
refresh_token, remote_addr
refresh_token, request.remote
)
except InvalidAuthError as exc:
return self.json(
@ -321,6 +323,7 @@ class TokenView(HomeAssistantView):
status_code=HTTPStatus.FORBIDDEN,
)
await hass.auth.session.async_create_session(request, refresh_token)
return self.json(
{
"access_token": access_token,
@ -341,9 +344,9 @@ class TokenView(HomeAssistantView):
self,
hass: HomeAssistant,
data: MultiDictProxy[str],
remote_addr: str | None,
request: web.Request,
) -> web.Response:
"""Handle authorization code request."""
"""Handle refresh token request."""
client_id = data.get("client_id")
if client_id is not None and not indieauth.verify_client_id(client_id):
return self.json(
@ -381,7 +384,7 @@ class TokenView(HomeAssistantView):
try:
access_token = hass.auth.async_create_access_token(
refresh_token, remote_addr
refresh_token, request.remote
)
except InvalidAuthError as exc:
return self.json(
@ -389,6 +392,7 @@ class TokenView(HomeAssistantView):
status_code=HTTPStatus.FORBIDDEN,
)
await hass.auth.session.async_create_session(request, refresh_token)
return self.json(
{
"access_token": access_token,
@ -437,6 +441,20 @@ class LinkUserView(HomeAssistantView):
return self.json_message("User linked")
class StrictConnectionTempTokenView(HomeAssistantView):
"""View to get temporary strict connection token."""
url = STRICT_CONNECTION_URL
name = "api:auth:strict_connection:temp_token"
requires_auth = False
async def get(self, request: web.Request) -> web.Response:
"""Get a temporary token and redirect to main page."""
hass = request.app[KEY_HASS]
await hass.auth.session.async_create_temp_unauthorized_session(request)
raise web.HTTPSeeOther(location="/")
@callback
def _create_auth_code_store() -> tuple[StoreResultType, RetrieveResultType]:
"""Create an in memory store."""

View file

@ -197,7 +197,6 @@ class HassIOIngress(HomeAssistantView):
content_type or simple_response.content_type
):
simple_response.enable_compression()
await simple_response.prepare(request)
return simple_response
# Stream response

View file

@ -10,7 +10,8 @@ import os
import socket
import ssl
from tempfile import NamedTemporaryFile
from typing import Any, Final, TypedDict, cast
from typing import Any, Final, Required, TypedDict, cast
from urllib.parse import quote_plus, urljoin
from aiohttp import web
from aiohttp.abc import AbstractStreamWriter
@ -30,8 +31,20 @@ from yarl import URL
from homeassistant.components.network import async_get_source_ip
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, SERVER_PORT
from homeassistant.core import Event, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.core import (
Event,
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
callback,
)
from homeassistant.exceptions import (
HomeAssistantError,
ServiceValidationError,
Unauthorized,
UnknownUser,
)
from homeassistant.helpers import storage
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.http import (
@ -53,9 +66,13 @@ from homeassistant.util import dt as dt_util, ssl as ssl_util
from homeassistant.util.async_ import create_eager_task
from homeassistant.util.json import json_loads
from .auth import async_setup_auth
from .auth import async_setup_auth, async_sign_path
from .ban import setup_bans
from .const import KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER # noqa: F401
from .const import ( # noqa: F401
KEY_HASS_REFRESH_TOKEN_ID,
KEY_HASS_USER,
StrictConnectionMode,
)
from .cors import setup_cors
from .decorators import require_admin # noqa: F401
from .forwarded import async_setup_forwarded
@ -80,6 +97,7 @@ CONF_TRUSTED_PROXIES: Final = "trusted_proxies"
CONF_LOGIN_ATTEMPTS_THRESHOLD: Final = "login_attempts_threshold"
CONF_IP_BAN_ENABLED: Final = "ip_ban_enabled"
CONF_SSL_PROFILE: Final = "ssl_profile"
CONF_STRICT_CONNECTION: Final = "strict_connection"
SSL_MODERN: Final = "modern"
SSL_INTERMEDIATE: Final = "intermediate"
@ -129,6 +147,9 @@ HTTP_SCHEMA: Final = vol.All(
[SSL_INTERMEDIATE, SSL_MODERN]
),
vol.Optional(CONF_USE_X_FRAME_OPTIONS, default=True): cv.boolean,
vol.Optional(
CONF_STRICT_CONNECTION, default=StrictConnectionMode.DISABLED
): vol.In([e.value for e in StrictConnectionMode]),
}
),
)
@ -152,6 +173,7 @@ class ConfData(TypedDict, total=False):
login_attempts_threshold: int
ip_ban_enabled: bool
ssl_profile: str
strict_connection: Required[StrictConnectionMode]
@bind_hass
@ -218,6 +240,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
login_threshold=login_threshold,
is_ban_enabled=is_ban_enabled,
use_x_frame_options=use_x_frame_options,
strict_connection_non_cloud=conf[CONF_STRICT_CONNECTION],
)
async def stop_server(event: Event) -> None:
@ -247,6 +270,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
local_ip, host, server_port, ssl_certificate is not None
)
_setup_services(hass, conf)
return True
@ -331,6 +355,7 @@ class HomeAssistantHTTP:
login_threshold: int,
is_ban_enabled: bool,
use_x_frame_options: bool,
strict_connection_non_cloud: StrictConnectionMode,
) -> None:
"""Initialize the server."""
self.app[KEY_HASS] = self.hass
@ -347,7 +372,7 @@ class HomeAssistantHTTP:
if is_ban_enabled:
setup_bans(self.hass, self.app, login_threshold)
await async_setup_auth(self.hass, self.app)
await async_setup_auth(self.hass, self.app, strict_connection_non_cloud)
setup_headers(self.app, use_x_frame_options)
setup_cors(self.app, cors_origins)
@ -577,3 +602,59 @@ async def start_http_server_and_save_config(
]
store.async_delay_save(lambda: conf, SAVE_DELAY)
@callback
def _setup_services(hass: HomeAssistant, conf: ConfData) -> None:
"""Set up services for HTTP component."""
async def create_temporary_strict_connection_url(
call: ServiceCall,
) -> ServiceResponse:
"""Create a strict connection url and return it."""
# Copied form homeassistant/helpers/service.py#_async_admin_handler
# as the helper supports no responses yet
if call.context.user_id:
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(context=call.context)
if not user.is_admin:
raise Unauthorized(context=call.context)
if conf[CONF_STRICT_CONNECTION] is StrictConnectionMode.DISABLED:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="strict_connection_not_enabled_non_cloud",
)
try:
url = get_url(hass, prefer_external=True, allow_internal=False)
except NoURLAvailableError as ex:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="no_external_url_available",
) from ex
# to avoid circular import
# pylint: disable-next=import-outside-toplevel
from homeassistant.components.auth import STRICT_CONNECTION_URL
path = async_sign_path(
hass,
STRICT_CONNECTION_URL,
datetime.timedelta(hours=1),
use_content_user=True,
)
url = urljoin(url, path)
return {
"url": f"https://login.home-assistant.io?u={quote_plus(url)}",
"direct_url": url,
}
hass.services.async_register(
DOMAIN,
"create_temporary_strict_connection_url",
create_temporary_strict_connection_url,
supports_response=SupportsResponse.ONLY,
)

View file

@ -4,14 +4,18 @@ from __future__ import annotations
from collections.abc import Awaitable, Callable
from datetime import timedelta
from http import HTTPStatus
from ipaddress import ip_address
import logging
import os
import secrets
import time
from typing import Any, Final
from aiohttp import hdrs
from aiohttp.web import Application, Request, StreamResponse, middleware
from aiohttp.web import Application, Request, Response, StreamResponse, middleware
from aiohttp.web_exceptions import HTTPBadRequest
from aiohttp_session import session_middleware
import jwt
from jwt import api_jws
from yarl import URL
@ -27,7 +31,13 @@ from homeassistant.helpers.network import is_cloud_connection
from homeassistant.helpers.storage import Store
from homeassistant.util.network import is_local
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
from .const import (
KEY_AUTHENTICATED,
KEY_HASS_REFRESH_TOKEN_ID,
KEY_HASS_USER,
StrictConnectionMode,
)
from .session import HomeAssistantCookieStorage
_LOGGER = logging.getLogger(__name__)
@ -39,6 +49,10 @@ SAFE_QUERY_PARAMS: Final = ["height", "width"]
STORAGE_VERSION = 1
STORAGE_KEY = "http.auth"
CONTENT_USER_NAME = "Home Assistant Content"
STRICT_CONNECTION_EXCLUDED_PATH = "/api/webhook/"
STRICT_CONNECTION_STATIC_PAGE = os.path.join(
os.path.dirname(__file__), "strict_connection_static_page.html"
)
@callback
@ -48,13 +62,16 @@ def async_sign_path(
expiration: timedelta,
*,
refresh_token_id: str | None = None,
use_content_user: bool = False,
) -> str:
"""Sign a path for temporary access without auth header."""
if (secret := hass.data.get(DATA_SIGN_SECRET)) is None:
secret = hass.data[DATA_SIGN_SECRET] = secrets.token_hex()
if refresh_token_id is None:
if connection := websocket_api.current_connection.get():
if use_content_user:
refresh_token_id = hass.data[STORAGE_KEY]
elif connection := websocket_api.current_connection.get():
refresh_token_id = connection.refresh_token_id
elif (
request := current_request.get()
@ -114,7 +131,11 @@ def async_user_not_allowed_do_auth(
return "User cannot authenticate remotely"
async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
async def async_setup_auth(
hass: HomeAssistant,
app: Application,
strict_connection_mode_non_cloud: StrictConnectionMode,
) -> None:
"""Create auth middleware for the app."""
store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
if (data := await store.async_load()) is None:
@ -135,6 +156,16 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
await store.async_save(data)
hass.data[STORAGE_KEY] = refresh_token.id
strict_connection_static_file_content = None
if strict_connection_mode_non_cloud is StrictConnectionMode.STATIC_PAGE:
def read_static_page() -> str:
with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
return file.read()
strict_connection_static_file_content = await hass.async_add_executor_job(
read_static_page
)
@callback
def async_validate_auth_header(request: Request) -> bool:
@ -224,6 +255,22 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
authenticated = True
auth_type = "signed request"
if (
not authenticated
and strict_connection_mode_non_cloud is not StrictConnectionMode.DISABLED
and not request.path.startswith(STRICT_CONNECTION_EXCLUDED_PATH)
and not await hass.auth.session.async_validate_request_for_strict_connection_session(
request
)
and (
resp := _async_perform_action_on_non_local(
request, strict_connection_static_file_content
)
)
is not None
):
return resp
if authenticated and _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"Authenticated %s for %s using %s",
@ -235,4 +282,43 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
request[KEY_AUTHENTICATED] = authenticated
return await handler(request)
app.middlewares.append(session_middleware(HomeAssistantCookieStorage(hass)))
app.middlewares.append(auth_middleware)
@callback
def _async_perform_action_on_non_local(
request: Request,
strict_connection_static_file_content: str | None,
) -> StreamResponse | None:
"""Perform strict connection mode action if the request is not local.
The function does the following:
- Try to get the IP address of the request. If it fails, assume it's not local
- If the request is local, return None (allow the request to continue)
- If strict_connection_static_file_content is set, return a response with the content
- Otherwise close the connection and raise an exception
"""
try:
ip_address_ = ip_address(request.remote) # type: ignore[arg-type]
except ValueError:
_LOGGER.debug("Invalid IP address: %s", request.remote)
ip_address_ = None
if ip_address_ and is_local(ip_address_):
return None
_LOGGER.debug("Perform strict connection action for %s", ip_address_)
if strict_connection_static_file_content:
return Response(
text=strict_connection_static_file_content,
content_type="text/html",
status=HTTPStatus.IM_A_TEAPOT,
)
if transport := request.transport:
# it should never happen that we don't have a transport
transport.close()
# We need to raise an exception to stop processing the request
raise HTTPBadRequest

View file

@ -1,8 +1,17 @@
"""HTTP specific constants."""
from enum import StrEnum
from typing import Final
from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS # noqa: F401
KEY_HASS_USER: Final = "hass_user"
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"
class StrictConnectionMode(StrEnum):
"""Enum for strict connection mode."""
DISABLED = "disabled"
STATIC_PAGE = "static_page"
DROP_CONNECTION = "drop_connection"

View file

@ -0,0 +1,5 @@
{
"services": {
"create_temporary_strict_connection_url": "mdi:login-variant"
}
}

View file

@ -0,0 +1 @@
create_temporary_strict_connection_url: ~

View file

@ -0,0 +1,160 @@
"""Session http module."""
from functools import lru_cache
import logging
from aiohttp.web import Request, StreamResponse
from aiohttp_session import Session, SessionData
from aiohttp_session.cookie_storage import EncryptedCookieStorage
from cryptography.fernet import InvalidToken
from homeassistant.auth.const import REFRESH_TOKEN_EXPIRATION
from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import json_dumps
from homeassistant.helpers.network import is_cloud_connection
from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads
from .ban import process_wrong_login
_LOGGER = logging.getLogger(__name__)
COOKIE_NAME = "SC"
PREFIXED_COOKIE_NAME = f"__Host-{COOKIE_NAME}"
SESSION_CACHE_SIZE = 16
def _get_cookie_name(is_secure: bool) -> str:
"""Return the cookie name."""
return PREFIXED_COOKIE_NAME if is_secure else COOKIE_NAME
class HomeAssistantCookieStorage(EncryptedCookieStorage):
"""Home Assistant cookie storage.
Own class is required:
- to set the secure flag based on the connection type
- to use a LRU cache for session decryption
"""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the cookie storage."""
super().__init__(
hass.auth.session.key,
cookie_name=PREFIXED_COOKIE_NAME,
max_age=int(REFRESH_TOKEN_EXPIRATION),
httponly=True,
samesite="Lax",
secure=True,
encoder=json_dumps,
decoder=json_loads,
)
self._hass = hass
def _secure_connection(self, request: Request) -> bool:
"""Return if the connection is secure (https)."""
return is_cloud_connection(self._hass) or request.secure
def load_cookie(self, request: Request) -> str | None:
"""Load cookie."""
is_secure = self._secure_connection(request)
cookie_name = _get_cookie_name(is_secure)
return request.cookies.get(cookie_name)
@lru_cache(maxsize=SESSION_CACHE_SIZE)
def _decrypt_cookie(self, cookie: str) -> Session | None:
"""Decrypt and validate cookie."""
try:
data = SessionData( # type: ignore[misc]
self._decoder(
self._fernet.decrypt(
cookie.encode("utf-8"), ttl=self.max_age
).decode("utf-8")
)
)
except (InvalidToken, TypeError, ValueError, *JSON_DECODE_EXCEPTIONS):
_LOGGER.warning("Cannot decrypt/parse cookie value")
return None
session = Session(None, data=data, new=data is None, max_age=self.max_age)
# Validate session if not empty
if (
not session.empty
and not self._hass.auth.session.async_validate_strict_connection_session(
session
)
):
# Invalidate session as it is not valid
session.invalidate()
return session
async def new_session(self) -> Session:
"""Create a new session and mark it as changed."""
session = Session(None, data=None, new=True, max_age=self.max_age)
session.changed()
return session
async def load_session(self, request: Request) -> Session:
"""Load session."""
# Split parent function to use lru_cache
if (cookie := self.load_cookie(request)) is None:
return await self.new_session()
if (session := self._decrypt_cookie(cookie)) is None:
# Decrypting/parsing failed, log wrong login and create a new session
await process_wrong_login(request)
session = await self.new_session()
return session
async def save_session(
self, request: Request, response: StreamResponse, session: Session
) -> None:
"""Save session."""
is_secure = self._secure_connection(request)
cookie_name = _get_cookie_name(is_secure)
if session.empty:
response.del_cookie(cookie_name)
else:
params = self.cookie_params.copy()
params["secure"] = is_secure
params["max_age"] = session.max_age
cookie_data = self._encoder(self._get_session_data(session)).encode("utf-8")
response.set_cookie(
cookie_name,
self._fernet.encrypt(cookie_data).decode("utf-8"),
**params,
)
# Add Cache-Control header to not cache the cookie as it
# is used for session management
self._add_cache_control_header(response)
@staticmethod
def _add_cache_control_header(response: StreamResponse) -> None:
"""Add/set cache control header to no-cache="Set-Cookie"."""
# Structure of the Cache-Control header defined in
# https://datatracker.ietf.org/doc/html/rfc2068#section-14.9
if header := response.headers.get("Cache-Control"):
directives = []
for directive in header.split(","):
directive = directive.strip()
directive_lowered = directive.lower()
if directive_lowered.startswith("no-cache"):
if "set-cookie" in directive_lowered or directive.find("=") == -1:
# Set-Cookie is already in the no-cache directive or
# the whole request should not be cached -> Nothing to do
return
# Add Set-Cookie to the no-cache
# [:-1] to remove the " at the end of the directive
directive = f"{directive[:-1]}, Set-Cookie"
directives.append(directive)
header = ", ".join(directives)
else:
header = 'no-cache="Set-Cookie"'
response.headers["Cache-Control"] = header

View file

@ -0,0 +1,46 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>I'm a Teapot</title>
<style>
body {
font-family: Arial, sans-serif;
background-color: #f3f3f3;
margin: 0;
padding: 0;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
}
.container {
text-align: center;
}
h1 {
font-size: 36px;
color: #333;
margin-bottom: 10px;
}
p {
font-size: 18px;
color: #666;
margin-bottom: 20px;
}
.teapot {
font-size: 60px;
}
</style>
</head>
<body>
<div class="container">
<h1>Error 418: I'm a Teapot</h1>
<p>
Oops! Looks like the server is taking a coffee break.<br />
Don't worry, it'll be back to brewing your requests in no time!
</p>
<p class="teapot">&#9749;</p>
</div>
</body>
</html>

View file

@ -0,0 +1,16 @@
{
"exceptions": {
"strict_connection_not_enabled_non_cloud": {
"message": "Strict connection is not enabled for non-cloud requests"
},
"no_external_url_available": {
"message": "No external URL available"
}
},
"services": {
"create_temporary_strict_connection_url": {
"name": "Create a temporary strict connection URL",
"description": "Create a temporary strict connection URL, which can be used to login on another device."
}
}
}

View file

@ -7,6 +7,7 @@ aiohttp-fast-url-dispatcher==0.3.0
aiohttp-zlib-ng==0.3.1
aiohttp==3.9.4
aiohttp_cors==0.7.0
aiohttp_session==2.12.0
astral==2.2
async-interrupt==1.1.1
async-upnp-client==0.38.3

View file

@ -26,6 +26,7 @@ dependencies = [
"aiodns==3.2.0",
"aiohttp==3.9.4",
"aiohttp_cors==0.7.0",
"aiohttp_session==2.12.0",
"aiohttp-fast-url-dispatcher==0.3.0",
"aiohttp-zlib-ng==0.3.1",
"astral==2.2",

View file

@ -6,6 +6,7 @@
aiodns==3.2.0
aiohttp==3.9.4
aiohttp_cors==0.7.0
aiohttp_session==2.12.0
aiohttp-fast-url-dispatcher==0.3.0
aiohttp-zlib-ng==0.3.1
astral==2.2

View file

@ -306,7 +306,7 @@ async def test_api_get_services(
for serv_domain in data:
local = local_services.pop(serv_domain["domain"])
assert serv_domain["services"] == local
assert serv_domain["services"].keys() == local.keys()
async def test_api_call_service_no_data(

View file

@ -1,22 +1,28 @@
"""The tests for the Home Assistant HTTP component."""
from collections.abc import Awaitable, Callable
from datetime import timedelta
from http import HTTPStatus
from ipaddress import ip_network
import logging
from unittest.mock import Mock, patch
from aiohttp import BasicAuth, web
from aiohttp import BasicAuth, ServerDisconnectedError, web
from aiohttp.test_utils import TestClient
from aiohttp.web_exceptions import HTTPUnauthorized
from aiohttp_session import get_session
import jwt
import pytest
import yarl
from yarl import URL
from homeassistant.auth.const import GROUP_ID_READ_ONLY
from homeassistant.auth.models import User
from homeassistant.auth.models import RefreshToken, User
from homeassistant.auth.providers import trusted_networks
from homeassistant.auth.providers.legacy_api_password import (
LegacyApiPasswordAuthProvider,
)
from homeassistant.auth.session import SESSION_ID, TEMP_TIMEOUT
from homeassistant.components import websocket_api
from homeassistant.components.http import KEY_HASS
from homeassistant.components.http.auth import (
@ -24,11 +30,12 @@ from homeassistant.components.http.auth import (
DATA_SIGN_SECRET,
SIGN_QUERY_PARAM,
STORAGE_KEY,
STRICT_CONNECTION_STATIC_PAGE,
async_setup_auth,
async_sign_path,
async_user_not_allowed_do_auth,
)
from homeassistant.components.http.const import KEY_AUTHENTICATED
from homeassistant.components.http.const import KEY_AUTHENTICATED, StrictConnectionMode
from homeassistant.components.http.forwarded import async_setup_forwarded
from homeassistant.components.http.request_context import (
current_request,
@ -36,13 +43,15 @@ from homeassistant.components.http.request_context import (
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow
from . import HTTP_HEADER_HA_AUTH
from tests.common import MockUser
from tests.common import MockUser, async_fire_time_changed
from tests.test_util import mock_real_ip
from tests.typing import ClientSessionGenerator, WebSocketGenerator
_LOGGER = logging.getLogger(__name__)
API_PASSWORD = "test-password"
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
@ -54,7 +63,13 @@ TRUSTED_NETWORKS = [
]
TRUSTED_ADDRESSES = ["100.64.0.1", "192.0.2.100", "FD01:DB8::1", "2001:DB8:ABCD::1"]
EXTERNAL_ADDRESSES = ["198.51.100.1", "2001:DB8:FA1::1"]
UNTRUSTED_ADDRESSES = [*EXTERNAL_ADDRESSES, "127.0.0.1", "::1"]
LOCALHOST_ADDRESSES = ["127.0.0.1", "::1"]
UNTRUSTED_ADDRESSES = [*EXTERNAL_ADDRESSES, *LOCALHOST_ADDRESSES]
PRIVATE_ADDRESSES = [
"192.168.10.10",
"172.16.4.20",
"10.100.50.5",
]
async def mock_handler(request):
@ -122,7 +137,7 @@ async def test_cant_access_with_password_in_header(
hass: HomeAssistant,
) -> None:
"""Test access with password in header."""
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
@ -139,7 +154,7 @@ async def test_cant_access_with_password_in_query(
hass: HomeAssistant,
) -> None:
"""Test access with password in URL."""
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
resp = await client.get("/", params={"api_password": API_PASSWORD})
@ -159,7 +174,7 @@ async def test_basic_auth_does_not_work(
legacy_auth: LegacyApiPasswordAuthProvider,
) -> None:
"""Test access with basic authentication."""
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
req = await client.get("/", auth=BasicAuth("homeassistant", API_PASSWORD))
@ -183,7 +198,7 @@ async def test_cannot_access_with_trusted_ip(
hass_owner_user: MockUser,
) -> None:
"""Test access with an untrusted ip address."""
await async_setup_auth(hass, app2)
await async_setup_auth(hass, app2, StrictConnectionMode.DISABLED)
set_mock_ip = mock_real_ip(app2)
client = await aiohttp_client(app2)
@ -211,7 +226,7 @@ async def test_auth_active_access_with_access_token_in_header(
) -> None:
"""Test access with access token in header."""
token = hass_access_token
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@ -247,7 +262,7 @@ async def test_auth_active_access_with_trusted_ip(
hass_owner_user: MockUser,
) -> None:
"""Test access with an untrusted ip address."""
await async_setup_auth(hass, app2)
await async_setup_auth(hass, app2, StrictConnectionMode.DISABLED)
set_mock_ip = mock_real_ip(app2)
client = await aiohttp_client(app2)
@ -274,7 +289,7 @@ async def test_auth_legacy_support_api_password_cannot_access(
hass: HomeAssistant,
) -> None:
"""Test access using api_password if auth.support_legacy."""
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
@ -296,7 +311,7 @@ async def test_auth_access_signed_path_with_refresh_token(
"""Test access with signed url."""
app.router.add_post("/", mock_handler)
app.router.add_get("/another_path", mock_handler)
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@ -341,7 +356,7 @@ async def test_auth_access_signed_path_with_query_param(
"""Test access with signed url and query params."""
app.router.add_post("/", mock_handler)
app.router.add_get("/another_path", mock_handler)
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@ -371,7 +386,7 @@ async def test_auth_access_signed_path_with_query_param_order(
"""Test access with signed url and query params different order."""
app.router.add_post("/", mock_handler)
app.router.add_get("/another_path", mock_handler)
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@ -412,7 +427,7 @@ async def test_auth_access_signed_path_with_query_param_safe_param(
"""Test access with signed url and changing a safe param."""
app.router.add_post("/", mock_handler)
app.router.add_get("/another_path", mock_handler)
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@ -451,7 +466,7 @@ async def test_auth_access_signed_path_with_query_param_tamper(
"""Test access with signed url and query params that have been tampered with."""
app.router.add_post("/", mock_handler)
app.router.add_get("/another_path", mock_handler)
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@ -520,7 +535,7 @@ async def test_auth_access_signed_path_with_http(
)
app.router.add_get("/hello", mock_handler)
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@ -544,7 +559,7 @@ async def test_auth_access_signed_path_with_content_user(
hass: HomeAssistant, app, aiohttp_client: ClientSessionGenerator
) -> None:
"""Test access signed url uses content user."""
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
signed_path = async_sign_path(hass, "/", timedelta(seconds=5))
signature = yarl.URL(signed_path).query["authSig"]
claims = jwt.decode(
@ -564,7 +579,7 @@ async def test_local_only_user_rejected(
) -> None:
"""Test access with access token in header."""
token = hass_access_token
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
set_mock_ip = mock_real_ip(app)
client = await aiohttp_client(app)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
@ -630,7 +645,7 @@ async def test_create_user_once(hass: HomeAssistant) -> None:
"""Test that we reuse the user."""
cur_users = len(await hass.auth.async_get_users())
app = web.Application()
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
users = await hass.auth.async_get_users()
assert len(users) == cur_users + 1
@ -642,7 +657,287 @@ async def test_create_user_once(hass: HomeAssistant) -> None:
assert len(user.refresh_tokens) == 1
assert user.system_generated
await async_setup_auth(hass, app)
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
# test it did not create a user
assert len(await hass.auth.async_get_users()) == cur_users + 1
@pytest.fixture
def app_strict_connection(hass):
"""Fixture to set up a web.Application."""
async def handler(request):
"""Return if request was authenticated."""
return web.json_response(data={"authenticated": request[KEY_AUTHENTICATED]})
app = web.Application()
app[KEY_HASS] = hass
app.router.add_get("/", handler)
async_setup_forwarded(app, True, [])
return app
@pytest.mark.parametrize(
"strict_connection_mode", [e.value for e in StrictConnectionMode]
)
async def test_strict_connection_non_cloud_authenticated_requests(
hass: HomeAssistant,
app_strict_connection: web.Application,
aiohttp_client: ClientSessionGenerator,
hass_access_token: str,
strict_connection_mode: StrictConnectionMode,
) -> None:
"""Test authenticated requests with strict connection."""
token = hass_access_token
await async_setup_auth(hass, app_strict_connection, strict_connection_mode)
set_mock_ip = mock_real_ip(app_strict_connection)
client = await aiohttp_client(app_strict_connection)
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
assert refresh_token
assert hass.auth.session._strict_connection_sessions == {}
signed_path = async_sign_path(
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
)
for remote_addr in (*LOCALHOST_ADDRESSES, *PRIVATE_ADDRESSES, *EXTERNAL_ADDRESSES):
set_mock_ip(remote_addr)
# authorized requests should work normally
req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
assert req.status == HTTPStatus.OK
assert await req.json() == {"authenticated": True}
req = await client.get(signed_path)
assert req.status == HTTPStatus.OK
assert await req.json() == {"authenticated": True}
@pytest.mark.parametrize(
"strict_connection_mode", [e.value for e in StrictConnectionMode]
)
async def test_strict_connection_non_cloud_local_unauthenticated_requests(
hass: HomeAssistant,
app_strict_connection: web.Application,
aiohttp_client: ClientSessionGenerator,
strict_connection_mode: StrictConnectionMode,
) -> None:
"""Test local unauthenticated requests with strict connection."""
await async_setup_auth(hass, app_strict_connection, strict_connection_mode)
set_mock_ip = mock_real_ip(app_strict_connection)
client = await aiohttp_client(app_strict_connection)
assert hass.auth.session._strict_connection_sessions == {}
for remote_addr in (*LOCALHOST_ADDRESSES, *PRIVATE_ADDRESSES):
set_mock_ip(remote_addr)
# local requests should work normally
req = await client.get("/")
assert req.status == HTTPStatus.OK
assert await req.json() == {"authenticated": False}
def _add_set_cookie_endpoint(app: web.Application, refresh_token: RefreshToken) -> None:
"""Add an endpoint to set a cookie."""
async def set_cookie(request: web.Request) -> web.Response:
hass = request.app[KEY_HASS]
# Clear all sessions
hass.auth.session._temp_sessions.clear()
hass.auth.session._strict_connection_sessions.clear()
if request.query["token"] == "refresh":
await hass.auth.session.async_create_session(request, refresh_token)
else:
await hass.auth.session.async_create_temp_unauthorized_session(request)
session = await get_session(request)
return web.Response(text=session[SESSION_ID])
app.router.add_get("/test/cookie", set_cookie)
async def _test_strict_connection_non_cloud_enabled_setup(
hass: HomeAssistant,
app: web.Application,
aiohttp_client: ClientSessionGenerator,
hass_access_token: str,
strict_connection_mode: StrictConnectionMode,
) -> tuple[TestClient, Callable[[str], None], RefreshToken]:
"""Test external unauthenticated requests with strict connection non cloud enabled."""
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
assert refresh_token
session = hass.auth.session
assert session._strict_connection_sessions == {}
assert session._temp_sessions == {}
_add_set_cookie_endpoint(app, refresh_token)
await async_setup_auth(hass, app, strict_connection_mode)
set_mock_ip = mock_real_ip(app)
client = await aiohttp_client(app)
return (client, set_mock_ip, refresh_token)
async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests(
hass: HomeAssistant,
app: web.Application,
aiohttp_client: ClientSessionGenerator,
hass_access_token: str,
perform_unauthenticated_request: Callable[
[HomeAssistant, TestClient], Awaitable[None]
],
strict_connection_mode: StrictConnectionMode,
) -> None:
"""Test external unauthenticated requests with strict connection non cloud enabled."""
client, set_mock_ip, _ = await _test_strict_connection_non_cloud_enabled_setup(
hass, app, aiohttp_client, hass_access_token, strict_connection_mode
)
for remote_addr in EXTERNAL_ADDRESSES:
set_mock_ip(remote_addr)
await perform_unauthenticated_request(hass, client)
async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_refresh_token(
hass: HomeAssistant,
app: web.Application,
aiohttp_client: ClientSessionGenerator,
hass_access_token: str,
perform_unauthenticated_request: Callable[
[HomeAssistant, TestClient], Awaitable[None]
],
strict_connection_mode: StrictConnectionMode,
) -> None:
"""Test external unauthenticated requests with strict connection non cloud enabled and refresh token cookie."""
(
client,
set_mock_ip,
refresh_token,
) = await _test_strict_connection_non_cloud_enabled_setup(
hass, app, aiohttp_client, hass_access_token, strict_connection_mode
)
session = hass.auth.session
# set strict connection cookie with refresh token
set_mock_ip(LOCALHOST_ADDRESSES[0])
session_id = await (await client.get("/test/cookie?token=refresh")).text()
assert session._strict_connection_sessions == {session_id: refresh_token.id}
for remote_addr in EXTERNAL_ADDRESSES:
set_mock_ip(remote_addr)
req = await client.get("/")
assert req.status == HTTPStatus.OK
assert await req.json() == {"authenticated": False}
# Invalidate refresh token, which should also invalidate session
hass.auth.async_remove_refresh_token(refresh_token)
assert session._strict_connection_sessions == {}
for remote_addr in EXTERNAL_ADDRESSES:
set_mock_ip(remote_addr)
await perform_unauthenticated_request(hass, client)
async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_temp_session(
hass: HomeAssistant,
app: web.Application,
aiohttp_client: ClientSessionGenerator,
hass_access_token: str,
perform_unauthenticated_request: Callable[
[HomeAssistant, TestClient], Awaitable[None]
],
strict_connection_mode: StrictConnectionMode,
) -> None:
"""Test external unauthenticated requests with strict connection non cloud enabled and temp cookie."""
client, set_mock_ip, _ = await _test_strict_connection_non_cloud_enabled_setup(
hass, app, aiohttp_client, hass_access_token, strict_connection_mode
)
session = hass.auth.session
# set strict connection cookie with temp session
assert session._temp_sessions == {}
set_mock_ip(LOCALHOST_ADDRESSES[0])
session_id = await (await client.get("/test/cookie?token=temp")).text()
assert client.session.cookie_jar.filter_cookies(URL("http://127.0.0.1"))
assert session_id in session._temp_sessions
for remote_addr in EXTERNAL_ADDRESSES:
set_mock_ip(remote_addr)
resp = await client.get("/")
assert resp.status == HTTPStatus.OK
assert await resp.json() == {"authenticated": False}
async_fire_time_changed(hass, utcnow() + TEMP_TIMEOUT + timedelta(minutes=1))
await hass.async_block_till_done(wait_background_tasks=True)
assert session._temp_sessions == {}
for remote_addr in EXTERNAL_ADDRESSES:
set_mock_ip(remote_addr)
await perform_unauthenticated_request(hass, client)
async def _drop_connection_unauthorized_request(
_: HomeAssistant, client: TestClient
) -> None:
with pytest.raises(ServerDisconnectedError):
# unauthorized requests should raise ServerDisconnectedError
await client.get("/")
async def _static_page_unauthorized_request(
hass: HomeAssistant, client: TestClient
) -> None:
req = await client.get("/")
assert req.status == HTTPStatus.IM_A_TEAPOT
def read_static_page() -> str:
with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
return file.read()
assert await req.text() == await hass.async_add_executor_job(read_static_page)
@pytest.mark.parametrize(
"test_func",
[
_test_strict_connection_non_cloud_enabled_external_unauthenticated_requests,
_test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_refresh_token,
_test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_temp_session,
],
ids=[
"no cookie",
"refresh token cookie",
"temp session cookie",
],
)
@pytest.mark.parametrize(
("strict_connection_mode", "request_func"),
[
(StrictConnectionMode.DROP_CONNECTION, _drop_connection_unauthorized_request),
(StrictConnectionMode.STATIC_PAGE, _static_page_unauthorized_request),
],
ids=["drop connection", "static page"],
)
async def test_strict_connection_non_cloud_external_unauthenticated_requests(
hass: HomeAssistant,
app_strict_connection: web.Application,
aiohttp_client: ClientSessionGenerator,
hass_access_token: str,
test_func: Callable[
[
HomeAssistant,
web.Application,
ClientSessionGenerator,
str,
Callable[[HomeAssistant, TestClient], Awaitable[None]],
StrictConnectionMode,
],
Awaitable[None],
],
strict_connection_mode: StrictConnectionMode,
request_func: Callable[[HomeAssistant, TestClient], Awaitable[None]],
) -> None:
"""Test external unauthenticated requests with strict connection non cloud."""
await test_func(
hass,
app_strict_connection,
aiohttp_client,
hass_access_token,
request_func,
strict_connection_mode,
)

View file

@ -7,6 +7,7 @@ from ipaddress import ip_network
import logging
from pathlib import Path
from unittest.mock import Mock, patch
from urllib.parse import quote_plus
import pytest
@ -14,7 +15,10 @@ from homeassistant.auth.providers.legacy_api_password import (
LegacyApiPasswordAuthProvider,
)
from homeassistant.components import http
from homeassistant.components.http.const import StrictConnectionMode
from homeassistant.config import async_process_ha_core_config
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ServiceValidationError
from homeassistant.helpers.http import KEY_HASS
from homeassistant.helpers.network import NoURLAvailableError
from homeassistant.setup import async_setup_component
@ -521,3 +525,78 @@ async def test_logging(
response = await client.get("/api/states/logging.entity")
assert response.status == HTTPStatus.OK
assert "GET /api/states/logging.entity" not in caplog.text
async def test_service_create_temporary_strict_connection_url_strict_connection_disabled(
hass: HomeAssistant,
) -> None:
"""Test service create_temporary_strict_connection_url with strict_connection not enabled."""
assert await async_setup_component(hass, http.DOMAIN, {"http": {}})
with pytest.raises(
ServiceValidationError,
match="Strict connection is not enabled for non-cloud requests",
):
await hass.services.async_call(
http.DOMAIN,
"create_temporary_strict_connection_url",
blocking=True,
return_response=True,
)
@pytest.mark.parametrize(
("mode"),
[
StrictConnectionMode.DROP_CONNECTION,
StrictConnectionMode.STATIC_PAGE,
],
)
async def test_service_create_temporary_strict_connection(
hass: HomeAssistant, mode: StrictConnectionMode
) -> None:
"""Test service create_temporary_strict_connection_url."""
assert await async_setup_component(
hass, http.DOMAIN, {"http": {"strict_connection": mode}}
)
# No external url set
assert hass.config.external_url is None
assert hass.config.internal_url is None
with pytest.raises(ServiceValidationError, match="No external URL available"):
await hass.services.async_call(
http.DOMAIN,
"create_temporary_strict_connection_url",
blocking=True,
return_response=True,
)
# Raise if only internal url is available
hass.config.api = Mock(use_ssl=False, port=8123, local_ip="192.168.123.123")
with pytest.raises(ServiceValidationError, match="No external URL available"):
await hass.services.async_call(
http.DOMAIN,
"create_temporary_strict_connection_url",
blocking=True,
return_response=True,
)
# Set external url too
external_url = "https://example.com"
await async_process_ha_core_config(
hass,
{"external_url": external_url},
)
assert hass.config.external_url == external_url
response = await hass.services.async_call(
http.DOMAIN,
"create_temporary_strict_connection_url",
blocking=True,
return_response=True,
)
assert isinstance(response, dict)
direct_url_prefix = f"{external_url}/auth/strict_connection/temp_token?authSig="
assert response.pop("direct_url").startswith(direct_url_prefix)
assert response.pop("url").startswith(
f"https://login.home-assistant.io?u={quote_plus(direct_url_prefix)}"
)
assert response == {} # No more keys in response

View file

@ -0,0 +1,107 @@
"""Tests for HTTP session."""
from collections.abc import Callable
import logging
from typing import Any
from unittest.mock import patch
from aiohttp import web
from aiohttp.test_utils import make_mocked_request
import pytest
from homeassistant.auth.session import SESSION_ID
from homeassistant.components.http.session import (
COOKIE_NAME,
HomeAssistantCookieStorage,
)
from homeassistant.core import HomeAssistant
def fake_request_with_strict_connection_cookie(cookie_value: str) -> web.Request:
"""Return a fake request with a strict connection cookie."""
request = make_mocked_request(
"GET", "/", headers={"Cookie": f"{COOKIE_NAME}={cookie_value}"}
)
assert COOKIE_NAME in request.cookies
return request
@pytest.fixture
def cookie_storage(hass: HomeAssistant) -> HomeAssistantCookieStorage:
"""Fixture for the cookie storage."""
return HomeAssistantCookieStorage(hass)
def _encrypt_cookie_data(cookie_storage: HomeAssistantCookieStorage, data: Any) -> str:
"""Encrypt cookie data."""
cookie_data = cookie_storage._encoder(data).encode("utf-8")
return cookie_storage._fernet.encrypt(cookie_data).decode("utf-8")
@pytest.mark.parametrize(
"func",
[
lambda _: "invalid",
lambda storage: _encrypt_cookie_data(storage, "bla"),
lambda storage: _encrypt_cookie_data(storage, None),
],
)
async def test_load_session_modified_cookies(
cookie_storage: HomeAssistantCookieStorage,
caplog: pytest.LogCaptureFixture,
func: Callable[[HomeAssistantCookieStorage], str],
) -> None:
"""Test that on modified cookies the session is empty and the request will be logged for ban."""
request = fake_request_with_strict_connection_cookie(func(cookie_storage))
with patch(
"homeassistant.components.http.session.process_wrong_login",
) as mock_process_wrong_login:
session = await cookie_storage.load_session(request)
assert session.empty
assert (
"homeassistant.components.http.session",
logging.WARNING,
"Cannot decrypt/parse cookie value",
) in caplog.record_tuples
mock_process_wrong_login.assert_called()
async def test_load_session_validate_session(
hass: HomeAssistant,
cookie_storage: HomeAssistantCookieStorage,
) -> None:
"""Test load session validates the session."""
session = await cookie_storage.new_session()
session[SESSION_ID] = "bla"
request = fake_request_with_strict_connection_cookie(
_encrypt_cookie_data(cookie_storage, cookie_storage._get_session_data(session))
)
with patch.object(
hass.auth.session, "async_validate_strict_connection_session", return_value=True
) as mock_validate:
session = await cookie_storage.load_session(request)
assert not session.empty
assert session[SESSION_ID] == "bla"
mock_validate.assert_called_with(session)
# verify lru_cache is working
mock_validate.reset_mock()
await cookie_storage.load_session(request)
mock_validate.assert_not_called()
session = await cookie_storage.new_session()
session[SESSION_ID] = "something"
request = fake_request_with_strict_connection_cookie(
_encrypt_cookie_data(cookie_storage, cookie_storage._get_session_data(session))
)
with patch.object(
hass.auth.session,
"async_validate_strict_connection_session",
return_value=False,
):
session = await cookie_storage.load_session(request)
assert session.empty
assert SESSION_ID not in session
assert session._changed

View file

@ -14,7 +14,6 @@ from __future__ import annotations
import asyncio
from collections.abc import Generator
from http import HTTPStatus
import logging
import threading
from unittest.mock import Mock, patch
@ -87,6 +86,17 @@ class HLSSync:
self._num_recvs = 0
self._num_finished = 0
def on_resp():
self._num_finished += 1
self.check_requests_ready()
class SyncResponse(web.Response):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
on_resp()
self.response = SyncResponse
def reset_request_pool(self, num_requests: int, reset_finished=True):
"""Use to reset the request counter between segments."""
self._num_recvs = 0
@ -120,12 +130,6 @@ class HLSSync:
self.check_requests_ready()
return self._original_not_found()
def response(self, body, headers=None, status=HTTPStatus.OK):
"""Intercept the Response call so we know when the web handler is finished."""
self._num_finished += 1
self.check_requests_ready()
return self._original_response(body=body, headers=headers, status=status)
async def recv(self, output: StreamOutput, **kw):
"""Intercept the recv call so we know when the response is blocking on recv."""
self._num_recvs += 1
@ -164,7 +168,7 @@ def hls_sync():
),
patch(
"homeassistant.components.stream.hls.web.Response",
side_effect=sync.response,
new=sync.response,
),
):
yield sync

View file

@ -701,7 +701,7 @@ async def test_get_services(
assert msg["id"] == id_
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == hass.services.async_services()
assert msg["result"].keys() == hass.services.async_services().keys()
async def test_get_config(

View file

@ -7,6 +7,7 @@ from typing import Any
from unittest.mock import AsyncMock, Mock, patch
import pytest
from pytest_unordered import unordered
import voluptuous as vol
# To prevent circular import when running just this file
@ -16,6 +17,7 @@ import homeassistant.components # noqa: F401
from homeassistant.components.group import DOMAIN as DOMAIN_GROUP, Group
from homeassistant.components.logger import DOMAIN as DOMAIN_LOGGER
from homeassistant.components.shell_command import DOMAIN as DOMAIN_SHELL_COMMAND
from homeassistant.components.system_health import DOMAIN as DOMAIN_SYSTEM_HEALTH
from homeassistant.const import (
ATTR_ENTITY_ID,
ENTITY_MATCH_ALL,
@ -785,7 +787,7 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
"""Test async_get_all_descriptions."""
group_config = {DOMAIN_GROUP: {}}
assert await async_setup_component(hass, DOMAIN_GROUP, group_config)
assert await async_setup_component(hass, "system_health", {})
assert await async_setup_component(hass, DOMAIN_SYSTEM_HEALTH, {})
with patch(
"homeassistant.helpers.service._load_services_files",
@ -795,17 +797,20 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
# Test we only load services.yaml for integrations with services.yaml
# And system_health has no services
assert proxy_load_services_files.mock_calls[0][1][1] == [
await async_get_integration(hass, "group")
]
assert proxy_load_services_files.mock_calls[0][1][1] == unordered(
[
await async_get_integration(hass, DOMAIN_GROUP),
await async_get_integration(hass, "http"), # system_health requires http
]
)
assert len(descriptions) == 1
assert "description" in descriptions["group"]["reload"]
assert "fields" in descriptions["group"]["reload"]
assert len(descriptions) == 2
assert DOMAIN_GROUP in descriptions
assert "description" in descriptions[DOMAIN_GROUP]["reload"]
assert "fields" in descriptions[DOMAIN_GROUP]["reload"]
# Does not have services
assert "system_health" not in descriptions
assert DOMAIN_SYSTEM_HEALTH not in descriptions
logger_config = {DOMAIN_LOGGER: {}}
@ -833,8 +838,8 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
await async_setup_component(hass, DOMAIN_LOGGER, logger_config)
descriptions = await service.async_get_all_descriptions(hass)
assert len(descriptions) == 2
assert len(descriptions) == 3
assert DOMAIN_LOGGER in descriptions
assert descriptions[DOMAIN_LOGGER]["set_default_level"]["name"] == "Translated name"
assert (
descriptions[DOMAIN_LOGGER]["set_default_level"]["description"]

View file

@ -5,6 +5,7 @@ from unittest.mock import patch
import pytest
from homeassistant.components.http.const import StrictConnectionMode
from homeassistant.config import YAML_CONFIG_FILE
from homeassistant.scripts import check_config
@ -134,6 +135,7 @@ def test_secrets(mock_is_file, event_loop, mock_hass_config_yaml: None) -> None:
"login_attempts_threshold": -1,
"server_port": 8123,
"ssl_profile": "modern",
"strict_connection": StrictConnectionMode.DISABLED,
"use_x_frame_options": True,
"server_host": ["0.0.0.0", "::"],
}