Introduce only_supervisor for @websocket_api.ws_require_user() (#61298)
This commit is contained in:
parent
1f1a29cada
commit
a13ae85982
7 changed files with 70 additions and 13 deletions
|
@ -20,6 +20,7 @@ from homeassistant.const import (
|
|||
ATTR_MANUFACTURER,
|
||||
ATTR_NAME,
|
||||
EVENT_CORE_CONFIG_UPDATE,
|
||||
HASSIO_USER_NAME,
|
||||
SERVICE_HOMEASSISTANT_RESTART,
|
||||
SERVICE_HOMEASSISTANT_STOP,
|
||||
Platform,
|
||||
|
@ -440,11 +441,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: # noqa:
|
|||
|
||||
# Migrate old name
|
||||
if user.name == "Hass.io":
|
||||
await hass.auth.async_update_user(user, name="Supervisor")
|
||||
await hass.auth.async_update_user(user, name=HASSIO_USER_NAME)
|
||||
|
||||
if refresh_token is None:
|
||||
user = await hass.auth.async_create_system_user(
|
||||
"Supervisor", group_ids=[GROUP_ID_ADMIN]
|
||||
HASSIO_USER_NAME, group_ids=[GROUP_ID_ADMIN]
|
||||
)
|
||||
refresh_token = await hass.auth.async_create_refresh_token(user)
|
||||
data["hassio_user"] = user.id
|
||||
|
|
|
@ -113,7 +113,7 @@ def ws_info(
|
|||
connection.send_result(msg["id"], recorder_info)
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.ws_require_user(only_supervisor=True)
|
||||
@websocket_api.websocket_command({vol.Required("type"): "backup/start"})
|
||||
@websocket_api.async_response
|
||||
async def ws_backup_start(
|
||||
|
@ -131,7 +131,7 @@ async def ws_backup_start(
|
|||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.ws_require_user(only_supervisor=True)
|
||||
@websocket_api.websocket_command({vol.Required("type"): "backup/end"})
|
||||
@websocket_api.async_response
|
||||
async def ws_backup_end(
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import HASSIO_USER_NAME
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import Unauthorized
|
||||
|
||||
|
@ -70,6 +71,7 @@ def ws_require_user(
|
|||
allow_system_user: bool = True,
|
||||
only_active_user: bool = True,
|
||||
only_inactive_user: bool = False,
|
||||
only_supervisor: bool = False,
|
||||
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
|
||||
"""Decorate function validating login user exist in current WS connection.
|
||||
|
||||
|
@ -111,6 +113,10 @@ def ws_require_user(
|
|||
output_error("only_inactive_user", "Not allowed as active user")
|
||||
return
|
||||
|
||||
if only_supervisor and connection.user.name != HASSIO_USER_NAME:
|
||||
output_error("only_supervisor", "Only allowed as Supervisor")
|
||||
return
|
||||
|
||||
return func(hass, connection, msg)
|
||||
|
||||
return check_current_user
|
||||
|
|
|
@ -756,3 +756,6 @@ ENTITY_CATEGORIES: Final[list[str]] = [
|
|||
CAST_APP_ID_HOMEASSISTANT_MEDIA: Final = "B45F4572"
|
||||
# The ID of the Home Assistant Lovelace Cast App
|
||||
CAST_APP_ID_HOMEASSISTANT_LOVELACE: Final = "A078F6B0"
|
||||
|
||||
# User used by Supervisor
|
||||
HASSIO_USER_NAME = "Supervisor"
|
||||
|
|
|
@ -360,9 +360,11 @@ async def test_recorder_info_migration_queue_exhausted(hass, hass_ws_client):
|
|||
assert response["result"]["thread_running"] is True
|
||||
|
||||
|
||||
async def test_backup_start_no_recorder(hass, hass_ws_client):
|
||||
async def test_backup_start_no_recorder(
|
||||
hass, hass_ws_client, hass_supervisor_access_token
|
||||
):
|
||||
"""Test getting backup start when recorder is not present."""
|
||||
client = await hass_ws_client()
|
||||
client = await hass_ws_client(hass, hass_supervisor_access_token)
|
||||
|
||||
await client.send_json({"id": 1, "type": "backup/start"})
|
||||
response = await client.receive_json()
|
||||
|
@ -370,9 +372,9 @@ async def test_backup_start_no_recorder(hass, hass_ws_client):
|
|||
assert response["error"]["code"] == "unknown_command"
|
||||
|
||||
|
||||
async def test_backup_start_timeout(hass, hass_ws_client):
|
||||
async def test_backup_start_timeout(hass, hass_ws_client, hass_supervisor_access_token):
|
||||
"""Test getting backup start when recorder is not present."""
|
||||
client = await hass_ws_client()
|
||||
client = await hass_ws_client(hass, hass_supervisor_access_token)
|
||||
await async_init_recorder_component(hass)
|
||||
|
||||
# Ensure there are no queued events
|
||||
|
@ -388,9 +390,9 @@ async def test_backup_start_timeout(hass, hass_ws_client):
|
|||
await client.send_json({"id": 2, "type": "backup/end"})
|
||||
|
||||
|
||||
async def test_backup_end(hass, hass_ws_client):
|
||||
async def test_backup_end(hass, hass_ws_client, hass_supervisor_access_token):
|
||||
"""Test backup start."""
|
||||
client = await hass_ws_client()
|
||||
client = await hass_ws_client(hass, hass_supervisor_access_token)
|
||||
await async_init_recorder_component(hass)
|
||||
|
||||
# Ensure there are no queued events
|
||||
|
@ -405,9 +407,11 @@ async def test_backup_end(hass, hass_ws_client):
|
|||
assert response["success"]
|
||||
|
||||
|
||||
async def test_backup_end_without_start(hass, hass_ws_client):
|
||||
async def test_backup_end_without_start(
|
||||
hass, hass_ws_client, hass_supervisor_access_token
|
||||
):
|
||||
"""Test backup start."""
|
||||
client = await hass_ws_client()
|
||||
client = await hass_ws_client(hass, hass_supervisor_access_token)
|
||||
await async_init_recorder_component(hass)
|
||||
|
||||
# Ensure there are no queued events
|
||||
|
|
|
@ -66,3 +66,26 @@ async def test_async_response_request_context(hass, websocket_client):
|
|||
assert msg["id"] == 7
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "not_found"
|
||||
|
||||
|
||||
async def test_supervisor_only(hass, websocket_client):
|
||||
"""Test that only the Supervisor can make requests."""
|
||||
|
||||
@websocket_api.ws_require_user(only_supervisor=True)
|
||||
@websocket_api.websocket_command({"type": "test-require-supervisor-user"})
|
||||
def require_supervisor_request(hass, connection, msg):
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
websocket_api.async_register_command(hass, require_supervisor_request)
|
||||
|
||||
await websocket_client.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
"type": "test-require-supervisor-user",
|
||||
}
|
||||
)
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 5
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "only_supervisor"
|
||||
|
|
|
@ -26,7 +26,7 @@ from homeassistant.components.websocket_api.auth import (
|
|||
TYPE_AUTH_REQUIRED,
|
||||
)
|
||||
from homeassistant.components.websocket_api.http import URL
|
||||
from homeassistant.const import ATTR_NOW, EVENT_TIME_CHANGED
|
||||
from homeassistant.const import ATTR_NOW, EVENT_TIME_CHANGED, HASSIO_USER_NAME
|
||||
from homeassistant.helpers import config_entry_oauth2_flow, event
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import location
|
||||
|
@ -405,6 +405,26 @@ def hass_read_only_access_token(hass, hass_read_only_user, local_auth):
|
|||
return hass.auth.async_create_access_token(refresh_token)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hass_supervisor_user(hass, local_auth):
|
||||
"""Return the Home Assistant Supervisor user."""
|
||||
admin_group = hass.loop.run_until_complete(
|
||||
hass.auth.async_get_group(GROUP_ID_ADMIN)
|
||||
)
|
||||
return MockUser(
|
||||
name=HASSIO_USER_NAME, groups=[admin_group], system_generated=True
|
||||
).add_to_hass(hass)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hass_supervisor_access_token(hass, hass_supervisor_user, local_auth):
|
||||
"""Return a Home Assistant Supervisor access token."""
|
||||
refresh_token = hass.loop.run_until_complete(
|
||||
hass.auth.async_create_refresh_token(hass_supervisor_user)
|
||||
)
|
||||
return hass.auth.async_create_access_token(refresh_token)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def legacy_auth(hass):
|
||||
"""Load legacy API password provider."""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue