Add strict connection (#112387)
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
f70ce8abf9
commit
348e1df949
23 changed files with 1187 additions and 64 deletions
|
@ -28,6 +28,7 @@ from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRA
|
|||
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
|
||||
from .models import AuthFlowResult
|
||||
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
|
||||
from .session import SessionManager
|
||||
|
||||
EVENT_USER_ADDED = "user_added"
|
||||
EVENT_USER_UPDATED = "user_updated"
|
||||
|
@ -85,7 +86,7 @@ async def auth_manager_from_config(
|
|||
module_hash[module.id] = module
|
||||
|
||||
manager = AuthManager(hass, store, provider_hash, module_hash)
|
||||
manager.async_setup()
|
||||
await manager.async_setup()
|
||||
return manager
|
||||
|
||||
|
||||
|
@ -180,9 +181,9 @@ class AuthManager:
|
|||
self._remove_expired_job = HassJob(
|
||||
self._async_remove_expired_refresh_tokens, job_type=HassJobType.Callback
|
||||
)
|
||||
self.session = SessionManager(hass, self)
|
||||
|
||||
@callback
|
||||
def async_setup(self) -> None:
|
||||
async def async_setup(self) -> None:
|
||||
"""Set up the auth manager."""
|
||||
hass = self.hass
|
||||
hass.async_add_shutdown_job(
|
||||
|
@ -191,6 +192,7 @@ class AuthManager:
|
|||
)
|
||||
)
|
||||
self._async_track_next_refresh_token_expiration()
|
||||
await self.session.async_setup()
|
||||
|
||||
@property
|
||||
def auth_providers(self) -> list[AuthProvider]:
|
||||
|
|
205
homeassistant/auth/session.py
Normal file
205
homeassistant/auth/session.py
Normal file
|
@ -0,0 +1,205 @@
|
|||
"""Session auth module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
import secrets
|
||||
from typing import TYPE_CHECKING, Final, TypedDict
|
||||
|
||||
from aiohttp.web import Request
|
||||
from aiohttp_session import Session, get_session, new_session
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .models import RefreshToken
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import AuthManager
|
||||
|
||||
|
||||
TEMP_TIMEOUT = timedelta(minutes=5)
|
||||
TEMP_TIMEOUT_SECONDS = TEMP_TIMEOUT.total_seconds()
|
||||
|
||||
SESSION_ID = "id"
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = "auth.session"
|
||||
|
||||
|
||||
class StrictConnectionTempSessionData:
|
||||
"""Data for accessing unauthorized resources for a short period of time."""
|
||||
|
||||
__slots__ = ("cancel_remove", "absolute_expiry")
|
||||
|
||||
def __init__(self, cancel_remove: CALLBACK_TYPE) -> None:
|
||||
"""Initialize the temp session data."""
|
||||
self.cancel_remove: Final[CALLBACK_TYPE] = cancel_remove
|
||||
self.absolute_expiry: Final[datetime] = dt_util.utcnow() + TEMP_TIMEOUT
|
||||
|
||||
|
||||
class StoreData(TypedDict):
|
||||
"""Data to store."""
|
||||
|
||||
unauthorized_sessions: dict[str, str]
|
||||
key: str
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Session manager."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, auth: AuthManager) -> None:
|
||||
"""Initialize the strict connection manager."""
|
||||
self._auth = auth
|
||||
self._hass = hass
|
||||
self._temp_sessions: dict[str, StrictConnectionTempSessionData] = {}
|
||||
self._strict_connection_sessions: dict[str, str] = {}
|
||||
self._store = Store[StoreData](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||
)
|
||||
self._key: str | None = None
|
||||
self._refresh_token_revoke_callbacks: dict[str, CALLBACK_TYPE] = {}
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Return the encryption key."""
|
||||
if self._key is None:
|
||||
self._key = Fernet.generate_key().decode()
|
||||
self._async_schedule_save()
|
||||
return self._key
|
||||
|
||||
async def async_validate_request_for_strict_connection_session(
|
||||
self,
|
||||
request: Request,
|
||||
) -> bool:
|
||||
"""Check if a request has a valid strict connection session."""
|
||||
session = await get_session(request)
|
||||
if session.new or session.empty:
|
||||
return False
|
||||
result = self.async_validate_strict_connection_session(session)
|
||||
if result is False:
|
||||
session.invalidate()
|
||||
return result
|
||||
|
||||
@callback
|
||||
def async_validate_strict_connection_session(
|
||||
self,
|
||||
session: Session,
|
||||
) -> bool:
|
||||
"""Validate a strict connection session."""
|
||||
if not (session_id := session.get(SESSION_ID)):
|
||||
return False
|
||||
|
||||
if token_id := self._strict_connection_sessions.get(session_id):
|
||||
if self._auth.async_get_refresh_token(token_id):
|
||||
return True
|
||||
# refresh token is invalid, delete entry
|
||||
self._strict_connection_sessions.pop(session_id)
|
||||
self._async_schedule_save()
|
||||
|
||||
if data := self._temp_sessions.get(session_id):
|
||||
if dt_util.utcnow() <= data.absolute_expiry:
|
||||
return True
|
||||
# session expired, delete entry
|
||||
self._temp_sessions.pop(session_id).cancel_remove()
|
||||
|
||||
return False
|
||||
|
||||
@callback
|
||||
def _async_register_revoke_token_callback(self, refresh_token_id: str) -> None:
|
||||
"""Register a callback to revoke all sessions for a refresh token."""
|
||||
if refresh_token_id in self._refresh_token_revoke_callbacks:
|
||||
return
|
||||
|
||||
@callback
|
||||
def async_invalidate_auth_sessions() -> None:
|
||||
"""Invalidate all sessions for a refresh token."""
|
||||
self._strict_connection_sessions = {
|
||||
session_id: token_id
|
||||
for session_id, token_id in self._strict_connection_sessions.items()
|
||||
if token_id != refresh_token_id
|
||||
}
|
||||
self._async_schedule_save()
|
||||
|
||||
self._refresh_token_revoke_callbacks[refresh_token_id] = (
|
||||
self._auth.async_register_revoke_token_callback(
|
||||
refresh_token_id, async_invalidate_auth_sessions
|
||||
)
|
||||
)
|
||||
|
||||
async def async_create_session(
|
||||
self,
|
||||
request: Request,
|
||||
refresh_token: RefreshToken,
|
||||
) -> None:
|
||||
"""Create new session for given refresh token.
|
||||
|
||||
Caller needs to make sure that the refresh token is valid.
|
||||
By creating a session, we are implicitly revoking all other
|
||||
sessions for the given refresh token as there is one refresh
|
||||
token per device/user case.
|
||||
"""
|
||||
self._strict_connection_sessions = {
|
||||
session_id: token_id
|
||||
for session_id, token_id in self._strict_connection_sessions.items()
|
||||
if token_id != refresh_token.id
|
||||
}
|
||||
|
||||
self._async_register_revoke_token_callback(refresh_token.id)
|
||||
session_id = await self._async_create_new_session(request)
|
||||
self._strict_connection_sessions[session_id] = refresh_token.id
|
||||
self._async_schedule_save()
|
||||
|
||||
async def async_create_temp_unauthorized_session(self, request: Request) -> None:
|
||||
"""Create a temporary unauthorized session."""
|
||||
session_id = await self._async_create_new_session(
|
||||
request, max_age=int(TEMP_TIMEOUT_SECONDS)
|
||||
)
|
||||
|
||||
@callback
|
||||
def remove(_: datetime) -> None:
|
||||
self._temp_sessions.pop(session_id, None)
|
||||
|
||||
self._temp_sessions[session_id] = StrictConnectionTempSessionData(
|
||||
async_call_later(self._hass, TEMP_TIMEOUT_SECONDS, remove)
|
||||
)
|
||||
|
||||
async def _async_create_new_session(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
max_age: int | None = None,
|
||||
) -> str:
|
||||
session_id = secrets.token_hex(64)
|
||||
|
||||
session = await new_session(request)
|
||||
session[SESSION_ID] = session_id
|
||||
if max_age is not None:
|
||||
session.max_age = max_age
|
||||
return session_id
|
||||
|
||||
@callback
|
||||
def _async_schedule_save(self, delay: float = 1) -> None:
|
||||
"""Save sessions."""
|
||||
self._store.async_delay_save(self._data_to_save, delay)
|
||||
|
||||
@callback
|
||||
def _data_to_save(self) -> StoreData:
|
||||
"""Return the data to store."""
|
||||
return StoreData(
|
||||
unauthorized_sessions=self._strict_connection_sessions,
|
||||
key=self.key,
|
||||
)
|
||||
|
||||
async def async_setup(self) -> None:
|
||||
"""Set up session manager."""
|
||||
data = await self._store.async_load()
|
||||
if data is None:
|
||||
return
|
||||
|
||||
self._key = data["key"]
|
||||
self._strict_connection_sessions = data["unauthorized_sessions"]
|
||||
for token_id in self._strict_connection_sessions.values():
|
||||
self._async_register_revoke_token_callback(token_id)
|
|
@ -162,6 +162,7 @@ from homeassistant.util import dt as dt_util
|
|||
from . import indieauth, login_flow, mfa_setup_flow
|
||||
|
||||
DOMAIN = "auth"
|
||||
STRICT_CONNECTION_URL = "/auth/strict_connection/temp_token"
|
||||
|
||||
StoreResultType = Callable[[str, Credentials], str]
|
||||
RetrieveResultType = Callable[[str, str], Credentials | None]
|
||||
|
@ -187,6 +188,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
hass.http.register_view(RevokeTokenView())
|
||||
hass.http.register_view(LinkUserView(retrieve_result))
|
||||
hass.http.register_view(OAuth2AuthorizeCallbackView())
|
||||
hass.http.register_view(StrictConnectionTempTokenView())
|
||||
|
||||
websocket_api.async_register_command(hass, websocket_current_user)
|
||||
websocket_api.async_register_command(hass, websocket_create_long_lived_access_token)
|
||||
|
@ -260,10 +262,10 @@ class TokenView(HomeAssistantView):
|
|||
return await RevokeTokenView.post(self, request) # type: ignore[arg-type]
|
||||
|
||||
if grant_type == "authorization_code":
|
||||
return await self._async_handle_auth_code(hass, data, request.remote)
|
||||
return await self._async_handle_auth_code(hass, data, request)
|
||||
|
||||
if grant_type == "refresh_token":
|
||||
return await self._async_handle_refresh_token(hass, data, request.remote)
|
||||
return await self._async_handle_refresh_token(hass, data, request)
|
||||
|
||||
return self.json(
|
||||
{"error": "unsupported_grant_type"}, status_code=HTTPStatus.BAD_REQUEST
|
||||
|
@ -273,7 +275,7 @@ class TokenView(HomeAssistantView):
|
|||
self,
|
||||
hass: HomeAssistant,
|
||||
data: MultiDictProxy[str],
|
||||
remote_addr: str | None,
|
||||
request: web.Request,
|
||||
) -> web.Response:
|
||||
"""Handle authorization code request."""
|
||||
client_id = data.get("client_id")
|
||||
|
@ -313,7 +315,7 @@ class TokenView(HomeAssistantView):
|
|||
)
|
||||
try:
|
||||
access_token = hass.auth.async_create_access_token(
|
||||
refresh_token, remote_addr
|
||||
refresh_token, request.remote
|
||||
)
|
||||
except InvalidAuthError as exc:
|
||||
return self.json(
|
||||
|
@ -321,6 +323,7 @@ class TokenView(HomeAssistantView):
|
|||
status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
|
||||
await hass.auth.session.async_create_session(request, refresh_token)
|
||||
return self.json(
|
||||
{
|
||||
"access_token": access_token,
|
||||
|
@ -341,9 +344,9 @@ class TokenView(HomeAssistantView):
|
|||
self,
|
||||
hass: HomeAssistant,
|
||||
data: MultiDictProxy[str],
|
||||
remote_addr: str | None,
|
||||
request: web.Request,
|
||||
) -> web.Response:
|
||||
"""Handle authorization code request."""
|
||||
"""Handle refresh token request."""
|
||||
client_id = data.get("client_id")
|
||||
if client_id is not None and not indieauth.verify_client_id(client_id):
|
||||
return self.json(
|
||||
|
@ -381,7 +384,7 @@ class TokenView(HomeAssistantView):
|
|||
|
||||
try:
|
||||
access_token = hass.auth.async_create_access_token(
|
||||
refresh_token, remote_addr
|
||||
refresh_token, request.remote
|
||||
)
|
||||
except InvalidAuthError as exc:
|
||||
return self.json(
|
||||
|
@ -389,6 +392,7 @@ class TokenView(HomeAssistantView):
|
|||
status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
|
||||
await hass.auth.session.async_create_session(request, refresh_token)
|
||||
return self.json(
|
||||
{
|
||||
"access_token": access_token,
|
||||
|
@ -437,6 +441,20 @@ class LinkUserView(HomeAssistantView):
|
|||
return self.json_message("User linked")
|
||||
|
||||
|
||||
class StrictConnectionTempTokenView(HomeAssistantView):
|
||||
"""View to get temporary strict connection token."""
|
||||
|
||||
url = STRICT_CONNECTION_URL
|
||||
name = "api:auth:strict_connection:temp_token"
|
||||
requires_auth = False
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Get a temporary token and redirect to main page."""
|
||||
hass = request.app[KEY_HASS]
|
||||
await hass.auth.session.async_create_temp_unauthorized_session(request)
|
||||
raise web.HTTPSeeOther(location="/")
|
||||
|
||||
|
||||
@callback
|
||||
def _create_auth_code_store() -> tuple[StoreResultType, RetrieveResultType]:
|
||||
"""Create an in memory store."""
|
||||
|
|
|
@ -197,7 +197,6 @@ class HassIOIngress(HomeAssistantView):
|
|||
content_type or simple_response.content_type
|
||||
):
|
||||
simple_response.enable_compression()
|
||||
await simple_response.prepare(request)
|
||||
return simple_response
|
||||
|
||||
# Stream response
|
||||
|
|
|
@ -10,7 +10,8 @@ import os
|
|||
import socket
|
||||
import ssl
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Final, TypedDict, cast
|
||||
from typing import Any, Final, Required, TypedDict, cast
|
||||
from urllib.parse import quote_plus, urljoin
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.abc import AbstractStreamWriter
|
||||
|
@ -30,8 +31,20 @@ from yarl import URL
|
|||
|
||||
from homeassistant.components.network import async_get_source_ip
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, SERVER_PORT
|
||||
from homeassistant.core import Event, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.core import (
|
||||
Event,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
SupportsResponse,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import (
|
||||
HomeAssistantError,
|
||||
ServiceValidationError,
|
||||
Unauthorized,
|
||||
UnknownUser,
|
||||
)
|
||||
from homeassistant.helpers import storage
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.http import (
|
||||
|
@ -53,9 +66,13 @@ from homeassistant.util import dt as dt_util, ssl as ssl_util
|
|||
from homeassistant.util.async_ import create_eager_task
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
from .auth import async_setup_auth
|
||||
from .auth import async_setup_auth, async_sign_path
|
||||
from .ban import setup_bans
|
||||
from .const import KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER # noqa: F401
|
||||
from .const import ( # noqa: F401
|
||||
KEY_HASS_REFRESH_TOKEN_ID,
|
||||
KEY_HASS_USER,
|
||||
StrictConnectionMode,
|
||||
)
|
||||
from .cors import setup_cors
|
||||
from .decorators import require_admin # noqa: F401
|
||||
from .forwarded import async_setup_forwarded
|
||||
|
@ -80,6 +97,7 @@ CONF_TRUSTED_PROXIES: Final = "trusted_proxies"
|
|||
CONF_LOGIN_ATTEMPTS_THRESHOLD: Final = "login_attempts_threshold"
|
||||
CONF_IP_BAN_ENABLED: Final = "ip_ban_enabled"
|
||||
CONF_SSL_PROFILE: Final = "ssl_profile"
|
||||
CONF_STRICT_CONNECTION: Final = "strict_connection"
|
||||
|
||||
SSL_MODERN: Final = "modern"
|
||||
SSL_INTERMEDIATE: Final = "intermediate"
|
||||
|
@ -129,6 +147,9 @@ HTTP_SCHEMA: Final = vol.All(
|
|||
[SSL_INTERMEDIATE, SSL_MODERN]
|
||||
),
|
||||
vol.Optional(CONF_USE_X_FRAME_OPTIONS, default=True): cv.boolean,
|
||||
vol.Optional(
|
||||
CONF_STRICT_CONNECTION, default=StrictConnectionMode.DISABLED
|
||||
): vol.In([e.value for e in StrictConnectionMode]),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
@ -152,6 +173,7 @@ class ConfData(TypedDict, total=False):
|
|||
login_attempts_threshold: int
|
||||
ip_ban_enabled: bool
|
||||
ssl_profile: str
|
||||
strict_connection: Required[StrictConnectionMode]
|
||||
|
||||
|
||||
@bind_hass
|
||||
|
@ -218,6 +240,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
login_threshold=login_threshold,
|
||||
is_ban_enabled=is_ban_enabled,
|
||||
use_x_frame_options=use_x_frame_options,
|
||||
strict_connection_non_cloud=conf[CONF_STRICT_CONNECTION],
|
||||
)
|
||||
|
||||
async def stop_server(event: Event) -> None:
|
||||
|
@ -247,6 +270,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
local_ip, host, server_port, ssl_certificate is not None
|
||||
)
|
||||
|
||||
_setup_services(hass, conf)
|
||||
return True
|
||||
|
||||
|
||||
|
@ -331,6 +355,7 @@ class HomeAssistantHTTP:
|
|||
login_threshold: int,
|
||||
is_ban_enabled: bool,
|
||||
use_x_frame_options: bool,
|
||||
strict_connection_non_cloud: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Initialize the server."""
|
||||
self.app[KEY_HASS] = self.hass
|
||||
|
@ -347,7 +372,7 @@ class HomeAssistantHTTP:
|
|||
if is_ban_enabled:
|
||||
setup_bans(self.hass, self.app, login_threshold)
|
||||
|
||||
await async_setup_auth(self.hass, self.app)
|
||||
await async_setup_auth(self.hass, self.app, strict_connection_non_cloud)
|
||||
|
||||
setup_headers(self.app, use_x_frame_options)
|
||||
setup_cors(self.app, cors_origins)
|
||||
|
@ -577,3 +602,59 @@ async def start_http_server_and_save_config(
|
|||
]
|
||||
|
||||
store.async_delay_save(lambda: conf, SAVE_DELAY)
|
||||
|
||||
|
||||
@callback
|
||||
def _setup_services(hass: HomeAssistant, conf: ConfData) -> None:
|
||||
"""Set up services for HTTP component."""
|
||||
|
||||
async def create_temporary_strict_connection_url(
|
||||
call: ServiceCall,
|
||||
) -> ServiceResponse:
|
||||
"""Create a strict connection url and return it."""
|
||||
# Copied form homeassistant/helpers/service.py#_async_admin_handler
|
||||
# as the helper supports no responses yet
|
||||
if call.context.user_id:
|
||||
user = await hass.auth.async_get_user(call.context.user_id)
|
||||
if user is None:
|
||||
raise UnknownUser(context=call.context)
|
||||
if not user.is_admin:
|
||||
raise Unauthorized(context=call.context)
|
||||
|
||||
if conf[CONF_STRICT_CONNECTION] is StrictConnectionMode.DISABLED:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="strict_connection_not_enabled_non_cloud",
|
||||
)
|
||||
|
||||
try:
|
||||
url = get_url(hass, prefer_external=True, allow_internal=False)
|
||||
except NoURLAvailableError as ex:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="no_external_url_available",
|
||||
) from ex
|
||||
|
||||
# to avoid circular import
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
from homeassistant.components.auth import STRICT_CONNECTION_URL
|
||||
|
||||
path = async_sign_path(
|
||||
hass,
|
||||
STRICT_CONNECTION_URL,
|
||||
datetime.timedelta(hours=1),
|
||||
use_content_user=True,
|
||||
)
|
||||
url = urljoin(url, path)
|
||||
|
||||
return {
|
||||
"url": f"https://login.home-assistant.io?u={quote_plus(url)}",
|
||||
"direct_url": url,
|
||||
}
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
create_temporary_strict_connection_url,
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
|
|
|
@ -4,14 +4,18 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import timedelta
|
||||
from http import HTTPStatus
|
||||
from ipaddress import ip_address
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any, Final
|
||||
|
||||
from aiohttp import hdrs
|
||||
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||
from aiohttp.web import Application, Request, Response, StreamResponse, middleware
|
||||
from aiohttp.web_exceptions import HTTPBadRequest
|
||||
from aiohttp_session import session_middleware
|
||||
import jwt
|
||||
from jwt import api_jws
|
||||
from yarl import URL
|
||||
|
@ -27,7 +31,13 @@ from homeassistant.helpers.network import is_cloud_connection
|
|||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.util.network import is_local
|
||||
|
||||
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
|
||||
from .const import (
|
||||
KEY_AUTHENTICATED,
|
||||
KEY_HASS_REFRESH_TOKEN_ID,
|
||||
KEY_HASS_USER,
|
||||
StrictConnectionMode,
|
||||
)
|
||||
from .session import HomeAssistantCookieStorage
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -39,6 +49,10 @@ SAFE_QUERY_PARAMS: Final = ["height", "width"]
|
|||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = "http.auth"
|
||||
CONTENT_USER_NAME = "Home Assistant Content"
|
||||
STRICT_CONNECTION_EXCLUDED_PATH = "/api/webhook/"
|
||||
STRICT_CONNECTION_STATIC_PAGE = os.path.join(
|
||||
os.path.dirname(__file__), "strict_connection_static_page.html"
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -48,13 +62,16 @@ def async_sign_path(
|
|||
expiration: timedelta,
|
||||
*,
|
||||
refresh_token_id: str | None = None,
|
||||
use_content_user: bool = False,
|
||||
) -> str:
|
||||
"""Sign a path for temporary access without auth header."""
|
||||
if (secret := hass.data.get(DATA_SIGN_SECRET)) is None:
|
||||
secret = hass.data[DATA_SIGN_SECRET] = secrets.token_hex()
|
||||
|
||||
if refresh_token_id is None:
|
||||
if connection := websocket_api.current_connection.get():
|
||||
if use_content_user:
|
||||
refresh_token_id = hass.data[STORAGE_KEY]
|
||||
elif connection := websocket_api.current_connection.get():
|
||||
refresh_token_id = connection.refresh_token_id
|
||||
elif (
|
||||
request := current_request.get()
|
||||
|
@ -114,7 +131,11 @@ def async_user_not_allowed_do_auth(
|
|||
return "User cannot authenticate remotely"
|
||||
|
||||
|
||||
async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
|
||||
async def async_setup_auth(
|
||||
hass: HomeAssistant,
|
||||
app: Application,
|
||||
strict_connection_mode_non_cloud: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Create auth middleware for the app."""
|
||||
store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
if (data := await store.async_load()) is None:
|
||||
|
@ -135,6 +156,16 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
|
|||
await store.async_save(data)
|
||||
|
||||
hass.data[STORAGE_KEY] = refresh_token.id
|
||||
strict_connection_static_file_content = None
|
||||
if strict_connection_mode_non_cloud is StrictConnectionMode.STATIC_PAGE:
|
||||
|
||||
def read_static_page() -> str:
|
||||
with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
strict_connection_static_file_content = await hass.async_add_executor_job(
|
||||
read_static_page
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_validate_auth_header(request: Request) -> bool:
|
||||
|
@ -224,6 +255,22 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
|
|||
authenticated = True
|
||||
auth_type = "signed request"
|
||||
|
||||
if (
|
||||
not authenticated
|
||||
and strict_connection_mode_non_cloud is not StrictConnectionMode.DISABLED
|
||||
and not request.path.startswith(STRICT_CONNECTION_EXCLUDED_PATH)
|
||||
and not await hass.auth.session.async_validate_request_for_strict_connection_session(
|
||||
request
|
||||
)
|
||||
and (
|
||||
resp := _async_perform_action_on_non_local(
|
||||
request, strict_connection_static_file_content
|
||||
)
|
||||
)
|
||||
is not None
|
||||
):
|
||||
return resp
|
||||
|
||||
if authenticated and _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"Authenticated %s for %s using %s",
|
||||
|
@ -235,4 +282,43 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
|
|||
request[KEY_AUTHENTICATED] = authenticated
|
||||
return await handler(request)
|
||||
|
||||
app.middlewares.append(session_middleware(HomeAssistantCookieStorage(hass)))
|
||||
app.middlewares.append(auth_middleware)
|
||||
|
||||
|
||||
@callback
|
||||
def _async_perform_action_on_non_local(
|
||||
request: Request,
|
||||
strict_connection_static_file_content: str | None,
|
||||
) -> StreamResponse | None:
|
||||
"""Perform strict connection mode action if the request is not local.
|
||||
|
||||
The function does the following:
|
||||
- Try to get the IP address of the request. If it fails, assume it's not local
|
||||
- If the request is local, return None (allow the request to continue)
|
||||
- If strict_connection_static_file_content is set, return a response with the content
|
||||
- Otherwise close the connection and raise an exception
|
||||
"""
|
||||
try:
|
||||
ip_address_ = ip_address(request.remote) # type: ignore[arg-type]
|
||||
except ValueError:
|
||||
_LOGGER.debug("Invalid IP address: %s", request.remote)
|
||||
ip_address_ = None
|
||||
|
||||
if ip_address_ and is_local(ip_address_):
|
||||
return None
|
||||
|
||||
_LOGGER.debug("Perform strict connection action for %s", ip_address_)
|
||||
if strict_connection_static_file_content:
|
||||
return Response(
|
||||
text=strict_connection_static_file_content,
|
||||
content_type="text/html",
|
||||
status=HTTPStatus.IM_A_TEAPOT,
|
||||
)
|
||||
|
||||
if transport := request.transport:
|
||||
# it should never happen that we don't have a transport
|
||||
transport.close()
|
||||
|
||||
# We need to raise an exception to stop processing the request
|
||||
raise HTTPBadRequest
|
||||
|
|
|
@ -1,8 +1,17 @@
|
|||
"""HTTP specific constants."""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Final
|
||||
|
||||
from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS # noqa: F401
|
||||
|
||||
KEY_HASS_USER: Final = "hass_user"
|
||||
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"
|
||||
|
||||
|
||||
class StrictConnectionMode(StrEnum):
|
||||
"""Enum for strict connection mode."""
|
||||
|
||||
DISABLED = "disabled"
|
||||
STATIC_PAGE = "static_page"
|
||||
DROP_CONNECTION = "drop_connection"
|
||||
|
|
5
homeassistant/components/http/icons.json
Normal file
5
homeassistant/components/http/icons.json
Normal file
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
"services": {
|
||||
"create_temporary_strict_connection_url": "mdi:login-variant"
|
||||
}
|
||||
}
|
1
homeassistant/components/http/services.yaml
Normal file
1
homeassistant/components/http/services.yaml
Normal file
|
@ -0,0 +1 @@
|
|||
create_temporary_strict_connection_url: ~
|
160
homeassistant/components/http/session.py
Normal file
160
homeassistant/components/http/session.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
"""Session http module."""
|
||||
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
|
||||
from aiohttp.web import Request, StreamResponse
|
||||
from aiohttp_session import Session, SessionData
|
||||
from aiohttp_session.cookie_storage import EncryptedCookieStorage
|
||||
from cryptography.fernet import InvalidToken
|
||||
|
||||
from homeassistant.auth.const import REFRESH_TOKEN_EXPIRATION
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.json import json_dumps
|
||||
from homeassistant.helpers.network import is_cloud_connection
|
||||
from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads
|
||||
|
||||
from .ban import process_wrong_login
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
COOKIE_NAME = "SC"
|
||||
PREFIXED_COOKIE_NAME = f"__Host-{COOKIE_NAME}"
|
||||
SESSION_CACHE_SIZE = 16
|
||||
|
||||
|
||||
def _get_cookie_name(is_secure: bool) -> str:
|
||||
"""Return the cookie name."""
|
||||
return PREFIXED_COOKIE_NAME if is_secure else COOKIE_NAME
|
||||
|
||||
|
||||
class HomeAssistantCookieStorage(EncryptedCookieStorage):
|
||||
"""Home Assistant cookie storage.
|
||||
|
||||
Own class is required:
|
||||
- to set the secure flag based on the connection type
|
||||
- to use a LRU cache for session decryption
|
||||
"""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the cookie storage."""
|
||||
super().__init__(
|
||||
hass.auth.session.key,
|
||||
cookie_name=PREFIXED_COOKIE_NAME,
|
||||
max_age=int(REFRESH_TOKEN_EXPIRATION),
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
secure=True,
|
||||
encoder=json_dumps,
|
||||
decoder=json_loads,
|
||||
)
|
||||
self._hass = hass
|
||||
|
||||
def _secure_connection(self, request: Request) -> bool:
|
||||
"""Return if the connection is secure (https)."""
|
||||
return is_cloud_connection(self._hass) or request.secure
|
||||
|
||||
def load_cookie(self, request: Request) -> str | None:
|
||||
"""Load cookie."""
|
||||
is_secure = self._secure_connection(request)
|
||||
cookie_name = _get_cookie_name(is_secure)
|
||||
return request.cookies.get(cookie_name)
|
||||
|
||||
@lru_cache(maxsize=SESSION_CACHE_SIZE)
|
||||
def _decrypt_cookie(self, cookie: str) -> Session | None:
|
||||
"""Decrypt and validate cookie."""
|
||||
try:
|
||||
data = SessionData( # type: ignore[misc]
|
||||
self._decoder(
|
||||
self._fernet.decrypt(
|
||||
cookie.encode("utf-8"), ttl=self.max_age
|
||||
).decode("utf-8")
|
||||
)
|
||||
)
|
||||
except (InvalidToken, TypeError, ValueError, *JSON_DECODE_EXCEPTIONS):
|
||||
_LOGGER.warning("Cannot decrypt/parse cookie value")
|
||||
return None
|
||||
|
||||
session = Session(None, data=data, new=data is None, max_age=self.max_age)
|
||||
|
||||
# Validate session if not empty
|
||||
if (
|
||||
not session.empty
|
||||
and not self._hass.auth.session.async_validate_strict_connection_session(
|
||||
session
|
||||
)
|
||||
):
|
||||
# Invalidate session as it is not valid
|
||||
session.invalidate()
|
||||
|
||||
return session
|
||||
|
||||
async def new_session(self) -> Session:
|
||||
"""Create a new session and mark it as changed."""
|
||||
session = Session(None, data=None, new=True, max_age=self.max_age)
|
||||
session.changed()
|
||||
return session
|
||||
|
||||
async def load_session(self, request: Request) -> Session:
|
||||
"""Load session."""
|
||||
# Split parent function to use lru_cache
|
||||
if (cookie := self.load_cookie(request)) is None:
|
||||
return await self.new_session()
|
||||
|
||||
if (session := self._decrypt_cookie(cookie)) is None:
|
||||
# Decrypting/parsing failed, log wrong login and create a new session
|
||||
await process_wrong_login(request)
|
||||
session = await self.new_session()
|
||||
|
||||
return session
|
||||
|
||||
async def save_session(
|
||||
self, request: Request, response: StreamResponse, session: Session
|
||||
) -> None:
|
||||
"""Save session."""
|
||||
|
||||
is_secure = self._secure_connection(request)
|
||||
cookie_name = _get_cookie_name(is_secure)
|
||||
|
||||
if session.empty:
|
||||
response.del_cookie(cookie_name)
|
||||
else:
|
||||
params = self.cookie_params.copy()
|
||||
params["secure"] = is_secure
|
||||
params["max_age"] = session.max_age
|
||||
|
||||
cookie_data = self._encoder(self._get_session_data(session)).encode("utf-8")
|
||||
response.set_cookie(
|
||||
cookie_name,
|
||||
self._fernet.encrypt(cookie_data).decode("utf-8"),
|
||||
**params,
|
||||
)
|
||||
# Add Cache-Control header to not cache the cookie as it
|
||||
# is used for session management
|
||||
self._add_cache_control_header(response)
|
||||
|
||||
@staticmethod
|
||||
def _add_cache_control_header(response: StreamResponse) -> None:
|
||||
"""Add/set cache control header to no-cache="Set-Cookie"."""
|
||||
# Structure of the Cache-Control header defined in
|
||||
# https://datatracker.ietf.org/doc/html/rfc2068#section-14.9
|
||||
if header := response.headers.get("Cache-Control"):
|
||||
directives = []
|
||||
for directive in header.split(","):
|
||||
directive = directive.strip()
|
||||
directive_lowered = directive.lower()
|
||||
if directive_lowered.startswith("no-cache"):
|
||||
if "set-cookie" in directive_lowered or directive.find("=") == -1:
|
||||
# Set-Cookie is already in the no-cache directive or
|
||||
# the whole request should not be cached -> Nothing to do
|
||||
return
|
||||
|
||||
# Add Set-Cookie to the no-cache
|
||||
# [:-1] to remove the " at the end of the directive
|
||||
directive = f"{directive[:-1]}, Set-Cookie"
|
||||
|
||||
directives.append(directive)
|
||||
header = ", ".join(directives)
|
||||
else:
|
||||
header = 'no-cache="Set-Cookie"'
|
||||
response.headers["Cache-Control"] = header
|
|
@ -0,0 +1,46 @@
|
|||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>I'm a Teapot</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
background-color: #f3f3f3;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100vh;
|
||||
}
|
||||
.container {
|
||||
text-align: center;
|
||||
}
|
||||
h1 {
|
||||
font-size: 36px;
|
||||
color: #333;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
p {
|
||||
font-size: 18px;
|
||||
color: #666;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.teapot {
|
||||
font-size: 60px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Error 418: I'm a Teapot</h1>
|
||||
<p>
|
||||
Oops! Looks like the server is taking a coffee break.<br />
|
||||
Don't worry, it'll be back to brewing your requests in no time!
|
||||
</p>
|
||||
<p class="teapot">☕</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
16
homeassistant/components/http/strings.json
Normal file
16
homeassistant/components/http/strings.json
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"exceptions": {
|
||||
"strict_connection_not_enabled_non_cloud": {
|
||||
"message": "Strict connection is not enabled for non-cloud requests"
|
||||
},
|
||||
"no_external_url_available": {
|
||||
"message": "No external URL available"
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"create_temporary_strict_connection_url": {
|
||||
"name": "Create a temporary strict connection URL",
|
||||
"description": "Create a temporary strict connection URL, which can be used to login on another device."
|
||||
}
|
||||
}
|
||||
}
|
|
@ -7,6 +7,7 @@ aiohttp-fast-url-dispatcher==0.3.0
|
|||
aiohttp-zlib-ng==0.3.1
|
||||
aiohttp==3.9.4
|
||||
aiohttp_cors==0.7.0
|
||||
aiohttp_session==2.12.0
|
||||
astral==2.2
|
||||
async-interrupt==1.1.1
|
||||
async-upnp-client==0.38.3
|
||||
|
|
|
@ -26,6 +26,7 @@ dependencies = [
|
|||
"aiodns==3.2.0",
|
||||
"aiohttp==3.9.4",
|
||||
"aiohttp_cors==0.7.0",
|
||||
"aiohttp_session==2.12.0",
|
||||
"aiohttp-fast-url-dispatcher==0.3.0",
|
||||
"aiohttp-zlib-ng==0.3.1",
|
||||
"astral==2.2",
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
aiodns==3.2.0
|
||||
aiohttp==3.9.4
|
||||
aiohttp_cors==0.7.0
|
||||
aiohttp_session==2.12.0
|
||||
aiohttp-fast-url-dispatcher==0.3.0
|
||||
aiohttp-zlib-ng==0.3.1
|
||||
astral==2.2
|
||||
|
|
|
@ -306,7 +306,7 @@ async def test_api_get_services(
|
|||
for serv_domain in data:
|
||||
local = local_services.pop(serv_domain["domain"])
|
||||
|
||||
assert serv_domain["services"] == local
|
||||
assert serv_domain["services"].keys() == local.keys()
|
||||
|
||||
|
||||
async def test_api_call_service_no_data(
|
||||
|
|
|
@ -1,22 +1,28 @@
|
|||
"""The tests for the Home Assistant HTTP component."""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import timedelta
|
||||
from http import HTTPStatus
|
||||
from ipaddress import ip_network
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from aiohttp import BasicAuth, web
|
||||
from aiohttp import BasicAuth, ServerDisconnectedError, web
|
||||
from aiohttp.test_utils import TestClient
|
||||
from aiohttp.web_exceptions import HTTPUnauthorized
|
||||
from aiohttp_session import get_session
|
||||
import jwt
|
||||
import pytest
|
||||
import yarl
|
||||
from yarl import URL
|
||||
|
||||
from homeassistant.auth.const import GROUP_ID_READ_ONLY
|
||||
from homeassistant.auth.models import User
|
||||
from homeassistant.auth.models import RefreshToken, User
|
||||
from homeassistant.auth.providers import trusted_networks
|
||||
from homeassistant.auth.providers.legacy_api_password import (
|
||||
LegacyApiPasswordAuthProvider,
|
||||
)
|
||||
from homeassistant.auth.session import SESSION_ID, TEMP_TIMEOUT
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.http import KEY_HASS
|
||||
from homeassistant.components.http.auth import (
|
||||
|
@ -24,11 +30,12 @@ from homeassistant.components.http.auth import (
|
|||
DATA_SIGN_SECRET,
|
||||
SIGN_QUERY_PARAM,
|
||||
STORAGE_KEY,
|
||||
STRICT_CONNECTION_STATIC_PAGE,
|
||||
async_setup_auth,
|
||||
async_sign_path,
|
||||
async_user_not_allowed_do_auth,
|
||||
)
|
||||
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
||||
from homeassistant.components.http.const import KEY_AUTHENTICATED, StrictConnectionMode
|
||||
from homeassistant.components.http.forwarded import async_setup_forwarded
|
||||
from homeassistant.components.http.request_context import (
|
||||
current_request,
|
||||
|
@ -36,13 +43,15 @@ from homeassistant.components.http.request_context import (
|
|||
)
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from . import HTTP_HEADER_HA_AUTH
|
||||
|
||||
from tests.common import MockUser
|
||||
from tests.common import MockUser, async_fire_time_changed
|
||||
from tests.test_util import mock_real_ip
|
||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
API_PASSWORD = "test-password"
|
||||
|
||||
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
|
||||
|
@ -54,7 +63,13 @@ TRUSTED_NETWORKS = [
|
|||
]
|
||||
TRUSTED_ADDRESSES = ["100.64.0.1", "192.0.2.100", "FD01:DB8::1", "2001:DB8:ABCD::1"]
|
||||
EXTERNAL_ADDRESSES = ["198.51.100.1", "2001:DB8:FA1::1"]
|
||||
UNTRUSTED_ADDRESSES = [*EXTERNAL_ADDRESSES, "127.0.0.1", "::1"]
|
||||
LOCALHOST_ADDRESSES = ["127.0.0.1", "::1"]
|
||||
UNTRUSTED_ADDRESSES = [*EXTERNAL_ADDRESSES, *LOCALHOST_ADDRESSES]
|
||||
PRIVATE_ADDRESSES = [
|
||||
"192.168.10.10",
|
||||
"172.16.4.20",
|
||||
"10.100.50.5",
|
||||
]
|
||||
|
||||
|
||||
async def mock_handler(request):
|
||||
|
@ -122,7 +137,7 @@ async def test_cant_access_with_password_in_header(
|
|||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test access with password in header."""
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||
|
@ -139,7 +154,7 @@ async def test_cant_access_with_password_in_query(
|
|||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test access with password in URL."""
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.get("/", params={"api_password": API_PASSWORD})
|
||||
|
@ -159,7 +174,7 @@ async def test_basic_auth_does_not_work(
|
|||
legacy_auth: LegacyApiPasswordAuthProvider,
|
||||
) -> None:
|
||||
"""Test access with basic authentication."""
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
req = await client.get("/", auth=BasicAuth("homeassistant", API_PASSWORD))
|
||||
|
@ -183,7 +198,7 @@ async def test_cannot_access_with_trusted_ip(
|
|||
hass_owner_user: MockUser,
|
||||
) -> None:
|
||||
"""Test access with an untrusted ip address."""
|
||||
await async_setup_auth(hass, app2)
|
||||
await async_setup_auth(hass, app2, StrictConnectionMode.DISABLED)
|
||||
|
||||
set_mock_ip = mock_real_ip(app2)
|
||||
client = await aiohttp_client(app2)
|
||||
|
@ -211,7 +226,7 @@ async def test_auth_active_access_with_access_token_in_header(
|
|||
) -> None:
|
||||
"""Test access with access token in header."""
|
||||
token = hass_access_token
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
client = await aiohttp_client(app)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
|
@ -247,7 +262,7 @@ async def test_auth_active_access_with_trusted_ip(
|
|||
hass_owner_user: MockUser,
|
||||
) -> None:
|
||||
"""Test access with an untrusted ip address."""
|
||||
await async_setup_auth(hass, app2)
|
||||
await async_setup_auth(hass, app2, StrictConnectionMode.DISABLED)
|
||||
|
||||
set_mock_ip = mock_real_ip(app2)
|
||||
client = await aiohttp_client(app2)
|
||||
|
@ -274,7 +289,7 @@ async def test_auth_legacy_support_api_password_cannot_access(
|
|||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test access using api_password if auth.support_legacy."""
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
req = await client.get("/", headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||
|
@ -296,7 +311,7 @@ async def test_auth_access_signed_path_with_refresh_token(
|
|||
"""Test access with signed url."""
|
||||
app.router.add_post("/", mock_handler)
|
||||
app.router.add_get("/another_path", mock_handler)
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
@ -341,7 +356,7 @@ async def test_auth_access_signed_path_with_query_param(
|
|||
"""Test access with signed url and query params."""
|
||||
app.router.add_post("/", mock_handler)
|
||||
app.router.add_get("/another_path", mock_handler)
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
@ -371,7 +386,7 @@ async def test_auth_access_signed_path_with_query_param_order(
|
|||
"""Test access with signed url and query params different order."""
|
||||
app.router.add_post("/", mock_handler)
|
||||
app.router.add_get("/another_path", mock_handler)
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
@ -412,7 +427,7 @@ async def test_auth_access_signed_path_with_query_param_safe_param(
|
|||
"""Test access with signed url and changing a safe param."""
|
||||
app.router.add_post("/", mock_handler)
|
||||
app.router.add_get("/another_path", mock_handler)
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
@ -451,7 +466,7 @@ async def test_auth_access_signed_path_with_query_param_tamper(
|
|||
"""Test access with signed url and query params that have been tampered with."""
|
||||
app.router.add_post("/", mock_handler)
|
||||
app.router.add_get("/another_path", mock_handler)
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
@ -520,7 +535,7 @@ async def test_auth_access_signed_path_with_http(
|
|||
)
|
||||
|
||||
app.router.add_get("/hello", mock_handler)
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
@ -544,7 +559,7 @@ async def test_auth_access_signed_path_with_content_user(
|
|||
hass: HomeAssistant, app, aiohttp_client: ClientSessionGenerator
|
||||
) -> None:
|
||||
"""Test access signed url uses content user."""
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
signed_path = async_sign_path(hass, "/", timedelta(seconds=5))
|
||||
signature = yarl.URL(signed_path).query["authSig"]
|
||||
claims = jwt.decode(
|
||||
|
@ -564,7 +579,7 @@ async def test_local_only_user_rejected(
|
|||
) -> None:
|
||||
"""Test access with access token in header."""
|
||||
token = hass_access_token
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
set_mock_ip = mock_real_ip(app)
|
||||
client = await aiohttp_client(app)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
@ -630,7 +645,7 @@ async def test_create_user_once(hass: HomeAssistant) -> None:
|
|||
"""Test that we reuse the user."""
|
||||
cur_users = len(await hass.auth.async_get_users())
|
||||
app = web.Application()
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
users = await hass.auth.async_get_users()
|
||||
assert len(users) == cur_users + 1
|
||||
|
||||
|
@ -642,7 +657,287 @@ async def test_create_user_once(hass: HomeAssistant) -> None:
|
|||
assert len(user.refresh_tokens) == 1
|
||||
assert user.system_generated
|
||||
|
||||
await async_setup_auth(hass, app)
|
||||
await async_setup_auth(hass, app, StrictConnectionMode.DISABLED)
|
||||
|
||||
# test it did not create a user
|
||||
assert len(await hass.auth.async_get_users()) == cur_users + 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_strict_connection(hass):
|
||||
"""Fixture to set up a web.Application."""
|
||||
|
||||
async def handler(request):
|
||||
"""Return if request was authenticated."""
|
||||
return web.json_response(data={"authenticated": request[KEY_AUTHENTICATED]})
|
||||
|
||||
app = web.Application()
|
||||
app[KEY_HASS] = hass
|
||||
app.router.add_get("/", handler)
|
||||
async_setup_forwarded(app, True, [])
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"strict_connection_mode", [e.value for e in StrictConnectionMode]
|
||||
)
|
||||
async def test_strict_connection_non_cloud_authenticated_requests(
|
||||
hass: HomeAssistant,
|
||||
app_strict_connection: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
hass_access_token: str,
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test authenticated requests with strict connection."""
|
||||
token = hass_access_token
|
||||
await async_setup_auth(hass, app_strict_connection, strict_connection_mode)
|
||||
set_mock_ip = mock_real_ip(app_strict_connection)
|
||||
client = await aiohttp_client(app_strict_connection)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
assert refresh_token
|
||||
assert hass.auth.session._strict_connection_sessions == {}
|
||||
|
||||
signed_path = async_sign_path(
|
||||
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
|
||||
)
|
||||
|
||||
for remote_addr in (*LOCALHOST_ADDRESSES, *PRIVATE_ADDRESSES, *EXTERNAL_ADDRESSES):
|
||||
set_mock_ip(remote_addr)
|
||||
|
||||
# authorized requests should work normally
|
||||
req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.json() == {"authenticated": True}
|
||||
req = await client.get(signed_path)
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.json() == {"authenticated": True}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"strict_connection_mode", [e.value for e in StrictConnectionMode]
|
||||
)
|
||||
async def test_strict_connection_non_cloud_local_unauthenticated_requests(
|
||||
hass: HomeAssistant,
|
||||
app_strict_connection: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test local unauthenticated requests with strict connection."""
|
||||
await async_setup_auth(hass, app_strict_connection, strict_connection_mode)
|
||||
set_mock_ip = mock_real_ip(app_strict_connection)
|
||||
client = await aiohttp_client(app_strict_connection)
|
||||
assert hass.auth.session._strict_connection_sessions == {}
|
||||
|
||||
for remote_addr in (*LOCALHOST_ADDRESSES, *PRIVATE_ADDRESSES):
|
||||
set_mock_ip(remote_addr)
|
||||
# local requests should work normally
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.json() == {"authenticated": False}
|
||||
|
||||
|
||||
def _add_set_cookie_endpoint(app: web.Application, refresh_token: RefreshToken) -> None:
|
||||
"""Add an endpoint to set a cookie."""
|
||||
|
||||
async def set_cookie(request: web.Request) -> web.Response:
|
||||
hass = request.app[KEY_HASS]
|
||||
# Clear all sessions
|
||||
hass.auth.session._temp_sessions.clear()
|
||||
hass.auth.session._strict_connection_sessions.clear()
|
||||
|
||||
if request.query["token"] == "refresh":
|
||||
await hass.auth.session.async_create_session(request, refresh_token)
|
||||
else:
|
||||
await hass.auth.session.async_create_temp_unauthorized_session(request)
|
||||
session = await get_session(request)
|
||||
return web.Response(text=session[SESSION_ID])
|
||||
|
||||
app.router.add_get("/test/cookie", set_cookie)
|
||||
|
||||
|
||||
async def _test_strict_connection_non_cloud_enabled_setup(
|
||||
hass: HomeAssistant,
|
||||
app: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
hass_access_token: str,
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
) -> tuple[TestClient, Callable[[str], None], RefreshToken]:
|
||||
"""Test external unauthenticated requests with strict connection non cloud enabled."""
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
assert refresh_token
|
||||
session = hass.auth.session
|
||||
assert session._strict_connection_sessions == {}
|
||||
assert session._temp_sessions == {}
|
||||
|
||||
_add_set_cookie_endpoint(app, refresh_token)
|
||||
await async_setup_auth(hass, app, strict_connection_mode)
|
||||
set_mock_ip = mock_real_ip(app)
|
||||
client = await aiohttp_client(app)
|
||||
return (client, set_mock_ip, refresh_token)
|
||||
|
||||
|
||||
async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests(
|
||||
hass: HomeAssistant,
|
||||
app: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
hass_access_token: str,
|
||||
perform_unauthenticated_request: Callable[
|
||||
[HomeAssistant, TestClient], Awaitable[None]
|
||||
],
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection non cloud enabled."""
|
||||
client, set_mock_ip, _ = await _test_strict_connection_non_cloud_enabled_setup(
|
||||
hass, app, aiohttp_client, hass_access_token, strict_connection_mode
|
||||
)
|
||||
|
||||
for remote_addr in EXTERNAL_ADDRESSES:
|
||||
set_mock_ip(remote_addr)
|
||||
await perform_unauthenticated_request(hass, client)
|
||||
|
||||
|
||||
async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_refresh_token(
|
||||
hass: HomeAssistant,
|
||||
app: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
hass_access_token: str,
|
||||
perform_unauthenticated_request: Callable[
|
||||
[HomeAssistant, TestClient], Awaitable[None]
|
||||
],
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection non cloud enabled and refresh token cookie."""
|
||||
(
|
||||
client,
|
||||
set_mock_ip,
|
||||
refresh_token,
|
||||
) = await _test_strict_connection_non_cloud_enabled_setup(
|
||||
hass, app, aiohttp_client, hass_access_token, strict_connection_mode
|
||||
)
|
||||
session = hass.auth.session
|
||||
|
||||
# set strict connection cookie with refresh token
|
||||
set_mock_ip(LOCALHOST_ADDRESSES[0])
|
||||
session_id = await (await client.get("/test/cookie?token=refresh")).text()
|
||||
assert session._strict_connection_sessions == {session_id: refresh_token.id}
|
||||
for remote_addr in EXTERNAL_ADDRESSES:
|
||||
set_mock_ip(remote_addr)
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.json() == {"authenticated": False}
|
||||
|
||||
# Invalidate refresh token, which should also invalidate session
|
||||
hass.auth.async_remove_refresh_token(refresh_token)
|
||||
assert session._strict_connection_sessions == {}
|
||||
for remote_addr in EXTERNAL_ADDRESSES:
|
||||
set_mock_ip(remote_addr)
|
||||
await perform_unauthenticated_request(hass, client)
|
||||
|
||||
|
||||
async def _test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_temp_session(
|
||||
hass: HomeAssistant,
|
||||
app: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
hass_access_token: str,
|
||||
perform_unauthenticated_request: Callable[
|
||||
[HomeAssistant, TestClient], Awaitable[None]
|
||||
],
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection non cloud enabled and temp cookie."""
|
||||
client, set_mock_ip, _ = await _test_strict_connection_non_cloud_enabled_setup(
|
||||
hass, app, aiohttp_client, hass_access_token, strict_connection_mode
|
||||
)
|
||||
session = hass.auth.session
|
||||
|
||||
# set strict connection cookie with temp session
|
||||
assert session._temp_sessions == {}
|
||||
set_mock_ip(LOCALHOST_ADDRESSES[0])
|
||||
session_id = await (await client.get("/test/cookie?token=temp")).text()
|
||||
assert client.session.cookie_jar.filter_cookies(URL("http://127.0.0.1"))
|
||||
assert session_id in session._temp_sessions
|
||||
for remote_addr in EXTERNAL_ADDRESSES:
|
||||
set_mock_ip(remote_addr)
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.OK
|
||||
assert await resp.json() == {"authenticated": False}
|
||||
|
||||
async_fire_time_changed(hass, utcnow() + TEMP_TIMEOUT + timedelta(minutes=1))
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
assert session._temp_sessions == {}
|
||||
for remote_addr in EXTERNAL_ADDRESSES:
|
||||
set_mock_ip(remote_addr)
|
||||
await perform_unauthenticated_request(hass, client)
|
||||
|
||||
|
||||
async def _drop_connection_unauthorized_request(
|
||||
_: HomeAssistant, client: TestClient
|
||||
) -> None:
|
||||
with pytest.raises(ServerDisconnectedError):
|
||||
# unauthorized requests should raise ServerDisconnectedError
|
||||
await client.get("/")
|
||||
|
||||
|
||||
async def _static_page_unauthorized_request(
|
||||
hass: HomeAssistant, client: TestClient
|
||||
) -> None:
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.IM_A_TEAPOT
|
||||
|
||||
def read_static_page() -> str:
|
||||
with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
assert await req.text() == await hass.async_add_executor_job(read_static_page)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_func",
|
||||
[
|
||||
_test_strict_connection_non_cloud_enabled_external_unauthenticated_requests,
|
||||
_test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_refresh_token,
|
||||
_test_strict_connection_non_cloud_enabled_external_unauthenticated_requests_temp_session,
|
||||
],
|
||||
ids=[
|
||||
"no cookie",
|
||||
"refresh token cookie",
|
||||
"temp session cookie",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("strict_connection_mode", "request_func"),
|
||||
[
|
||||
(StrictConnectionMode.DROP_CONNECTION, _drop_connection_unauthorized_request),
|
||||
(StrictConnectionMode.STATIC_PAGE, _static_page_unauthorized_request),
|
||||
],
|
||||
ids=["drop connection", "static page"],
|
||||
)
|
||||
async def test_strict_connection_non_cloud_external_unauthenticated_requests(
|
||||
hass: HomeAssistant,
|
||||
app_strict_connection: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
hass_access_token: str,
|
||||
test_func: Callable[
|
||||
[
|
||||
HomeAssistant,
|
||||
web.Application,
|
||||
ClientSessionGenerator,
|
||||
str,
|
||||
Callable[[HomeAssistant, TestClient], Awaitable[None]],
|
||||
StrictConnectionMode,
|
||||
],
|
||||
Awaitable[None],
|
||||
],
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
request_func: Callable[[HomeAssistant, TestClient], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection non cloud."""
|
||||
await test_func(
|
||||
hass,
|
||||
app_strict_connection,
|
||||
aiohttp_client,
|
||||
hass_access_token,
|
||||
request_func,
|
||||
strict_connection_mode,
|
||||
)
|
||||
|
|
|
@ -7,6 +7,7 @@ from ipaddress import ip_network
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -14,7 +15,10 @@ from homeassistant.auth.providers.legacy_api_password import (
|
|||
LegacyApiPasswordAuthProvider,
|
||||
)
|
||||
from homeassistant.components import http
|
||||
from homeassistant.components.http.const import StrictConnectionMode
|
||||
from homeassistant.config import async_process_ha_core_config
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ServiceValidationError
|
||||
from homeassistant.helpers.http import KEY_HASS
|
||||
from homeassistant.helpers.network import NoURLAvailableError
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
@ -521,3 +525,78 @@ async def test_logging(
|
|||
response = await client.get("/api/states/logging.entity")
|
||||
assert response.status == HTTPStatus.OK
|
||||
assert "GET /api/states/logging.entity" not in caplog.text
|
||||
|
||||
|
||||
async def test_service_create_temporary_strict_connection_url_strict_connection_disabled(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test service create_temporary_strict_connection_url with strict_connection not enabled."""
|
||||
assert await async_setup_component(hass, http.DOMAIN, {"http": {}})
|
||||
with pytest.raises(
|
||||
ServiceValidationError,
|
||||
match="Strict connection is not enabled for non-cloud requests",
|
||||
):
|
||||
await hass.services.async_call(
|
||||
http.DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode"),
|
||||
[
|
||||
StrictConnectionMode.DROP_CONNECTION,
|
||||
StrictConnectionMode.STATIC_PAGE,
|
||||
],
|
||||
)
|
||||
async def test_service_create_temporary_strict_connection(
|
||||
hass: HomeAssistant, mode: StrictConnectionMode
|
||||
) -> None:
|
||||
"""Test service create_temporary_strict_connection_url."""
|
||||
assert await async_setup_component(
|
||||
hass, http.DOMAIN, {"http": {"strict_connection": mode}}
|
||||
)
|
||||
|
||||
# No external url set
|
||||
assert hass.config.external_url is None
|
||||
assert hass.config.internal_url is None
|
||||
with pytest.raises(ServiceValidationError, match="No external URL available"):
|
||||
await hass.services.async_call(
|
||||
http.DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
# Raise if only internal url is available
|
||||
hass.config.api = Mock(use_ssl=False, port=8123, local_ip="192.168.123.123")
|
||||
with pytest.raises(ServiceValidationError, match="No external URL available"):
|
||||
await hass.services.async_call(
|
||||
http.DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
# Set external url too
|
||||
external_url = "https://example.com"
|
||||
await async_process_ha_core_config(
|
||||
hass,
|
||||
{"external_url": external_url},
|
||||
)
|
||||
assert hass.config.external_url == external_url
|
||||
response = await hass.services.async_call(
|
||||
http.DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
assert isinstance(response, dict)
|
||||
direct_url_prefix = f"{external_url}/auth/strict_connection/temp_token?authSig="
|
||||
assert response.pop("direct_url").startswith(direct_url_prefix)
|
||||
assert response.pop("url").startswith(
|
||||
f"https://login.home-assistant.io?u={quote_plus(direct_url_prefix)}"
|
||||
)
|
||||
assert response == {} # No more keys in response
|
||||
|
|
107
tests/components/http/test_session.py
Normal file
107
tests/components/http/test_session.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
"""Tests for HTTP session."""
|
||||
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import make_mocked_request
|
||||
import pytest
|
||||
|
||||
from homeassistant.auth.session import SESSION_ID
|
||||
from homeassistant.components.http.session import (
|
||||
COOKIE_NAME,
|
||||
HomeAssistantCookieStorage,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
def fake_request_with_strict_connection_cookie(cookie_value: str) -> web.Request:
|
||||
"""Return a fake request with a strict connection cookie."""
|
||||
request = make_mocked_request(
|
||||
"GET", "/", headers={"Cookie": f"{COOKIE_NAME}={cookie_value}"}
|
||||
)
|
||||
assert COOKIE_NAME in request.cookies
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cookie_storage(hass: HomeAssistant) -> HomeAssistantCookieStorage:
|
||||
"""Fixture for the cookie storage."""
|
||||
return HomeAssistantCookieStorage(hass)
|
||||
|
||||
|
||||
def _encrypt_cookie_data(cookie_storage: HomeAssistantCookieStorage, data: Any) -> str:
|
||||
"""Encrypt cookie data."""
|
||||
cookie_data = cookie_storage._encoder(data).encode("utf-8")
|
||||
return cookie_storage._fernet.encrypt(cookie_data).decode("utf-8")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func",
|
||||
[
|
||||
lambda _: "invalid",
|
||||
lambda storage: _encrypt_cookie_data(storage, "bla"),
|
||||
lambda storage: _encrypt_cookie_data(storage, None),
|
||||
],
|
||||
)
|
||||
async def test_load_session_modified_cookies(
|
||||
cookie_storage: HomeAssistantCookieStorage,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
func: Callable[[HomeAssistantCookieStorage], str],
|
||||
) -> None:
|
||||
"""Test that on modified cookies the session is empty and the request will be logged for ban."""
|
||||
request = fake_request_with_strict_connection_cookie(func(cookie_storage))
|
||||
with patch(
|
||||
"homeassistant.components.http.session.process_wrong_login",
|
||||
) as mock_process_wrong_login:
|
||||
session = await cookie_storage.load_session(request)
|
||||
assert session.empty
|
||||
assert (
|
||||
"homeassistant.components.http.session",
|
||||
logging.WARNING,
|
||||
"Cannot decrypt/parse cookie value",
|
||||
) in caplog.record_tuples
|
||||
mock_process_wrong_login.assert_called()
|
||||
|
||||
|
||||
async def test_load_session_validate_session(
|
||||
hass: HomeAssistant,
|
||||
cookie_storage: HomeAssistantCookieStorage,
|
||||
) -> None:
|
||||
"""Test load session validates the session."""
|
||||
session = await cookie_storage.new_session()
|
||||
session[SESSION_ID] = "bla"
|
||||
request = fake_request_with_strict_connection_cookie(
|
||||
_encrypt_cookie_data(cookie_storage, cookie_storage._get_session_data(session))
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
hass.auth.session, "async_validate_strict_connection_session", return_value=True
|
||||
) as mock_validate:
|
||||
session = await cookie_storage.load_session(request)
|
||||
assert not session.empty
|
||||
assert session[SESSION_ID] == "bla"
|
||||
mock_validate.assert_called_with(session)
|
||||
|
||||
# verify lru_cache is working
|
||||
mock_validate.reset_mock()
|
||||
await cookie_storage.load_session(request)
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
session = await cookie_storage.new_session()
|
||||
session[SESSION_ID] = "something"
|
||||
request = fake_request_with_strict_connection_cookie(
|
||||
_encrypt_cookie_data(cookie_storage, cookie_storage._get_session_data(session))
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
hass.auth.session,
|
||||
"async_validate_strict_connection_session",
|
||||
return_value=False,
|
||||
):
|
||||
session = await cookie_storage.load_session(request)
|
||||
assert session.empty
|
||||
assert SESSION_ID not in session
|
||||
assert session._changed
|
|
@ -14,7 +14,6 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from collections.abc import Generator
|
||||
from http import HTTPStatus
|
||||
import logging
|
||||
import threading
|
||||
from unittest.mock import Mock, patch
|
||||
|
@ -87,6 +86,17 @@ class HLSSync:
|
|||
self._num_recvs = 0
|
||||
self._num_finished = 0
|
||||
|
||||
def on_resp():
|
||||
self._num_finished += 1
|
||||
self.check_requests_ready()
|
||||
|
||||
class SyncResponse(web.Response):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
on_resp()
|
||||
|
||||
self.response = SyncResponse
|
||||
|
||||
def reset_request_pool(self, num_requests: int, reset_finished=True):
|
||||
"""Use to reset the request counter between segments."""
|
||||
self._num_recvs = 0
|
||||
|
@ -120,12 +130,6 @@ class HLSSync:
|
|||
self.check_requests_ready()
|
||||
return self._original_not_found()
|
||||
|
||||
def response(self, body, headers=None, status=HTTPStatus.OK):
|
||||
"""Intercept the Response call so we know when the web handler is finished."""
|
||||
self._num_finished += 1
|
||||
self.check_requests_ready()
|
||||
return self._original_response(body=body, headers=headers, status=status)
|
||||
|
||||
async def recv(self, output: StreamOutput, **kw):
|
||||
"""Intercept the recv call so we know when the response is blocking on recv."""
|
||||
self._num_recvs += 1
|
||||
|
@ -164,7 +168,7 @@ def hls_sync():
|
|||
),
|
||||
patch(
|
||||
"homeassistant.components.stream.hls.web.Response",
|
||||
side_effect=sync.response,
|
||||
new=sync.response,
|
||||
),
|
||||
):
|
||||
yield sync
|
||||
|
|
|
@ -701,7 +701,7 @@ async def test_get_services(
|
|||
assert msg["id"] == id_
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
assert msg["result"] == hass.services.async_services()
|
||||
assert msg["result"].keys() == hass.services.async_services().keys()
|
||||
|
||||
|
||||
async def test_get_config(
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Any
|
|||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_unordered import unordered
|
||||
import voluptuous as vol
|
||||
|
||||
# To prevent circular import when running just this file
|
||||
|
@ -16,6 +17,7 @@ import homeassistant.components # noqa: F401
|
|||
from homeassistant.components.group import DOMAIN as DOMAIN_GROUP, Group
|
||||
from homeassistant.components.logger import DOMAIN as DOMAIN_LOGGER
|
||||
from homeassistant.components.shell_command import DOMAIN as DOMAIN_SHELL_COMMAND
|
||||
from homeassistant.components.system_health import DOMAIN as DOMAIN_SYSTEM_HEALTH
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
ENTITY_MATCH_ALL,
|
||||
|
@ -785,7 +787,7 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
|
|||
"""Test async_get_all_descriptions."""
|
||||
group_config = {DOMAIN_GROUP: {}}
|
||||
assert await async_setup_component(hass, DOMAIN_GROUP, group_config)
|
||||
assert await async_setup_component(hass, "system_health", {})
|
||||
assert await async_setup_component(hass, DOMAIN_SYSTEM_HEALTH, {})
|
||||
|
||||
with patch(
|
||||
"homeassistant.helpers.service._load_services_files",
|
||||
|
@ -795,17 +797,20 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
|
|||
|
||||
# Test we only load services.yaml for integrations with services.yaml
|
||||
# And system_health has no services
|
||||
assert proxy_load_services_files.mock_calls[0][1][1] == [
|
||||
await async_get_integration(hass, "group")
|
||||
]
|
||||
assert proxy_load_services_files.mock_calls[0][1][1] == unordered(
|
||||
[
|
||||
await async_get_integration(hass, DOMAIN_GROUP),
|
||||
await async_get_integration(hass, "http"), # system_health requires http
|
||||
]
|
||||
)
|
||||
|
||||
assert len(descriptions) == 1
|
||||
|
||||
assert "description" in descriptions["group"]["reload"]
|
||||
assert "fields" in descriptions["group"]["reload"]
|
||||
assert len(descriptions) == 2
|
||||
assert DOMAIN_GROUP in descriptions
|
||||
assert "description" in descriptions[DOMAIN_GROUP]["reload"]
|
||||
assert "fields" in descriptions[DOMAIN_GROUP]["reload"]
|
||||
|
||||
# Does not have services
|
||||
assert "system_health" not in descriptions
|
||||
assert DOMAIN_SYSTEM_HEALTH not in descriptions
|
||||
|
||||
logger_config = {DOMAIN_LOGGER: {}}
|
||||
|
||||
|
@ -833,8 +838,8 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
|
|||
await async_setup_component(hass, DOMAIN_LOGGER, logger_config)
|
||||
descriptions = await service.async_get_all_descriptions(hass)
|
||||
|
||||
assert len(descriptions) == 2
|
||||
|
||||
assert len(descriptions) == 3
|
||||
assert DOMAIN_LOGGER in descriptions
|
||||
assert descriptions[DOMAIN_LOGGER]["set_default_level"]["name"] == "Translated name"
|
||||
assert (
|
||||
descriptions[DOMAIN_LOGGER]["set_default_level"]["description"]
|
||||
|
|
|
@ -5,6 +5,7 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.http.const import StrictConnectionMode
|
||||
from homeassistant.config import YAML_CONFIG_FILE
|
||||
from homeassistant.scripts import check_config
|
||||
|
||||
|
@ -134,6 +135,7 @@ def test_secrets(mock_is_file, event_loop, mock_hass_config_yaml: None) -> None:
|
|||
"login_attempts_threshold": -1,
|
||||
"server_port": 8123,
|
||||
"ssl_profile": "modern",
|
||||
"strict_connection": StrictConnectionMode.DISABLED,
|
||||
"use_x_frame_options": True,
|
||||
"server_host": ["0.0.0.0", "::"],
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue