Rename strict connection static page to guard page (#116085)
This commit is contained in:
parent
07d68eacfa
commit
bcc2dd99b2
7 changed files with 35 additions and 35 deletions
|
@ -52,9 +52,9 @@ STORAGE_VERSION = 1
|
|||
STORAGE_KEY = "http.auth"
|
||||
CONTENT_USER_NAME = "Home Assistant Content"
|
||||
STRICT_CONNECTION_EXCLUDED_PATH = "/api/webhook/"
|
||||
STRICT_CONNECTION_STATIC_PAGE_NAME = "strict_connection_static_page.html"
|
||||
STRICT_CONNECTION_STATIC_PAGE = os.path.join(
|
||||
os.path.dirname(__file__), STRICT_CONNECTION_STATIC_PAGE_NAME
|
||||
STRICT_CONNECTION_GUARD_PAGE_NAME = "strict_connection_guard_page.html"
|
||||
STRICT_CONNECTION_GUARD_PAGE = os.path.join(
|
||||
os.path.dirname(__file__), STRICT_CONNECTION_GUARD_PAGE_NAME
|
||||
)
|
||||
|
||||
|
||||
|
@ -160,9 +160,9 @@ async def async_setup_auth(
|
|||
|
||||
hass.data[STORAGE_KEY] = refresh_token.id
|
||||
|
||||
if strict_connection_mode_non_cloud is StrictConnectionMode.STATIC_PAGE:
|
||||
# Load the static page content on setup
|
||||
await _read_strict_connection_static_page(hass)
|
||||
if strict_connection_mode_non_cloud is StrictConnectionMode.GUARD_PAGE:
|
||||
# Load the guard page content on setup
|
||||
await _read_strict_connection_guard_page(hass)
|
||||
|
||||
@callback
|
||||
def async_validate_auth_header(request: Request) -> bool:
|
||||
|
@ -276,7 +276,7 @@ async def async_setup_auth(
|
|||
resp := await strict_connection_func(
|
||||
hass,
|
||||
request,
|
||||
strict_connection_mode is StrictConnectionMode.STATIC_PAGE,
|
||||
strict_connection_mode is StrictConnectionMode.GUARD_PAGE,
|
||||
)
|
||||
)
|
||||
is not None
|
||||
|
@ -301,14 +301,14 @@ async def async_setup_auth(
|
|||
async def _async_perform_strict_connection_action_on_non_local(
|
||||
hass: HomeAssistant,
|
||||
request: Request,
|
||||
static_page: bool,
|
||||
guard_page: bool,
|
||||
) -> StreamResponse | None:
|
||||
"""Perform strict connection mode action if the request is not local.
|
||||
|
||||
The function does the following:
|
||||
- Try to get the IP address of the request. If it fails, assume it's not local
|
||||
- If the request is local, return None (allow the request to continue)
|
||||
- If static_page is True, return a response with the content
|
||||
- If guard_page is True, return a response with the content
|
||||
- Otherwise close the connection and raise an exception
|
||||
"""
|
||||
try:
|
||||
|
@ -320,25 +320,25 @@ async def _async_perform_strict_connection_action_on_non_local(
|
|||
if ip_address_ and is_local(ip_address_):
|
||||
return None
|
||||
|
||||
return await _async_perform_strict_connection_action(hass, request, static_page)
|
||||
return await _async_perform_strict_connection_action(hass, request, guard_page)
|
||||
|
||||
|
||||
async def _async_perform_strict_connection_action(
|
||||
hass: HomeAssistant,
|
||||
request: Request,
|
||||
static_page: bool,
|
||||
guard_page: bool,
|
||||
) -> StreamResponse | None:
|
||||
"""Perform strict connection mode action.
|
||||
|
||||
The function does the following:
|
||||
- If static_page is True, return a response with the content
|
||||
- If guard_page is True, return a response with the content
|
||||
- Otherwise close the connection and raise an exception
|
||||
"""
|
||||
|
||||
_LOGGER.debug("Perform strict connection action for %s", request.remote)
|
||||
if static_page:
|
||||
if guard_page:
|
||||
return Response(
|
||||
text=await _read_strict_connection_static_page(hass),
|
||||
text=await _read_strict_connection_guard_page(hass),
|
||||
content_type="text/html",
|
||||
status=HTTPStatus.IM_A_TEAPOT,
|
||||
)
|
||||
|
@ -351,12 +351,12 @@ async def _async_perform_strict_connection_action(
|
|||
raise HTTPBadRequest
|
||||
|
||||
|
||||
@singleton.singleton(f"{DOMAIN}_{STRICT_CONNECTION_STATIC_PAGE_NAME}")
|
||||
async def _read_strict_connection_static_page(hass: HomeAssistant) -> str:
|
||||
"""Read the strict connection static page from disk via executor."""
|
||||
@singleton.singleton(f"{DOMAIN}_{STRICT_CONNECTION_GUARD_PAGE_NAME}")
|
||||
async def _read_strict_connection_guard_page(hass: HomeAssistant) -> str:
|
||||
"""Read the strict connection guard page from disk via executor."""
|
||||
|
||||
def read_static_page() -> str:
|
||||
with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
|
||||
def read_guard_page() -> str:
|
||||
with open(STRICT_CONNECTION_GUARD_PAGE, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
return await hass.async_add_executor_job(read_static_page)
|
||||
return await hass.async_add_executor_job(read_guard_page)
|
||||
|
|
|
@ -15,5 +15,5 @@ class StrictConnectionMode(StrEnum):
|
|||
"""Enum for strict connection mode."""
|
||||
|
||||
DISABLED = "disabled"
|
||||
STATIC_PAGE = "static_page"
|
||||
GUARD_PAGE = "guard_page"
|
||||
DROP_CONNECTION = "drop_connection"
|
||||
|
|
|
@ -327,7 +327,7 @@ async def test_service_create_temporary_strict_connection_url_strict_connection_
|
|||
("mode"),
|
||||
[
|
||||
StrictConnectionMode.DROP_CONNECTION,
|
||||
StrictConnectionMode.STATIC_PAGE,
|
||||
StrictConnectionMode.GUARD_PAGE,
|
||||
],
|
||||
)
|
||||
async def test_service_create_temporary_strict_connection(
|
||||
|
|
|
@ -18,7 +18,7 @@ from homeassistant.auth.session import SESSION_ID, TEMP_TIMEOUT
|
|||
from homeassistant.components.cloud.const import PREF_STRICT_CONNECTION
|
||||
from homeassistant.components.http import KEY_HASS
|
||||
from homeassistant.components.http.auth import (
|
||||
STRICT_CONNECTION_STATIC_PAGE,
|
||||
STRICT_CONNECTION_GUARD_PAGE,
|
||||
async_setup_auth,
|
||||
async_sign_path,
|
||||
)
|
||||
|
@ -213,17 +213,17 @@ async def _drop_connection_unauthorized_request(
|
|||
await client.get("/")
|
||||
|
||||
|
||||
async def _static_page_unauthorized_request(
|
||||
async def _guard_page_unauthorized_request(
|
||||
hass: HomeAssistant, client: TestClient
|
||||
) -> None:
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.IM_A_TEAPOT
|
||||
|
||||
def read_static_page() -> str:
|
||||
with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
|
||||
def read_guard_page() -> str:
|
||||
with open(STRICT_CONNECTION_GUARD_PAGE, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
assert await req.text() == await hass.async_add_executor_job(read_static_page)
|
||||
assert await req.text() == await hass.async_add_executor_job(read_guard_page)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -243,7 +243,7 @@ async def _static_page_unauthorized_request(
|
|||
("strict_connection_mode", "request_func"),
|
||||
[
|
||||
(StrictConnectionMode.DROP_CONNECTION, _drop_connection_unauthorized_request),
|
||||
(StrictConnectionMode.STATIC_PAGE, _static_page_unauthorized_request),
|
||||
(StrictConnectionMode.GUARD_PAGE, _guard_page_unauthorized_request),
|
||||
],
|
||||
ids=["drop connection", "static page"],
|
||||
)
|
||||
|
|
|
@ -30,7 +30,7 @@ from homeassistant.components.http.auth import (
|
|||
DATA_SIGN_SECRET,
|
||||
SIGN_QUERY_PARAM,
|
||||
STORAGE_KEY,
|
||||
STRICT_CONNECTION_STATIC_PAGE,
|
||||
STRICT_CONNECTION_GUARD_PAGE,
|
||||
async_setup_auth,
|
||||
async_sign_path,
|
||||
async_user_not_allowed_do_auth,
|
||||
|
@ -879,17 +879,17 @@ async def _drop_connection_unauthorized_request(
|
|||
await client.get("/")
|
||||
|
||||
|
||||
async def _static_page_unauthorized_request(
|
||||
async def _guard_page_unauthorized_request(
|
||||
hass: HomeAssistant, client: TestClient
|
||||
) -> None:
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.IM_A_TEAPOT
|
||||
|
||||
def read_static_page() -> str:
|
||||
with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
|
||||
def read_guard_page() -> str:
|
||||
with open(STRICT_CONNECTION_GUARD_PAGE, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
assert await req.text() == await hass.async_add_executor_job(read_static_page)
|
||||
assert await req.text() == await hass.async_add_executor_job(read_guard_page)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -909,7 +909,7 @@ async def _static_page_unauthorized_request(
|
|||
("strict_connection_mode", "request_func"),
|
||||
[
|
||||
(StrictConnectionMode.DROP_CONNECTION, _drop_connection_unauthorized_request),
|
||||
(StrictConnectionMode.STATIC_PAGE, _static_page_unauthorized_request),
|
||||
(StrictConnectionMode.GUARD_PAGE, _guard_page_unauthorized_request),
|
||||
],
|
||||
ids=["drop connection", "static page"],
|
||||
)
|
||||
|
|
|
@ -548,7 +548,7 @@ async def test_service_create_temporary_strict_connection_url_strict_connection_
|
|||
("mode"),
|
||||
[
|
||||
StrictConnectionMode.DROP_CONNECTION,
|
||||
StrictConnectionMode.STATIC_PAGE,
|
||||
StrictConnectionMode.GUARD_PAGE,
|
||||
],
|
||||
)
|
||||
async def test_service_create_temporary_strict_connection(
|
||||
|
|
Loading…
Add table
Reference in a new issue