Extend auth/providers
endpoint and add /api/person/list
endpoint for local ip requests (#103906)
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
512902fc59
commit
852fb58ca8
14 changed files with 370 additions and 75 deletions
|
@ -22,6 +22,7 @@ from homeassistant.core import callback
|
||||||
from homeassistant.data_entry_flow import FlowResult
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
from homeassistant.helpers.network import is_cloud_connection
|
||||||
|
|
||||||
from .. import InvalidAuthError
|
from .. import InvalidAuthError
|
||||||
from ..models import Credentials, RefreshToken, UserMeta
|
from ..models import Credentials, RefreshToken, UserMeta
|
||||||
|
@ -192,11 +193,8 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
||||||
if any(ip_addr in trusted_proxy for trusted_proxy in self.trusted_proxies):
|
if any(ip_addr in trusted_proxy for trusted_proxy in self.trusted_proxies):
|
||||||
raise InvalidAuthError("Can't allow access from a proxy server")
|
raise InvalidAuthError("Can't allow access from a proxy server")
|
||||||
|
|
||||||
if "cloud" in self.hass.config.components:
|
if is_cloud_connection(self.hass):
|
||||||
from hass_nabucasa import remote # pylint: disable=import-outside-toplevel
|
raise InvalidAuthError("Can't allow access from Home Assistant Cloud")
|
||||||
|
|
||||||
if remote.is_cloud_request.get():
|
|
||||||
raise InvalidAuthError("Can't allow access from Home Assistant Cloud")
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_validate_refresh_token(
|
def async_validate_refresh_token(
|
||||||
|
|
|
@ -71,14 +71,14 @@ from __future__ import annotations
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
import voluptuous_serialize
|
import voluptuous_serialize
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
from homeassistant.auth import AuthManagerFlowManager
|
from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError
|
||||||
from homeassistant.auth.models import Credentials
|
from homeassistant.auth.models import Credentials
|
||||||
from homeassistant.components import onboarding
|
from homeassistant.components import onboarding
|
||||||
from homeassistant.components.http.auth import async_user_not_allowed_do_auth
|
from homeassistant.components.http.auth import async_user_not_allowed_do_auth
|
||||||
|
@ -90,10 +90,16 @@ from homeassistant.components.http.ban import (
|
||||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||||
from homeassistant.components.http.view import HomeAssistantView
|
from homeassistant.components.http.view import HomeAssistantView
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.network import is_cloud_connection
|
||||||
|
from homeassistant.util.network import is_local
|
||||||
|
|
||||||
from . import indieauth
|
from . import indieauth
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from homeassistant.auth.providers.trusted_networks import (
|
||||||
|
TrustedNetworksAuthProvider,
|
||||||
|
)
|
||||||
|
|
||||||
from . import StoreResultType
|
from . import StoreResultType
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,12 +152,61 @@ class AuthProvidersView(HomeAssistantView):
|
||||||
message_code="onboarding_required",
|
message_code="onboarding_required",
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.json(
|
try:
|
||||||
[
|
remote_address = ip_address(request.remote) # type: ignore[arg-type]
|
||||||
{"name": provider.name, "id": provider.id, "type": provider.type}
|
except ValueError:
|
||||||
for provider in hass.auth.auth_providers
|
return self.json_message(
|
||||||
]
|
message="Invalid remote IP",
|
||||||
)
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
message_code="invalid_remote_ip",
|
||||||
|
)
|
||||||
|
|
||||||
|
cloud_connection = is_cloud_connection(hass)
|
||||||
|
|
||||||
|
providers = []
|
||||||
|
for provider in hass.auth.auth_providers:
|
||||||
|
additional_data = {}
|
||||||
|
|
||||||
|
if provider.type == "trusted_networks":
|
||||||
|
if cloud_connection:
|
||||||
|
# Skip quickly as trusted networks are not available on cloud
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
cast("TrustedNetworksAuthProvider", provider).async_validate_access(
|
||||||
|
remote_address
|
||||||
|
)
|
||||||
|
except InvalidAuthError:
|
||||||
|
# Not a trusted network, so we don't expose that trusted_network authenticator is setup
|
||||||
|
continue
|
||||||
|
elif (
|
||||||
|
provider.type == "homeassistant"
|
||||||
|
and not cloud_connection
|
||||||
|
and is_local(remote_address)
|
||||||
|
and "person" in hass.config.components
|
||||||
|
):
|
||||||
|
# We are local, return user id and username
|
||||||
|
users = await provider.store.async_get_users()
|
||||||
|
additional_data["users"] = {
|
||||||
|
user.id: credentials.data["username"]
|
||||||
|
for user in users
|
||||||
|
for credentials in user.credentials
|
||||||
|
if (
|
||||||
|
credentials.auth_provider_type == provider.type
|
||||||
|
and credentials.auth_provider_id == provider.id
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
providers.append(
|
||||||
|
{
|
||||||
|
"name": provider.name,
|
||||||
|
"id": provider.id,
|
||||||
|
"type": provider.type,
|
||||||
|
**additional_data,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.json(providers)
|
||||||
|
|
||||||
|
|
||||||
def _prepare_result_json(
|
def _prepare_result_json(
|
||||||
|
|
|
@ -21,6 +21,7 @@ from homeassistant.auth.models import User
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers.json import json_bytes
|
from homeassistant.helpers.json import json_bytes
|
||||||
|
from homeassistant.helpers.network import is_cloud_connection
|
||||||
from homeassistant.helpers.storage import Store
|
from homeassistant.helpers.storage import Store
|
||||||
from homeassistant.util.network import is_local
|
from homeassistant.util.network import is_local
|
||||||
|
|
||||||
|
@ -98,12 +99,8 @@ def async_user_not_allowed_do_auth(
|
||||||
if not request:
|
if not request:
|
||||||
return "No request available to validate local access"
|
return "No request available to validate local access"
|
||||||
|
|
||||||
if "cloud" in hass.config.components:
|
if is_cloud_connection(hass):
|
||||||
# pylint: disable-next=import-outside-toplevel
|
return "User is local only"
|
||||||
from hass_nabucasa import remote
|
|
||||||
|
|
||||||
if remote.is_cloud_request.get():
|
|
||||||
return "User is local only"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
remote_address = ip_address(request.remote) # type: ignore[arg-type]
|
remote_address = ip_address(request.remote) # type: ignore[arg-type]
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
"""Support for tracking people."""
|
"""Support for tracking people."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from http import HTTPStatus
|
||||||
|
from ipaddress import ip_address
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.auth import EVENT_USER_REMOVED
|
from homeassistant.auth import EVENT_USER_REMOVED
|
||||||
|
@ -13,6 +16,7 @@ from homeassistant.components.device_tracker import (
|
||||||
DOMAIN as DEVICE_TRACKER_DOMAIN,
|
DOMAIN as DEVICE_TRACKER_DOMAIN,
|
||||||
SourceType,
|
SourceType,
|
||||||
)
|
)
|
||||||
|
from homeassistant.components.http.view import HomeAssistantView
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_EDITABLE,
|
ATTR_EDITABLE,
|
||||||
ATTR_ENTITY_ID,
|
ATTR_ENTITY_ID,
|
||||||
|
@ -47,10 +51,12 @@ from homeassistant.helpers import (
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
from homeassistant.helpers.event import async_track_state_change_event
|
from homeassistant.helpers.event import async_track_state_change_event
|
||||||
|
from homeassistant.helpers.network import is_cloud_connection
|
||||||
from homeassistant.helpers.restore_state import RestoreEntity
|
from homeassistant.helpers.restore_state import RestoreEntity
|
||||||
from homeassistant.helpers.storage import Store
|
from homeassistant.helpers.storage import Store
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
|
from homeassistant.util.network import is_local
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -385,6 +391,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
hass, DOMAIN, SERVICE_RELOAD, async_reload_yaml
|
hass, DOMAIN, SERVICE_RELOAD, async_reload_yaml
|
||||||
)
|
)
|
||||||
|
|
||||||
|
hass.http.register_view(ListPersonsView)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@ -569,3 +577,44 @@ def _get_latest(prev: State | None, curr: State):
|
||||||
if prev is None or curr.last_updated > prev.last_updated:
|
if prev is None or curr.last_updated > prev.last_updated:
|
||||||
return curr
|
return curr
|
||||||
return prev
|
return prev
|
||||||
|
|
||||||
|
|
||||||
|
class ListPersonsView(HomeAssistantView):
|
||||||
|
"""List all persons if request is made from a local network."""
|
||||||
|
|
||||||
|
requires_auth = False
|
||||||
|
url = "/api/person/list"
|
||||||
|
name = "api:person:list"
|
||||||
|
|
||||||
|
async def get(self, request: web.Request) -> web.Response:
|
||||||
|
"""Return a list of persons if request comes from a local IP."""
|
||||||
|
try:
|
||||||
|
remote_address = ip_address(request.remote) # type: ignore[arg-type]
|
||||||
|
except ValueError:
|
||||||
|
return self.json_message(
|
||||||
|
message="Invalid remote IP",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
message_code="invalid_remote_ip",
|
||||||
|
)
|
||||||
|
|
||||||
|
hass: HomeAssistant = request.app["hass"]
|
||||||
|
if is_cloud_connection(hass) or not is_local(remote_address):
|
||||||
|
return self.json_message(
|
||||||
|
message="Not local",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
message_code="not_local",
|
||||||
|
)
|
||||||
|
|
||||||
|
yaml, storage, _ = hass.data[DOMAIN]
|
||||||
|
persons = [*yaml.async_items(), *storage.async_items()]
|
||||||
|
|
||||||
|
return self.json(
|
||||||
|
{
|
||||||
|
person[ATTR_USER_ID]: {
|
||||||
|
ATTR_NAME: person[ATTR_NAME],
|
||||||
|
CONF_PICTURE: person.get(CONF_PICTURE),
|
||||||
|
}
|
||||||
|
for person in persons
|
||||||
|
if person.get(ATTR_USER_ID)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
"name": "Person",
|
"name": "Person",
|
||||||
"after_dependencies": ["device_tracker"],
|
"after_dependencies": ["device_tracker"],
|
||||||
"codeowners": [],
|
"codeowners": [],
|
||||||
"dependencies": ["image_upload"],
|
"dependencies": ["image_upload", "http"],
|
||||||
"documentation": "https://www.home-assistant.io/integrations/person",
|
"documentation": "https://www.home-assistant.io/integrations/person",
|
||||||
"integration_type": "system",
|
"integration_type": "system",
|
||||||
"iot_class": "calculated",
|
"iot_class": "calculated",
|
||||||
|
|
|
@ -17,7 +17,7 @@ from homeassistant.components import websocket_api
|
||||||
from homeassistant.components.http.view import HomeAssistantView
|
from homeassistant.components.http.view import HomeAssistantView
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers.network import get_url
|
from homeassistant.helpers.network import get_url, is_cloud_connection
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
from homeassistant.util import network
|
from homeassistant.util import network
|
||||||
|
@ -145,13 +145,8 @@ async def async_handle_webhook(
|
||||||
return Response(status=HTTPStatus.METHOD_NOT_ALLOWED)
|
return Response(status=HTTPStatus.METHOD_NOT_ALLOWED)
|
||||||
|
|
||||||
if webhook["local_only"] in (True, None) and not isinstance(request, MockRequest):
|
if webhook["local_only"] in (True, None) and not isinstance(request, MockRequest):
|
||||||
if has_cloud := "cloud" in hass.config.components:
|
is_local = not is_cloud_connection(hass)
|
||||||
from hass_nabucasa import remote # pylint: disable=import-outside-toplevel
|
if is_local:
|
||||||
|
|
||||||
is_local = True
|
|
||||||
if has_cloud and remote.is_cloud_request.get():
|
|
||||||
is_local = False
|
|
||||||
else:
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert isinstance(request, Request)
|
assert isinstance(request, Request)
|
||||||
assert request.remote is not None
|
assert request.remote is not None
|
||||||
|
|
|
@ -299,3 +299,14 @@ def _get_cloud_url(hass: HomeAssistant, require_current_request: bool = False) -
|
||||||
return normalize_url(str(cloud_url))
|
return normalize_url(str(cloud_url))
|
||||||
|
|
||||||
raise NoURLAvailableError
|
raise NoURLAvailableError
|
||||||
|
|
||||||
|
|
||||||
|
def is_cloud_connection(hass: HomeAssistant) -> bool:
|
||||||
|
"""Return True if the current connection is a nabucasa cloud connection."""
|
||||||
|
|
||||||
|
if "cloud" not in hass.config.components:
|
||||||
|
return False
|
||||||
|
|
||||||
|
from hass_nabucasa import remote # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
return remote.is_cloud_request.get()
|
||||||
|
|
|
@ -1,8 +1,13 @@
|
||||||
"""Tests for the auth component."""
|
"""Tests for the auth component."""
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant import auth
|
from homeassistant import auth
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import ensure_auth_manager_loaded
|
from tests.common import ensure_auth_manager_loaded
|
||||||
|
from tests.test_util import mock_real_ip
|
||||||
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
BASE_CONFIG = [
|
BASE_CONFIG = [
|
||||||
{
|
{
|
||||||
|
@ -18,11 +23,12 @@ EMPTY_CONFIG = []
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_auth(
|
async def async_setup_auth(
|
||||||
hass,
|
hass: HomeAssistant,
|
||||||
aiohttp_client,
|
aiohttp_client: ClientSessionGenerator,
|
||||||
provider_configs=BASE_CONFIG,
|
provider_configs: list[dict[str, Any]] = BASE_CONFIG,
|
||||||
module_configs=EMPTY_CONFIG,
|
module_configs=EMPTY_CONFIG,
|
||||||
setup_api=False,
|
setup_api: bool = False,
|
||||||
|
custom_ip: str | None = None,
|
||||||
):
|
):
|
||||||
"""Set up authentication and create an HTTP client."""
|
"""Set up authentication and create an HTTP client."""
|
||||||
hass.auth = await auth.auth_manager_from_config(
|
hass.auth = await auth.auth_manager_from_config(
|
||||||
|
@ -32,4 +38,6 @@ async def async_setup_auth(
|
||||||
await async_setup_component(hass, "auth", {})
|
await async_setup_component(hass, "auth", {})
|
||||||
if setup_api:
|
if setup_api:
|
||||||
await async_setup_component(hass, "api", {})
|
await async_setup_component(hass, "api", {})
|
||||||
|
if custom_ip:
|
||||||
|
mock_real_ip(hass.http.app)(custom_ip)
|
||||||
return await aiohttp_client(hass.http.app)
|
return await aiohttp_client(hass.http.app)
|
||||||
|
|
|
@ -1,25 +1,141 @@
|
||||||
"""Tests for the login flow."""
|
"""Tests for the login flow."""
|
||||||
|
from collections.abc import Callable
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
import pytest
|
||||||
|
|
||||||
from . import async_setup_auth
|
from homeassistant.auth.models import User
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from . import BASE_CONFIG, async_setup_auth
|
||||||
|
|
||||||
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI
|
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
|
_TRUSTED_NETWORKS_CONFIG = {
|
||||||
|
"type": "trusted_networks",
|
||||||
|
"trusted_networks": ["192.168.0.1"],
|
||||||
|
"trusted_users": {
|
||||||
|
"192.168.0.1": [
|
||||||
|
"a1ab982744b64757bf80515589258924",
|
||||||
|
{"group": "system-group"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("provider_configs", "ip", "expected"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
BASE_CONFIG,
|
||||||
|
None,
|
||||||
|
[{"name": "Example", "type": "insecure_example", "id": None}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[_TRUSTED_NETWORKS_CONFIG],
|
||||||
|
None,
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[_TRUSTED_NETWORKS_CONFIG],
|
||||||
|
"192.168.0.1",
|
||||||
|
[{"name": "Trusted Networks", "type": "trusted_networks", "id": None}],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
async def test_fetch_auth_providers(
|
async def test_fetch_auth_providers(
|
||||||
hass: HomeAssistant, aiohttp_client: ClientSessionGenerator
|
hass: HomeAssistant,
|
||||||
|
aiohttp_client: ClientSessionGenerator,
|
||||||
|
provider_configs: list[dict[str, Any]],
|
||||||
|
ip: str | None,
|
||||||
|
expected: list[dict[str, Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test fetching auth providers."""
|
"""Test fetching auth providers."""
|
||||||
client = await async_setup_auth(hass, aiohttp_client)
|
client = await async_setup_auth(
|
||||||
|
hass, aiohttp_client, provider_configs, custom_ip=ip
|
||||||
|
)
|
||||||
resp = await client.get("/auth/providers")
|
resp = await client.get("/auth/providers")
|
||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
assert await resp.json() == [
|
assert await resp.json() == expected
|
||||||
{"name": "Example", "type": "insecure_example", "id": None}
|
|
||||||
]
|
|
||||||
|
async def _test_fetch_auth_providers_home_assistant(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
aiohttp_client: ClientSessionGenerator,
|
||||||
|
ip: str,
|
||||||
|
additional_expected_fn: Callable[[User], dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
"""Test fetching auth providers for homeassistant auth provider."""
|
||||||
|
client = await async_setup_auth(
|
||||||
|
hass, aiohttp_client, [{"type": "homeassistant"}], custom_ip=ip
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = hass.auth.auth_providers[0]
|
||||||
|
credentials = await provider.async_get_or_create_credentials({"username": "hello"})
|
||||||
|
user = await hass.auth.async_get_or_create_user(credentials)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"name": "Home Assistant Local",
|
||||||
|
"type": "homeassistant",
|
||||||
|
"id": None,
|
||||||
|
**additional_expected_fn(user),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = await client.get("/auth/providers")
|
||||||
|
assert resp.status == HTTPStatus.OK
|
||||||
|
assert await resp.json() == [expected]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"ip",
|
||||||
|
[
|
||||||
|
"192.168.0.10",
|
||||||
|
"::ffff:192.168.0.10",
|
||||||
|
"1.2.3.4",
|
||||||
|
"2001:db8::1",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_fetch_auth_providers_home_assistant_person_not_loaded(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
aiohttp_client: ClientSessionGenerator,
|
||||||
|
ip: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test fetching auth providers for homeassistant auth provider, where person integration is not loaded."""
|
||||||
|
await _test_fetch_auth_providers_home_assistant(
|
||||||
|
hass, aiohttp_client, ip, lambda _: {}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("ip", "is_local"),
|
||||||
|
[
|
||||||
|
("192.168.0.10", True),
|
||||||
|
("::ffff:192.168.0.10", True),
|
||||||
|
("1.2.3.4", False),
|
||||||
|
("2001:db8::1", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_fetch_auth_providers_home_assistant_person_loaded(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
aiohttp_client: ClientSessionGenerator,
|
||||||
|
ip: str,
|
||||||
|
is_local: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Test fetching auth providers for homeassistant auth provider, where person integration is loaded."""
|
||||||
|
domain = "person"
|
||||||
|
config = {domain: {"id": "1234", "name": "test person"}}
|
||||||
|
assert await async_setup_component(hass, domain, config)
|
||||||
|
|
||||||
|
await _test_fetch_auth_providers_home_assistant(
|
||||||
|
hass,
|
||||||
|
aiohttp_client,
|
||||||
|
ip,
|
||||||
|
lambda user: {"users": {user.id: user.name}} if is_local else {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_fetch_auth_providers_onboarding(
|
async def test_fetch_auth_providers_onboarding(
|
||||||
|
|
|
@ -1,34 +1,3 @@
|
||||||
"""Tests for the HTTP component."""
|
"""Tests for the HTTP component."""
|
||||||
from aiohttp import web
|
|
||||||
|
|
||||||
# Relic from the past. Kept here so we can run negative tests.
|
# Relic from the past. Kept here so we can run negative tests.
|
||||||
HTTP_HEADER_HA_AUTH = "X-HA-access"
|
HTTP_HEADER_HA_AUTH = "X-HA-access"
|
||||||
|
|
||||||
|
|
||||||
def mock_real_ip(app):
|
|
||||||
"""Inject middleware to mock real IP.
|
|
||||||
|
|
||||||
Returns a function to set the real IP.
|
|
||||||
"""
|
|
||||||
ip_to_mock = None
|
|
||||||
|
|
||||||
def set_ip_to_mock(value):
|
|
||||||
nonlocal ip_to_mock
|
|
||||||
ip_to_mock = value
|
|
||||||
|
|
||||||
@web.middleware
|
|
||||||
async def mock_real_ip(request, handler):
|
|
||||||
"""Mock Real IP middleware."""
|
|
||||||
nonlocal ip_to_mock
|
|
||||||
|
|
||||||
request = request.clone(remote=ip_to_mock)
|
|
||||||
|
|
||||||
return await handler(request)
|
|
||||||
|
|
||||||
async def real_ip_startup(app):
|
|
||||||
"""Startup of real ip."""
|
|
||||||
app.middlewares.insert(0, mock_real_ip)
|
|
||||||
|
|
||||||
app.on_startup.append(real_ip_startup)
|
|
||||||
|
|
||||||
return set_ip_to_mock
|
|
||||||
|
|
|
@ -35,9 +35,10 @@ from homeassistant.components.http.request_context import (
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from . import HTTP_HEADER_HA_AUTH, mock_real_ip
|
from . import HTTP_HEADER_HA_AUTH
|
||||||
|
|
||||||
from tests.common import MockUser
|
from tests.common import MockUser
|
||||||
|
from tests.test_util import mock_real_ip
|
||||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||||
|
|
||||||
API_PASSWORD = "test-password"
|
API_PASSWORD = "test-password"
|
||||||
|
|
|
@ -24,9 +24,8 @@ from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from . import mock_real_ip
|
|
||||||
|
|
||||||
from tests.common import async_get_persistent_notifications
|
from tests.common import async_get_persistent_notifications
|
||||||
|
from tests.test_util import mock_real_ip
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
SUPERVISOR_IP = "1.2.3.4"
|
SUPERVISOR_IP = "1.2.3.4"
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
"""The tests for the person component."""
|
"""The tests for the person component."""
|
||||||
|
from collections.abc import Callable
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
@ -29,7 +31,8 @@ from homeassistant.setup import async_setup_component
|
||||||
from .conftest import DEVICE_TRACKER, DEVICE_TRACKER_2
|
from .conftest import DEVICE_TRACKER, DEVICE_TRACKER_2
|
||||||
|
|
||||||
from tests.common import MockUser, mock_component, mock_restore_cache
|
from tests.common import MockUser, mock_component, mock_restore_cache
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.test_util import mock_real_ip
|
||||||
|
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||||
|
|
||||||
|
|
||||||
async def test_minimal_setup(hass: HomeAssistant) -> None:
|
async def test_minimal_setup(hass: HomeAssistant) -> None:
|
||||||
|
@ -847,3 +850,63 @@ async def test_entities_in_person(hass: HomeAssistant) -> None:
|
||||||
"device_tracker.paulus_iphone",
|
"device_tracker.paulus_iphone",
|
||||||
"device_tracker.paulus_ipad",
|
"device_tracker.paulus_ipad",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("ip", "status_code", "expected_fn"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"192.168.0.10",
|
||||||
|
HTTPStatus.OK,
|
||||||
|
lambda user: {
|
||||||
|
user["user_id"]: {"name": user["name"], "picture": user["picture"]}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"::ffff:192.168.0.10",
|
||||||
|
HTTPStatus.OK,
|
||||||
|
lambda user: {
|
||||||
|
user["user_id"]: {"name": user["name"], "picture": user["picture"]}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"1.2.3.4",
|
||||||
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
lambda _: {"code": "not_local", "message": "Not local"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"2001:db8::1",
|
||||||
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
lambda _: {"code": "not_local", "message": "Not local"},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_list_persons(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client_no_auth: ClientSessionGenerator,
|
||||||
|
hass_admin_user: MockUser,
|
||||||
|
ip: str,
|
||||||
|
status_code: HTTPStatus,
|
||||||
|
expected_fn: Callable[[dict[str, Any]], dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
"""Test listing persons from a not local ip address."""
|
||||||
|
|
||||||
|
user_id = hass_admin_user.id
|
||||||
|
admin = {"id": "1234", "name": "Admin", "user_id": user_id, "picture": "/bla"}
|
||||||
|
config = {
|
||||||
|
DOMAIN: [
|
||||||
|
admin,
|
||||||
|
{"id": "5678", "name": "Only a person"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
assert await async_setup_component(hass, DOMAIN, config)
|
||||||
|
|
||||||
|
await async_setup_component(hass, "api", {})
|
||||||
|
mock_real_ip(hass.http.app)(ip)
|
||||||
|
client = await hass_client_no_auth()
|
||||||
|
|
||||||
|
resp = await client.get("/api/person/list")
|
||||||
|
|
||||||
|
assert resp.status == status_code
|
||||||
|
result = await resp.json()
|
||||||
|
assert result == expected_fn(admin)
|
||||||
|
|
|
@ -1 +1,35 @@
|
||||||
"""Tests for the test utilities."""
|
"""Test utilities."""
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
|
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||||
|
|
||||||
|
|
||||||
|
def mock_real_ip(app: Application) -> Callable[[str], None]:
|
||||||
|
"""Inject middleware to mock real IP.
|
||||||
|
|
||||||
|
Returns a function to set the real IP.
|
||||||
|
"""
|
||||||
|
ip_to_mock: str | None = None
|
||||||
|
|
||||||
|
def set_ip_to_mock(value: str):
|
||||||
|
nonlocal ip_to_mock
|
||||||
|
ip_to_mock = value
|
||||||
|
|
||||||
|
@middleware
|
||||||
|
async def mock_real_ip(
|
||||||
|
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||||
|
) -> StreamResponse:
|
||||||
|
"""Mock Real IP middleware."""
|
||||||
|
nonlocal ip_to_mock
|
||||||
|
|
||||||
|
request = request.clone(remote=ip_to_mock)
|
||||||
|
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
async def real_ip_startup(app):
|
||||||
|
"""Startup of real ip."""
|
||||||
|
app.middlewares.insert(0, mock_real_ip)
|
||||||
|
|
||||||
|
app.on_startup.append(real_ip_startup)
|
||||||
|
|
||||||
|
return set_ip_to_mock
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue