Extend auth/providers
endpoint and add /api/person/list
endpoint for local ip requests (#103906)
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
512902fc59
commit
852fb58ca8
14 changed files with 370 additions and 75 deletions
|
@ -22,6 +22,7 @@ from homeassistant.core import callback
|
|||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.network import is_cloud_connection
|
||||
|
||||
from .. import InvalidAuthError
|
||||
from ..models import Credentials, RefreshToken, UserMeta
|
||||
|
@ -192,10 +193,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||
if any(ip_addr in trusted_proxy for trusted_proxy in self.trusted_proxies):
|
||||
raise InvalidAuthError("Can't allow access from a proxy server")
|
||||
|
||||
if "cloud" in self.hass.config.components:
|
||||
from hass_nabucasa import remote # pylint: disable=import-outside-toplevel
|
||||
|
||||
if remote.is_cloud_request.get():
|
||||
if is_cloud_connection(self.hass):
|
||||
raise InvalidAuthError("Can't allow access from Home Assistant Cloud")
|
||||
|
||||
@callback
|
||||
|
|
|
@ -71,14 +71,14 @@ from __future__ import annotations
|
|||
from collections.abc import Callable
|
||||
from http import HTTPStatus
|
||||
from ipaddress import ip_address
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from aiohttp import web
|
||||
import voluptuous as vol
|
||||
import voluptuous_serialize
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.auth import AuthManagerFlowManager
|
||||
from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError
|
||||
from homeassistant.auth.models import Credentials
|
||||
from homeassistant.components import onboarding
|
||||
from homeassistant.components.http.auth import async_user_not_allowed_do_auth
|
||||
|
@ -90,10 +90,16 @@ from homeassistant.components.http.ban import (
|
|||
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||
from homeassistant.components.http.view import HomeAssistantView
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.network import is_cloud_connection
|
||||
from homeassistant.util.network import is_local
|
||||
|
||||
from . import indieauth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.auth.providers.trusted_networks import (
|
||||
TrustedNetworksAuthProvider,
|
||||
)
|
||||
|
||||
from . import StoreResultType
|
||||
|
||||
|
||||
|
@ -146,13 +152,62 @@ class AuthProvidersView(HomeAssistantView):
|
|||
message_code="onboarding_required",
|
||||
)
|
||||
|
||||
return self.json(
|
||||
[
|
||||
{"name": provider.name, "id": provider.id, "type": provider.type}
|
||||
for provider in hass.auth.auth_providers
|
||||
]
|
||||
try:
|
||||
remote_address = ip_address(request.remote) # type: ignore[arg-type]
|
||||
except ValueError:
|
||||
return self.json_message(
|
||||
message="Invalid remote IP",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
message_code="invalid_remote_ip",
|
||||
)
|
||||
|
||||
cloud_connection = is_cloud_connection(hass)
|
||||
|
||||
providers = []
|
||||
for provider in hass.auth.auth_providers:
|
||||
additional_data = {}
|
||||
|
||||
if provider.type == "trusted_networks":
|
||||
if cloud_connection:
|
||||
# Skip quickly as trusted networks are not available on cloud
|
||||
continue
|
||||
|
||||
try:
|
||||
cast("TrustedNetworksAuthProvider", provider).async_validate_access(
|
||||
remote_address
|
||||
)
|
||||
except InvalidAuthError:
|
||||
# Not a trusted network, so we don't expose that trusted_network authenticator is setup
|
||||
continue
|
||||
elif (
|
||||
provider.type == "homeassistant"
|
||||
and not cloud_connection
|
||||
and is_local(remote_address)
|
||||
and "person" in hass.config.components
|
||||
):
|
||||
# We are local, return user id and username
|
||||
users = await provider.store.async_get_users()
|
||||
additional_data["users"] = {
|
||||
user.id: credentials.data["username"]
|
||||
for user in users
|
||||
for credentials in user.credentials
|
||||
if (
|
||||
credentials.auth_provider_type == provider.type
|
||||
and credentials.auth_provider_id == provider.id
|
||||
)
|
||||
}
|
||||
|
||||
providers.append(
|
||||
{
|
||||
"name": provider.name,
|
||||
"id": provider.id,
|
||||
"type": provider.type,
|
||||
**additional_data,
|
||||
}
|
||||
)
|
||||
|
||||
return self.json(providers)
|
||||
|
||||
|
||||
def _prepare_result_json(
|
||||
result: data_entry_flow.FlowResult,
|
||||
|
|
|
@ -21,6 +21,7 @@ from homeassistant.auth.models import User
|
|||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.json import json_bytes
|
||||
from homeassistant.helpers.network import is_cloud_connection
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.util.network import is_local
|
||||
|
||||
|
@ -98,11 +99,7 @@ def async_user_not_allowed_do_auth(
|
|||
if not request:
|
||||
return "No request available to validate local access"
|
||||
|
||||
if "cloud" in hass.config.components:
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
from hass_nabucasa import remote
|
||||
|
||||
if remote.is_cloud_request.get():
|
||||
if is_cloud_connection(hass):
|
||||
return "User is local only"
|
||||
|
||||
try:
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
"""Support for tracking people."""
|
||||
from __future__ import annotations
|
||||
|
||||
from http import HTTPStatus
|
||||
from ipaddress import ip_address
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.auth import EVENT_USER_REMOVED
|
||||
|
@ -13,6 +16,7 @@ from homeassistant.components.device_tracker import (
|
|||
DOMAIN as DEVICE_TRACKER_DOMAIN,
|
||||
SourceType,
|
||||
)
|
||||
from homeassistant.components.http.view import HomeAssistantView
|
||||
from homeassistant.const import (
|
||||
ATTR_EDITABLE,
|
||||
ATTR_ENTITY_ID,
|
||||
|
@ -47,10 +51,12 @@ from homeassistant.helpers import (
|
|||
)
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.event import async_track_state_change_event
|
||||
from homeassistant.helpers.network import is_cloud_connection
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util.network import is_local
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -385,6 +391,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
hass, DOMAIN, SERVICE_RELOAD, async_reload_yaml
|
||||
)
|
||||
|
||||
hass.http.register_view(ListPersonsView)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
@ -569,3 +577,44 @@ def _get_latest(prev: State | None, curr: State):
|
|||
if prev is None or curr.last_updated > prev.last_updated:
|
||||
return curr
|
||||
return prev
|
||||
|
||||
|
||||
class ListPersonsView(HomeAssistantView):
|
||||
"""List all persons if request is made from a local network."""
|
||||
|
||||
requires_auth = False
|
||||
url = "/api/person/list"
|
||||
name = "api:person:list"
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Return a list of persons if request comes from a local IP."""
|
||||
try:
|
||||
remote_address = ip_address(request.remote) # type: ignore[arg-type]
|
||||
except ValueError:
|
||||
return self.json_message(
|
||||
message="Invalid remote IP",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
message_code="invalid_remote_ip",
|
||||
)
|
||||
|
||||
hass: HomeAssistant = request.app["hass"]
|
||||
if is_cloud_connection(hass) or not is_local(remote_address):
|
||||
return self.json_message(
|
||||
message="Not local",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
message_code="not_local",
|
||||
)
|
||||
|
||||
yaml, storage, _ = hass.data[DOMAIN]
|
||||
persons = [*yaml.async_items(), *storage.async_items()]
|
||||
|
||||
return self.json(
|
||||
{
|
||||
person[ATTR_USER_ID]: {
|
||||
ATTR_NAME: person[ATTR_NAME],
|
||||
CONF_PICTURE: person.get(CONF_PICTURE),
|
||||
}
|
||||
for person in persons
|
||||
if person.get(ATTR_USER_ID)
|
||||
}
|
||||
)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
"name": "Person",
|
||||
"after_dependencies": ["device_tracker"],
|
||||
"codeowners": [],
|
||||
"dependencies": ["image_upload"],
|
||||
"dependencies": ["image_upload", "http"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/person",
|
||||
"integration_type": "system",
|
||||
"iot_class": "calculated",
|
||||
|
|
|
@ -17,7 +17,7 @@ from homeassistant.components import websocket_api
|
|||
from homeassistant.components.http.view import HomeAssistantView
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.network import get_url
|
||||
from homeassistant.helpers.network import get_url, is_cloud_connection
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import network
|
||||
|
@ -145,13 +145,8 @@ async def async_handle_webhook(
|
|||
return Response(status=HTTPStatus.METHOD_NOT_ALLOWED)
|
||||
|
||||
if webhook["local_only"] in (True, None) and not isinstance(request, MockRequest):
|
||||
if has_cloud := "cloud" in hass.config.components:
|
||||
from hass_nabucasa import remote # pylint: disable=import-outside-toplevel
|
||||
|
||||
is_local = True
|
||||
if has_cloud and remote.is_cloud_request.get():
|
||||
is_local = False
|
||||
else:
|
||||
is_local = not is_cloud_connection(hass)
|
||||
if is_local:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(request, Request)
|
||||
assert request.remote is not None
|
||||
|
|
|
@ -299,3 +299,14 @@ def _get_cloud_url(hass: HomeAssistant, require_current_request: bool = False) -
|
|||
return normalize_url(str(cloud_url))
|
||||
|
||||
raise NoURLAvailableError
|
||||
|
||||
|
||||
def is_cloud_connection(hass: HomeAssistant) -> bool:
|
||||
"""Return True if the current connection is a nabucasa cloud connection."""
|
||||
|
||||
if "cloud" not in hass.config.components:
|
||||
return False
|
||||
|
||||
from hass_nabucasa import remote # pylint: disable=import-outside-toplevel
|
||||
|
||||
return remote.is_cloud_request.get()
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
"""Tests for the auth component."""
|
||||
from typing import Any
|
||||
|
||||
from homeassistant import auth
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import ensure_auth_manager_loaded
|
||||
from tests.test_util import mock_real_ip
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
BASE_CONFIG = [
|
||||
{
|
||||
|
@ -18,11 +23,12 @@ EMPTY_CONFIG = []
|
|||
|
||||
|
||||
async def async_setup_auth(
|
||||
hass,
|
||||
aiohttp_client,
|
||||
provider_configs=BASE_CONFIG,
|
||||
hass: HomeAssistant,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
provider_configs: list[dict[str, Any]] = BASE_CONFIG,
|
||||
module_configs=EMPTY_CONFIG,
|
||||
setup_api=False,
|
||||
setup_api: bool = False,
|
||||
custom_ip: str | None = None,
|
||||
):
|
||||
"""Set up authentication and create an HTTP client."""
|
||||
hass.auth = await auth.auth_manager_from_config(
|
||||
|
@ -32,4 +38,6 @@ async def async_setup_auth(
|
|||
await async_setup_component(hass, "auth", {})
|
||||
if setup_api:
|
||||
await async_setup_component(hass, "api", {})
|
||||
if custom_ip:
|
||||
mock_real_ip(hass.http.app)(custom_ip)
|
||||
return await aiohttp_client(hass.http.app)
|
||||
|
|
|
@ -1,25 +1,141 @@
|
|||
"""Tests for the login flow."""
|
||||
from collections.abc import Callable
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
import pytest
|
||||
|
||||
from . import async_setup_auth
|
||||
from homeassistant.auth.models import User
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import BASE_CONFIG, async_setup_auth
|
||||
|
||||
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
_TRUSTED_NETWORKS_CONFIG = {
|
||||
"type": "trusted_networks",
|
||||
"trusted_networks": ["192.168.0.1"],
|
||||
"trusted_users": {
|
||||
"192.168.0.1": [
|
||||
"a1ab982744b64757bf80515589258924",
|
||||
{"group": "system-group"},
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider_configs", "ip", "expected"),
|
||||
[
|
||||
(
|
||||
BASE_CONFIG,
|
||||
None,
|
||||
[{"name": "Example", "type": "insecure_example", "id": None}],
|
||||
),
|
||||
(
|
||||
[_TRUSTED_NETWORKS_CONFIG],
|
||||
None,
|
||||
[],
|
||||
),
|
||||
(
|
||||
[_TRUSTED_NETWORKS_CONFIG],
|
||||
"192.168.0.1",
|
||||
[{"name": "Trusted Networks", "type": "trusted_networks", "id": None}],
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_fetch_auth_providers(
|
||||
hass: HomeAssistant, aiohttp_client: ClientSessionGenerator
|
||||
hass: HomeAssistant,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
provider_configs: list[dict[str, Any]],
|
||||
ip: str | None,
|
||||
expected: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test fetching auth providers."""
|
||||
client = await async_setup_auth(hass, aiohttp_client)
|
||||
client = await async_setup_auth(
|
||||
hass, aiohttp_client, provider_configs, custom_ip=ip
|
||||
)
|
||||
resp = await client.get("/auth/providers")
|
||||
assert resp.status == HTTPStatus.OK
|
||||
assert await resp.json() == [
|
||||
{"name": "Example", "type": "insecure_example", "id": None}
|
||||
]
|
||||
assert await resp.json() == expected
|
||||
|
||||
|
||||
async def _test_fetch_auth_providers_home_assistant(
|
||||
hass: HomeAssistant,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
ip: str,
|
||||
additional_expected_fn: Callable[[User], dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test fetching auth providers for homeassistant auth provider."""
|
||||
client = await async_setup_auth(
|
||||
hass, aiohttp_client, [{"type": "homeassistant"}], custom_ip=ip
|
||||
)
|
||||
|
||||
provider = hass.auth.auth_providers[0]
|
||||
credentials = await provider.async_get_or_create_credentials({"username": "hello"})
|
||||
user = await hass.auth.async_get_or_create_user(credentials)
|
||||
|
||||
expected = {
|
||||
"name": "Home Assistant Local",
|
||||
"type": "homeassistant",
|
||||
"id": None,
|
||||
**additional_expected_fn(user),
|
||||
}
|
||||
|
||||
resp = await client.get("/auth/providers")
|
||||
assert resp.status == HTTPStatus.OK
|
||||
assert await resp.json() == [expected]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip",
|
||||
[
|
||||
"192.168.0.10",
|
||||
"::ffff:192.168.0.10",
|
||||
"1.2.3.4",
|
||||
"2001:db8::1",
|
||||
],
|
||||
)
|
||||
async def test_fetch_auth_providers_home_assistant_person_not_loaded(
|
||||
hass: HomeAssistant,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
ip: str,
|
||||
) -> None:
|
||||
"""Test fetching auth providers for homeassistant auth provider, where person integration is not loaded."""
|
||||
await _test_fetch_auth_providers_home_assistant(
|
||||
hass, aiohttp_client, ip, lambda _: {}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("ip", "is_local"),
|
||||
[
|
||||
("192.168.0.10", True),
|
||||
("::ffff:192.168.0.10", True),
|
||||
("1.2.3.4", False),
|
||||
("2001:db8::1", False),
|
||||
],
|
||||
)
|
||||
async def test_fetch_auth_providers_home_assistant_person_loaded(
|
||||
hass: HomeAssistant,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
ip: str,
|
||||
is_local: bool,
|
||||
) -> None:
|
||||
"""Test fetching auth providers for homeassistant auth provider, where person integration is loaded."""
|
||||
domain = "person"
|
||||
config = {domain: {"id": "1234", "name": "test person"}}
|
||||
assert await async_setup_component(hass, domain, config)
|
||||
|
||||
await _test_fetch_auth_providers_home_assistant(
|
||||
hass,
|
||||
aiohttp_client,
|
||||
ip,
|
||||
lambda user: {"users": {user.id: user.name}} if is_local else {},
|
||||
)
|
||||
|
||||
|
||||
async def test_fetch_auth_providers_onboarding(
|
||||
|
|
|
@ -1,34 +1,3 @@
|
|||
"""Tests for the HTTP component."""
|
||||
from aiohttp import web
|
||||
|
||||
# Relic from the past. Kept here so we can run negative tests.
|
||||
HTTP_HEADER_HA_AUTH = "X-HA-access"
|
||||
|
||||
|
||||
def mock_real_ip(app):
|
||||
"""Inject middleware to mock real IP.
|
||||
|
||||
Returns a function to set the real IP.
|
||||
"""
|
||||
ip_to_mock = None
|
||||
|
||||
def set_ip_to_mock(value):
|
||||
nonlocal ip_to_mock
|
||||
ip_to_mock = value
|
||||
|
||||
@web.middleware
|
||||
async def mock_real_ip(request, handler):
|
||||
"""Mock Real IP middleware."""
|
||||
nonlocal ip_to_mock
|
||||
|
||||
request = request.clone(remote=ip_to_mock)
|
||||
|
||||
return await handler(request)
|
||||
|
||||
async def real_ip_startup(app):
|
||||
"""Startup of real ip."""
|
||||
app.middlewares.insert(0, mock_real_ip)
|
||||
|
||||
app.on_startup.append(real_ip_startup)
|
||||
|
||||
return set_ip_to_mock
|
||||
|
|
|
@ -35,9 +35,10 @@ from homeassistant.components.http.request_context import (
|
|||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import HTTP_HEADER_HA_AUTH, mock_real_ip
|
||||
from . import HTTP_HEADER_HA_AUTH
|
||||
|
||||
from tests.common import MockUser
|
||||
from tests.test_util import mock_real_ip
|
||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||
|
||||
API_PASSWORD = "test-password"
|
||||
|
|
|
@ -24,9 +24,8 @@ from homeassistant.core import HomeAssistant
|
|||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import mock_real_ip
|
||||
|
||||
from tests.common import async_get_persistent_notifications
|
||||
from tests.test_util import mock_real_ip
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
SUPERVISOR_IP = "1.2.3.4"
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""The tests for the person component."""
|
||||
from collections.abc import Callable
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
|
@ -29,7 +31,8 @@ from homeassistant.setup import async_setup_component
|
|||
from .conftest import DEVICE_TRACKER, DEVICE_TRACKER_2
|
||||
|
||||
from tests.common import MockUser, mock_component, mock_restore_cache
|
||||
from tests.typing import WebSocketGenerator
|
||||
from tests.test_util import mock_real_ip
|
||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||
|
||||
|
||||
async def test_minimal_setup(hass: HomeAssistant) -> None:
|
||||
|
@ -847,3 +850,63 @@ async def test_entities_in_person(hass: HomeAssistant) -> None:
|
|||
"device_tracker.paulus_iphone",
|
||||
"device_tracker.paulus_ipad",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("ip", "status_code", "expected_fn"),
|
||||
[
|
||||
(
|
||||
"192.168.0.10",
|
||||
HTTPStatus.OK,
|
||||
lambda user: {
|
||||
user["user_id"]: {"name": user["name"], "picture": user["picture"]}
|
||||
},
|
||||
),
|
||||
(
|
||||
"::ffff:192.168.0.10",
|
||||
HTTPStatus.OK,
|
||||
lambda user: {
|
||||
user["user_id"]: {"name": user["name"], "picture": user["picture"]}
|
||||
},
|
||||
),
|
||||
(
|
||||
"1.2.3.4",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
lambda _: {"code": "not_local", "message": "Not local"},
|
||||
),
|
||||
(
|
||||
"2001:db8::1",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
lambda _: {"code": "not_local", "message": "Not local"},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_list_persons(
|
||||
hass: HomeAssistant,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
hass_admin_user: MockUser,
|
||||
ip: str,
|
||||
status_code: HTTPStatus,
|
||||
expected_fn: Callable[[dict[str, Any]], dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test listing persons from a not local ip address."""
|
||||
|
||||
user_id = hass_admin_user.id
|
||||
admin = {"id": "1234", "name": "Admin", "user_id": user_id, "picture": "/bla"}
|
||||
config = {
|
||||
DOMAIN: [
|
||||
admin,
|
||||
{"id": "5678", "name": "Only a person"},
|
||||
]
|
||||
}
|
||||
assert await async_setup_component(hass, DOMAIN, config)
|
||||
|
||||
await async_setup_component(hass, "api", {})
|
||||
mock_real_ip(hass.http.app)(ip)
|
||||
client = await hass_client_no_auth()
|
||||
|
||||
resp = await client.get("/api/person/list")
|
||||
|
||||
assert resp.status == status_code
|
||||
result = await resp.json()
|
||||
assert result == expected_fn(admin)
|
||||
|
|
|
@ -1 +1,35 @@
|
|||
"""Tests for the test utilities."""
|
||||
"""Test utilities."""
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||
|
||||
|
||||
def mock_real_ip(app: Application) -> Callable[[str], None]:
|
||||
"""Inject middleware to mock real IP.
|
||||
|
||||
Returns a function to set the real IP.
|
||||
"""
|
||||
ip_to_mock: str | None = None
|
||||
|
||||
def set_ip_to_mock(value: str):
|
||||
nonlocal ip_to_mock
|
||||
ip_to_mock = value
|
||||
|
||||
@middleware
|
||||
async def mock_real_ip(
|
||||
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||
) -> StreamResponse:
|
||||
"""Mock Real IP middleware."""
|
||||
nonlocal ip_to_mock
|
||||
|
||||
request = request.clone(remote=ip_to_mock)
|
||||
|
||||
return await handler(request)
|
||||
|
||||
async def real_ip_startup(app):
|
||||
"""Startup of real ip."""
|
||||
app.middlewares.insert(0, mock_real_ip)
|
||||
|
||||
app.on_startup.append(real_ip_startup)
|
||||
|
||||
return set_ip_to_mock
|
||||
|
|
Loading…
Add table
Reference in a new issue