Extend auth/providers endpoint and add /api/person/list endpoint for local ip requests (#103906)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Robert Resch 2023-11-24 17:11:54 +01:00 committed by GitHub
parent 512902fc59
commit 852fb58ca8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 370 additions and 75 deletions

View file

@ -22,6 +22,7 @@ from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.network import is_cloud_connection
from .. import InvalidAuthError from .. import InvalidAuthError
from ..models import Credentials, RefreshToken, UserMeta from ..models import Credentials, RefreshToken, UserMeta
@ -192,10 +193,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
if any(ip_addr in trusted_proxy for trusted_proxy in self.trusted_proxies): if any(ip_addr in trusted_proxy for trusted_proxy in self.trusted_proxies):
raise InvalidAuthError("Can't allow access from a proxy server") raise InvalidAuthError("Can't allow access from a proxy server")
if "cloud" in self.hass.config.components: if is_cloud_connection(self.hass):
from hass_nabucasa import remote # pylint: disable=import-outside-toplevel
if remote.is_cloud_request.get():
raise InvalidAuthError("Can't allow access from Home Assistant Cloud") raise InvalidAuthError("Can't allow access from Home Assistant Cloud")
@callback @callback

View file

@ -71,14 +71,14 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from http import HTTPStatus from http import HTTPStatus
from ipaddress import ip_address from ipaddress import ip_address
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, cast
from aiohttp import web from aiohttp import web
import voluptuous as vol import voluptuous as vol
import voluptuous_serialize import voluptuous_serialize
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.auth import AuthManagerFlowManager from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError
from homeassistant.auth.models import Credentials from homeassistant.auth.models import Credentials
from homeassistant.components import onboarding from homeassistant.components import onboarding
from homeassistant.components.http.auth import async_user_not_allowed_do_auth from homeassistant.components.http.auth import async_user_not_allowed_do_auth
@ -90,10 +90,16 @@ from homeassistant.components.http.ban import (
from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.components.http.view import HomeAssistantView from homeassistant.components.http.view import HomeAssistantView
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.network import is_cloud_connection
from homeassistant.util.network import is_local
from . import indieauth from . import indieauth
if TYPE_CHECKING: if TYPE_CHECKING:
from homeassistant.auth.providers.trusted_networks import (
TrustedNetworksAuthProvider,
)
from . import StoreResultType from . import StoreResultType
@ -146,13 +152,62 @@ class AuthProvidersView(HomeAssistantView):
message_code="onboarding_required", message_code="onboarding_required",
) )
return self.json( try:
[ remote_address = ip_address(request.remote) # type: ignore[arg-type]
{"name": provider.name, "id": provider.id, "type": provider.type} except ValueError:
for provider in hass.auth.auth_providers return self.json_message(
] message="Invalid remote IP",
status_code=HTTPStatus.BAD_REQUEST,
message_code="invalid_remote_ip",
) )
cloud_connection = is_cloud_connection(hass)
providers = []
for provider in hass.auth.auth_providers:
additional_data = {}
if provider.type == "trusted_networks":
if cloud_connection:
# Skip quickly as trusted networks are not available on cloud
continue
try:
cast("TrustedNetworksAuthProvider", provider).async_validate_access(
remote_address
)
except InvalidAuthError:
# Not a trusted network, so we don't expose that trusted_network authenticator is setup
continue
elif (
provider.type == "homeassistant"
and not cloud_connection
and is_local(remote_address)
and "person" in hass.config.components
):
# We are local, return user id and username
users = await provider.store.async_get_users()
additional_data["users"] = {
user.id: credentials.data["username"]
for user in users
for credentials in user.credentials
if (
credentials.auth_provider_type == provider.type
and credentials.auth_provider_id == provider.id
)
}
providers.append(
{
"name": provider.name,
"id": provider.id,
"type": provider.type,
**additional_data,
}
)
return self.json(providers)
def _prepare_result_json( def _prepare_result_json(
result: data_entry_flow.FlowResult, result: data_entry_flow.FlowResult,

View file

@ -21,6 +21,7 @@ from homeassistant.auth.models import User
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.json import json_bytes from homeassistant.helpers.json import json_bytes
from homeassistant.helpers.network import is_cloud_connection
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
from homeassistant.util.network import is_local from homeassistant.util.network import is_local
@ -98,11 +99,7 @@ def async_user_not_allowed_do_auth(
if not request: if not request:
return "No request available to validate local access" return "No request available to validate local access"
if "cloud" in hass.config.components: if is_cloud_connection(hass):
# pylint: disable-next=import-outside-toplevel
from hass_nabucasa import remote
if remote.is_cloud_request.get():
return "User is local only" return "User is local only"
try: try:

View file

@ -1,9 +1,12 @@
"""Support for tracking people.""" """Support for tracking people."""
from __future__ import annotations from __future__ import annotations
from http import HTTPStatus
from ipaddress import ip_address
import logging import logging
from typing import Any from typing import Any
from aiohttp import web
import voluptuous as vol import voluptuous as vol
from homeassistant.auth import EVENT_USER_REMOVED from homeassistant.auth import EVENT_USER_REMOVED
@ -13,6 +16,7 @@ from homeassistant.components.device_tracker import (
DOMAIN as DEVICE_TRACKER_DOMAIN, DOMAIN as DEVICE_TRACKER_DOMAIN,
SourceType, SourceType,
) )
from homeassistant.components.http.view import HomeAssistantView
from homeassistant.const import ( from homeassistant.const import (
ATTR_EDITABLE, ATTR_EDITABLE,
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
@ -47,10 +51,12 @@ from homeassistant.helpers import (
) )
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.event import async_track_state_change_event
from homeassistant.helpers.network import is_cloud_connection
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util.network import is_local
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -385,6 +391,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass, DOMAIN, SERVICE_RELOAD, async_reload_yaml hass, DOMAIN, SERVICE_RELOAD, async_reload_yaml
) )
hass.http.register_view(ListPersonsView)
return True return True
@ -569,3 +577,44 @@ def _get_latest(prev: State | None, curr: State):
if prev is None or curr.last_updated > prev.last_updated: if prev is None or curr.last_updated > prev.last_updated:
return curr return curr
return prev return prev
class ListPersonsView(HomeAssistantView):
"""List all persons if request is made from a local network."""
requires_auth = False
url = "/api/person/list"
name = "api:person:list"
async def get(self, request: web.Request) -> web.Response:
"""Return a list of persons if request comes from a local IP."""
try:
remote_address = ip_address(request.remote) # type: ignore[arg-type]
except ValueError:
return self.json_message(
message="Invalid remote IP",
status_code=HTTPStatus.BAD_REQUEST,
message_code="invalid_remote_ip",
)
hass: HomeAssistant = request.app["hass"]
if is_cloud_connection(hass) or not is_local(remote_address):
return self.json_message(
message="Not local",
status_code=HTTPStatus.BAD_REQUEST,
message_code="not_local",
)
yaml, storage, _ = hass.data[DOMAIN]
persons = [*yaml.async_items(), *storage.async_items()]
return self.json(
{
person[ATTR_USER_ID]: {
ATTR_NAME: person[ATTR_NAME],
CONF_PICTURE: person.get(CONF_PICTURE),
}
for person in persons
if person.get(ATTR_USER_ID)
}
)

View file

@ -3,7 +3,7 @@
"name": "Person", "name": "Person",
"after_dependencies": ["device_tracker"], "after_dependencies": ["device_tracker"],
"codeowners": [], "codeowners": [],
"dependencies": ["image_upload"], "dependencies": ["image_upload", "http"],
"documentation": "https://www.home-assistant.io/integrations/person", "documentation": "https://www.home-assistant.io/integrations/person",
"integration_type": "system", "integration_type": "system",
"iot_class": "calculated", "iot_class": "calculated",

View file

@ -17,7 +17,7 @@ from homeassistant.components import websocket_api
from homeassistant.components.http.view import HomeAssistantView from homeassistant.components.http.view import HomeAssistantView
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.network import get_url from homeassistant.helpers.network import get_url, is_cloud_connection
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import network from homeassistant.util import network
@ -145,13 +145,8 @@ async def async_handle_webhook(
return Response(status=HTTPStatus.METHOD_NOT_ALLOWED) return Response(status=HTTPStatus.METHOD_NOT_ALLOWED)
if webhook["local_only"] in (True, None) and not isinstance(request, MockRequest): if webhook["local_only"] in (True, None) and not isinstance(request, MockRequest):
if has_cloud := "cloud" in hass.config.components: is_local = not is_cloud_connection(hass)
from hass_nabucasa import remote # pylint: disable=import-outside-toplevel if is_local:
is_local = True
if has_cloud and remote.is_cloud_request.get():
is_local = False
else:
if TYPE_CHECKING: if TYPE_CHECKING:
assert isinstance(request, Request) assert isinstance(request, Request)
assert request.remote is not None assert request.remote is not None

View file

@ -299,3 +299,14 @@ def _get_cloud_url(hass: HomeAssistant, require_current_request: bool = False) -
return normalize_url(str(cloud_url)) return normalize_url(str(cloud_url))
raise NoURLAvailableError raise NoURLAvailableError
def is_cloud_connection(hass: HomeAssistant) -> bool:
"""Return True if the current connection is a nabucasa cloud connection."""
if "cloud" not in hass.config.components:
return False
from hass_nabucasa import remote # pylint: disable=import-outside-toplevel
return remote.is_cloud_request.get()

View file

@ -1,8 +1,13 @@
"""Tests for the auth component.""" """Tests for the auth component."""
from typing import Any
from homeassistant import auth from homeassistant import auth
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import ensure_auth_manager_loaded from tests.common import ensure_auth_manager_loaded
from tests.test_util import mock_real_ip
from tests.typing import ClientSessionGenerator
BASE_CONFIG = [ BASE_CONFIG = [
{ {
@ -18,11 +23,12 @@ EMPTY_CONFIG = []
async def async_setup_auth( async def async_setup_auth(
hass, hass: HomeAssistant,
aiohttp_client, aiohttp_client: ClientSessionGenerator,
provider_configs=BASE_CONFIG, provider_configs: list[dict[str, Any]] = BASE_CONFIG,
module_configs=EMPTY_CONFIG, module_configs=EMPTY_CONFIG,
setup_api=False, setup_api: bool = False,
custom_ip: str | None = None,
): ):
"""Set up authentication and create an HTTP client.""" """Set up authentication and create an HTTP client."""
hass.auth = await auth.auth_manager_from_config( hass.auth = await auth.auth_manager_from_config(
@ -32,4 +38,6 @@ async def async_setup_auth(
await async_setup_component(hass, "auth", {}) await async_setup_component(hass, "auth", {})
if setup_api: if setup_api:
await async_setup_component(hass, "api", {}) await async_setup_component(hass, "api", {})
if custom_ip:
mock_real_ip(hass.http.app)(custom_ip)
return await aiohttp_client(hass.http.app) return await aiohttp_client(hass.http.app)

View file

@ -1,25 +1,141 @@
"""Tests for the login flow.""" """Tests for the login flow."""
from collections.abc import Callable
from http import HTTPStatus from http import HTTPStatus
from typing import Any
from unittest.mock import patch from unittest.mock import patch
from homeassistant.core import HomeAssistant import pytest
from . import async_setup_auth from homeassistant.auth.models import User
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from . import BASE_CONFIG, async_setup_auth
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI
from tests.typing import ClientSessionGenerator from tests.typing import ClientSessionGenerator
_TRUSTED_NETWORKS_CONFIG = {
"type": "trusted_networks",
"trusted_networks": ["192.168.0.1"],
"trusted_users": {
"192.168.0.1": [
"a1ab982744b64757bf80515589258924",
{"group": "system-group"},
]
},
}
@pytest.mark.parametrize(
("provider_configs", "ip", "expected"),
[
(
BASE_CONFIG,
None,
[{"name": "Example", "type": "insecure_example", "id": None}],
),
(
[_TRUSTED_NETWORKS_CONFIG],
None,
[],
),
(
[_TRUSTED_NETWORKS_CONFIG],
"192.168.0.1",
[{"name": "Trusted Networks", "type": "trusted_networks", "id": None}],
),
],
)
async def test_fetch_auth_providers( async def test_fetch_auth_providers(
hass: HomeAssistant, aiohttp_client: ClientSessionGenerator hass: HomeAssistant,
aiohttp_client: ClientSessionGenerator,
provider_configs: list[dict[str, Any]],
ip: str | None,
expected: list[dict[str, Any]],
) -> None: ) -> None:
"""Test fetching auth providers.""" """Test fetching auth providers."""
client = await async_setup_auth(hass, aiohttp_client) client = await async_setup_auth(
hass, aiohttp_client, provider_configs, custom_ip=ip
)
resp = await client.get("/auth/providers") resp = await client.get("/auth/providers")
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
assert await resp.json() == [ assert await resp.json() == expected
{"name": "Example", "type": "insecure_example", "id": None}
]
async def _test_fetch_auth_providers_home_assistant(
hass: HomeAssistant,
aiohttp_client: ClientSessionGenerator,
ip: str,
additional_expected_fn: Callable[[User], dict[str, Any]],
) -> None:
"""Test fetching auth providers for homeassistant auth provider."""
client = await async_setup_auth(
hass, aiohttp_client, [{"type": "homeassistant"}], custom_ip=ip
)
provider = hass.auth.auth_providers[0]
credentials = await provider.async_get_or_create_credentials({"username": "hello"})
user = await hass.auth.async_get_or_create_user(credentials)
expected = {
"name": "Home Assistant Local",
"type": "homeassistant",
"id": None,
**additional_expected_fn(user),
}
resp = await client.get("/auth/providers")
assert resp.status == HTTPStatus.OK
assert await resp.json() == [expected]
@pytest.mark.parametrize(
"ip",
[
"192.168.0.10",
"::ffff:192.168.0.10",
"1.2.3.4",
"2001:db8::1",
],
)
async def test_fetch_auth_providers_home_assistant_person_not_loaded(
hass: HomeAssistant,
aiohttp_client: ClientSessionGenerator,
ip: str,
) -> None:
"""Test fetching auth providers for homeassistant auth provider, where person integration is not loaded."""
await _test_fetch_auth_providers_home_assistant(
hass, aiohttp_client, ip, lambda _: {}
)
@pytest.mark.parametrize(
("ip", "is_local"),
[
("192.168.0.10", True),
("::ffff:192.168.0.10", True),
("1.2.3.4", False),
("2001:db8::1", False),
],
)
async def test_fetch_auth_providers_home_assistant_person_loaded(
hass: HomeAssistant,
aiohttp_client: ClientSessionGenerator,
ip: str,
is_local: bool,
) -> None:
"""Test fetching auth providers for homeassistant auth provider, where person integration is loaded."""
domain = "person"
config = {domain: {"id": "1234", "name": "test person"}}
assert await async_setup_component(hass, domain, config)
await _test_fetch_auth_providers_home_assistant(
hass,
aiohttp_client,
ip,
lambda user: {"users": {user.id: user.name}} if is_local else {},
)
async def test_fetch_auth_providers_onboarding( async def test_fetch_auth_providers_onboarding(

View file

@ -1,34 +1,3 @@
"""Tests for the HTTP component.""" """Tests for the HTTP component."""
from aiohttp import web
# Relic from the past. Kept here so we can run negative tests. # Relic from the past. Kept here so we can run negative tests.
HTTP_HEADER_HA_AUTH = "X-HA-access" HTTP_HEADER_HA_AUTH = "X-HA-access"
def mock_real_ip(app):
"""Inject middleware to mock real IP.
Returns a function to set the real IP.
"""
ip_to_mock = None
def set_ip_to_mock(value):
nonlocal ip_to_mock
ip_to_mock = value
@web.middleware
async def mock_real_ip(request, handler):
"""Mock Real IP middleware."""
nonlocal ip_to_mock
request = request.clone(remote=ip_to_mock)
return await handler(request)
async def real_ip_startup(app):
"""Startup of real ip."""
app.middlewares.insert(0, mock_real_ip)
app.on_startup.append(real_ip_startup)
return set_ip_to_mock

View file

@ -35,9 +35,10 @@ from homeassistant.components.http.request_context import (
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import HTTP_HEADER_HA_AUTH, mock_real_ip from . import HTTP_HEADER_HA_AUTH
from tests.common import MockUser from tests.common import MockUser
from tests.test_util import mock_real_ip
from tests.typing import ClientSessionGenerator, WebSocketGenerator from tests.typing import ClientSessionGenerator, WebSocketGenerator
API_PASSWORD = "test-password" API_PASSWORD = "test-password"

View file

@ -24,9 +24,8 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import mock_real_ip
from tests.common import async_get_persistent_notifications from tests.common import async_get_persistent_notifications
from tests.test_util import mock_real_ip
from tests.typing import ClientSessionGenerator from tests.typing import ClientSessionGenerator
SUPERVISOR_IP = "1.2.3.4" SUPERVISOR_IP = "1.2.3.4"

View file

@ -1,4 +1,6 @@
"""The tests for the person component.""" """The tests for the person component."""
from collections.abc import Callable
from http import HTTPStatus
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import patch
@ -29,7 +31,8 @@ from homeassistant.setup import async_setup_component
from .conftest import DEVICE_TRACKER, DEVICE_TRACKER_2 from .conftest import DEVICE_TRACKER, DEVICE_TRACKER_2
from tests.common import MockUser, mock_component, mock_restore_cache from tests.common import MockUser, mock_component, mock_restore_cache
from tests.typing import WebSocketGenerator from tests.test_util import mock_real_ip
from tests.typing import ClientSessionGenerator, WebSocketGenerator
async def test_minimal_setup(hass: HomeAssistant) -> None: async def test_minimal_setup(hass: HomeAssistant) -> None:
@ -847,3 +850,63 @@ async def test_entities_in_person(hass: HomeAssistant) -> None:
"device_tracker.paulus_iphone", "device_tracker.paulus_iphone",
"device_tracker.paulus_ipad", "device_tracker.paulus_ipad",
] ]
@pytest.mark.parametrize(
("ip", "status_code", "expected_fn"),
[
(
"192.168.0.10",
HTTPStatus.OK,
lambda user: {
user["user_id"]: {"name": user["name"], "picture": user["picture"]}
},
),
(
"::ffff:192.168.0.10",
HTTPStatus.OK,
lambda user: {
user["user_id"]: {"name": user["name"], "picture": user["picture"]}
},
),
(
"1.2.3.4",
HTTPStatus.BAD_REQUEST,
lambda _: {"code": "not_local", "message": "Not local"},
),
(
"2001:db8::1",
HTTPStatus.BAD_REQUEST,
lambda _: {"code": "not_local", "message": "Not local"},
),
],
)
async def test_list_persons(
hass: HomeAssistant,
hass_client_no_auth: ClientSessionGenerator,
hass_admin_user: MockUser,
ip: str,
status_code: HTTPStatus,
expected_fn: Callable[[dict[str, Any]], dict[str, Any]],
) -> None:
"""Test listing persons from a not local ip address."""
user_id = hass_admin_user.id
admin = {"id": "1234", "name": "Admin", "user_id": user_id, "picture": "/bla"}
config = {
DOMAIN: [
admin,
{"id": "5678", "name": "Only a person"},
]
}
assert await async_setup_component(hass, DOMAIN, config)
await async_setup_component(hass, "api", {})
mock_real_ip(hass.http.app)(ip)
client = await hass_client_no_auth()
resp = await client.get("/api/person/list")
assert resp.status == status_code
result = await resp.json()
assert result == expected_fn(admin)

View file

@ -1 +1,35 @@
"""Tests for the test utilities.""" """Test utilities."""
from collections.abc import Awaitable, Callable
from aiohttp.web import Application, Request, StreamResponse, middleware
def mock_real_ip(app: Application) -> Callable[[str], None]:
"""Inject middleware to mock real IP.
Returns a function to set the real IP.
"""
ip_to_mock: str | None = None
def set_ip_to_mock(value: str):
nonlocal ip_to_mock
ip_to_mock = value
@middleware
async def mock_real_ip(
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
"""Mock Real IP middleware."""
nonlocal ip_to_mock
request = request.clone(remote=ip_to_mock)
return await handler(request)
async def real_ip_startup(app):
"""Startup of real ip."""
app.middlewares.insert(0, mock_real_ip)
app.on_startup.append(real_ip_startup)
return set_ip_to_mock