Revert "Revert "Add preselect_remember_me to /auth/providers"" (#106867)

This commit is contained in:
Robert Resch 2024-01-11 10:37:19 +01:00 committed by GitHub
parent b08832a89a
commit 1c669c6e84
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 126 deletions

View file

@ -91,6 +91,7 @@ from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.components.http.view import HomeAssistantView from homeassistant.components.http.view import HomeAssistantView
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.network import is_cloud_connection from homeassistant.helpers.network import is_cloud_connection
from homeassistant.util.network import is_local
from . import indieauth from . import indieauth
@ -185,7 +186,14 @@ class AuthProvidersView(HomeAssistantView):
} }
) )
return self.json(providers) preselect_remember_me = not cloud_connection and is_local(remote_address)
return self.json(
{
"providers": providers,
"preselect_remember_me": preselect_remember_me,
}
)
def _prepare_result_json( def _prepare_result_json(

View file

@ -1,11 +1,9 @@
"""Support for tracking people.""" """Support for tracking people."""
from __future__ import annotations from __future__ import annotations
from http import HTTPStatus
import logging import logging
from typing import Any from typing import Any
from aiohttp import web
import voluptuous as vol import voluptuous as vol
from homeassistant.auth import EVENT_USER_REMOVED from homeassistant.auth import EVENT_USER_REMOVED
@ -15,7 +13,6 @@ from homeassistant.components.device_tracker import (
DOMAIN as DEVICE_TRACKER_DOMAIN, DOMAIN as DEVICE_TRACKER_DOMAIN,
SourceType, SourceType,
) )
from homeassistant.components.http.view import HomeAssistantView
from homeassistant.const import ( from homeassistant.const import (
ATTR_EDITABLE, ATTR_EDITABLE,
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
@ -388,8 +385,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass, DOMAIN, SERVICE_RELOAD, async_reload_yaml hass, DOMAIN, SERVICE_RELOAD, async_reload_yaml
) )
hass.http.register_view(ListPersonsView)
return True return True
@ -574,19 +569,3 @@ def _get_latest(prev: State | None, curr: State):
if prev is None or curr.last_updated > prev.last_updated: if prev is None or curr.last_updated > prev.last_updated:
return curr return curr
return prev return prev
class ListPersonsView(HomeAssistantView):
"""List all persons if request is made from a local network."""
requires_auth = False
url = "/api/person/list"
name = "api:person:list"
async def get(self, request: web.Request) -> web.Response:
"""Return a list of persons if request comes from a local IP."""
return self.json_message(
message="Not local",
status_code=HTTPStatus.BAD_REQUEST,
message_code="not_local",
)

View file

@ -6,7 +6,6 @@ from unittest.mock import patch
import pytest import pytest
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from . import BASE_CONFIG, async_setup_auth from . import BASE_CONFIG, async_setup_auth
@ -26,22 +25,30 @@ _TRUSTED_NETWORKS_CONFIG = {
@pytest.mark.parametrize( @pytest.mark.parametrize(
("provider_configs", "ip", "expected"), ("ip", "preselect_remember_me"),
[
("192.168.1.10", True),
("::ffff:192.168.0.10", True),
("1.2.3.4", False),
("2001:db8::1", False),
],
)
@pytest.mark.parametrize(
("provider_configs", "expected"),
[ [
( (
BASE_CONFIG, BASE_CONFIG,
None,
[{"name": "Example", "type": "insecure_example", "id": None}], [{"name": "Example", "type": "insecure_example", "id": None}],
), ),
( (
[_TRUSTED_NETWORKS_CONFIG], [{"type": "homeassistant"}],
None, [
[], {
), "name": "Home Assistant Local",
( "type": "homeassistant",
[_TRUSTED_NETWORKS_CONFIG], "id": None,
"192.168.0.1", }
[{"name": "Trusted Networks", "type": "trusted_networks", "id": None}], ],
), ),
], ],
) )
@ -49,8 +56,9 @@ async def test_fetch_auth_providers(
hass: HomeAssistant, hass: HomeAssistant,
aiohttp_client: ClientSessionGenerator, aiohttp_client: ClientSessionGenerator,
provider_configs: list[dict[str, Any]], provider_configs: list[dict[str, Any]],
ip: str | None,
expected: list[dict[str, Any]], expected: list[dict[str, Any]],
ip: str,
preselect_remember_me: bool,
) -> None: ) -> None:
"""Test fetching auth providers.""" """Test fetching auth providers."""
client = await async_setup_auth( client = await async_setup_auth(
@ -58,73 +66,37 @@ async def test_fetch_auth_providers(
) )
resp = await client.get("/auth/providers") resp = await client.get("/auth/providers")
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
assert await resp.json() == expected assert await resp.json() == {
"providers": expected,
"preselect_remember_me": preselect_remember_me,
async def _test_fetch_auth_providers_home_assistant(
hass: HomeAssistant,
aiohttp_client: ClientSessionGenerator,
ip: str,
) -> None:
"""Test fetching auth providers for homeassistant auth provider."""
client = await async_setup_auth(
hass, aiohttp_client, [{"type": "homeassistant"}], custom_ip=ip
)
expected = {
"name": "Home Assistant Local",
"type": "homeassistant",
"id": None,
} }
@pytest.mark.parametrize(
("ip", "expected"),
[
(
"192.168.0.1",
[{"name": "Trusted Networks", "type": "trusted_networks", "id": None}],
),
("::ffff:192.168.0.10", []),
("1.2.3.4", []),
("2001:db8::1", []),
],
)
async def test_fetch_auth_providers_trusted_network(
hass: HomeAssistant,
aiohttp_client: ClientSessionGenerator,
expected: list[dict[str, Any]],
ip: str,
) -> None:
"""Test fetching auth providers."""
client = await async_setup_auth(
hass, aiohttp_client, [_TRUSTED_NETWORKS_CONFIG], custom_ip=ip
)
resp = await client.get("/auth/providers") resp = await client.get("/auth/providers")
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
assert await resp.json() == [expected] assert (await resp.json())["providers"] == expected
@pytest.mark.parametrize(
"ip",
[
"192.168.0.10",
"::ffff:192.168.0.10",
"1.2.3.4",
"2001:db8::1",
],
)
async def test_fetch_auth_providers_home_assistant_person_not_loaded(
hass: HomeAssistant,
aiohttp_client: ClientSessionGenerator,
ip: str,
) -> None:
"""Test fetching auth providers for homeassistant auth provider, where person integration is not loaded."""
await _test_fetch_auth_providers_home_assistant(hass, aiohttp_client, ip)
@pytest.mark.parametrize(
("ip", "is_local"),
[
("192.168.0.10", True),
("::ffff:192.168.0.10", True),
("1.2.3.4", False),
("2001:db8::1", False),
],
)
async def test_fetch_auth_providers_home_assistant_person_loaded(
hass: HomeAssistant,
aiohttp_client: ClientSessionGenerator,
ip: str,
is_local: bool,
) -> None:
"""Test fetching auth providers for homeassistant auth provider, where person integration is loaded."""
domain = "person"
config = {domain: {"id": "1234", "name": "test person"}}
assert await async_setup_component(hass, domain, config)
await _test_fetch_auth_providers_home_assistant(
hass,
aiohttp_client,
ip,
)
async def test_fetch_auth_providers_onboarding( async def test_fetch_auth_providers_onboarding(

View file

@ -1,5 +1,4 @@
"""The tests for the person component.""" """The tests for the person component."""
from http import HTTPStatus
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import patch
@ -30,7 +29,7 @@ from homeassistant.setup import async_setup_component
from .conftest import DEVICE_TRACKER, DEVICE_TRACKER_2 from .conftest import DEVICE_TRACKER, DEVICE_TRACKER_2
from tests.common import MockUser, mock_component, mock_restore_cache from tests.common import MockUser, mock_component, mock_restore_cache
from tests.typing import ClientSessionGenerator, WebSocketGenerator from tests.typing import WebSocketGenerator
async def test_minimal_setup(hass: HomeAssistant) -> None: async def test_minimal_setup(hass: HomeAssistant) -> None:
@ -848,30 +847,3 @@ async def test_entities_in_person(hass: HomeAssistant) -> None:
"device_tracker.paulus_iphone", "device_tracker.paulus_iphone",
"device_tracker.paulus_ipad", "device_tracker.paulus_ipad",
] ]
async def test_list_persons(
hass: HomeAssistant,
hass_client_no_auth: ClientSessionGenerator,
hass_admin_user: MockUser,
) -> None:
"""Test listing persons from a not local ip address."""
user_id = hass_admin_user.id
admin = {"id": "1234", "name": "Admin", "user_id": user_id, "picture": "/bla"}
config = {
DOMAIN: [
admin,
{"id": "5678", "name": "Only a person"},
]
}
assert await async_setup_component(hass, DOMAIN, config)
await async_setup_component(hass, "api", {})
client = await hass_client_no_auth()
resp = await client.get("/api/person/list")
assert resp.status == HTTPStatus.BAD_REQUEST
result = await resp.json()
assert result == {"code": "not_local", "message": "Not local"}