Use modern WS API for auth integration + add auth provider type to refresh token info (#72552)

This commit is contained in:
Paulus Schoutsen 2022-05-26 13:06:34 -07:00 committed by GitHub
parent d092861926
commit ff3374b4e0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 89 deletions

View file

@ -148,43 +148,6 @@ from homeassistant.util import dt as dt_util
from . import indieauth, login_flow, mfa_setup_flow
DOMAIN = "auth"
WS_TYPE_CURRENT_USER = "auth/current_user"
SCHEMA_WS_CURRENT_USER = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{vol.Required("type"): WS_TYPE_CURRENT_USER}
)
WS_TYPE_LONG_LIVED_ACCESS_TOKEN = "auth/long_lived_access_token"
SCHEMA_WS_LONG_LIVED_ACCESS_TOKEN = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): WS_TYPE_LONG_LIVED_ACCESS_TOKEN,
vol.Required("lifespan"): int, # days
vol.Required("client_name"): str,
vol.Optional("client_icon"): str,
}
)
WS_TYPE_REFRESH_TOKENS = "auth/refresh_tokens"
SCHEMA_WS_REFRESH_TOKENS = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{vol.Required("type"): WS_TYPE_REFRESH_TOKENS}
)
WS_TYPE_DELETE_REFRESH_TOKEN = "auth/delete_refresh_token"
SCHEMA_WS_DELETE_REFRESH_TOKEN = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): WS_TYPE_DELETE_REFRESH_TOKEN,
vol.Required("refresh_token_id"): str,
}
)
WS_TYPE_SIGN_PATH = "auth/sign_path"
SCHEMA_WS_SIGN_PATH = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): WS_TYPE_SIGN_PATH,
vol.Required("path"): str,
vol.Optional("expires", default=30): int,
}
)
RESULT_TYPE_CREDENTIALS = "credentials"
@ -204,27 +167,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass.http.register_view(LinkUserView(retrieve_result))
hass.http.register_view(OAuth2AuthorizeCallbackView())
websocket_api.async_register_command(
hass, WS_TYPE_CURRENT_USER, websocket_current_user, SCHEMA_WS_CURRENT_USER
)
websocket_api.async_register_command(
hass,
WS_TYPE_LONG_LIVED_ACCESS_TOKEN,
websocket_create_long_lived_access_token,
SCHEMA_WS_LONG_LIVED_ACCESS_TOKEN,
)
websocket_api.async_register_command(
hass, WS_TYPE_REFRESH_TOKENS, websocket_refresh_tokens, SCHEMA_WS_REFRESH_TOKENS
)
websocket_api.async_register_command(
hass,
WS_TYPE_DELETE_REFRESH_TOKEN,
websocket_delete_refresh_token,
SCHEMA_WS_DELETE_REFRESH_TOKEN,
)
websocket_api.async_register_command(
hass, WS_TYPE_SIGN_PATH, websocket_sign_path, SCHEMA_WS_SIGN_PATH
)
websocket_api.async_register_command(hass, websocket_current_user)
websocket_api.async_register_command(hass, websocket_create_long_lived_access_token)
websocket_api.async_register_command(hass, websocket_refresh_tokens)
websocket_api.async_register_command(hass, websocket_delete_refresh_token)
websocket_api.async_register_command(hass, websocket_sign_path)
await login_flow.async_setup(hass, store_result)
await mfa_setup_flow.async_setup(hass)
@ -476,6 +423,7 @@ def _create_auth_code_store():
return store_result, retrieve_result
@websocket_api.websocket_command({vol.Required("type"): "auth/current_user"})
@websocket_api.ws_require_user()
@websocket_api.async_response
async def websocket_current_user(
@ -513,6 +461,14 @@ async def websocket_current_user(
)
@websocket_api.websocket_command(
{
vol.Required("type"): "auth/long_lived_access_token",
vol.Required("lifespan"): int, # days
vol.Required("client_name"): str,
vol.Optional("client_icon"): str,
}
)
@websocket_api.ws_require_user()
@websocket_api.async_response
async def websocket_create_long_lived_access_token(
@ -530,13 +486,13 @@ async def websocket_create_long_lived_access_token(
try:
access_token = hass.auth.async_create_access_token(refresh_token)
except InvalidAuthError as exc:
return websocket_api.error_message(
msg["id"], websocket_api.const.ERR_UNAUTHORIZED, str(exc)
)
connection.send_error(msg["id"], websocket_api.const.ERR_UNAUTHORIZED, str(exc))
return
connection.send_message(websocket_api.result_message(msg["id"], access_token))
connection.send_result(msg["id"], access_token)
@websocket_api.websocket_command({vol.Required("type"): "auth/refresh_tokens"})
@websocket_api.ws_require_user()
@callback
def websocket_refresh_tokens(
@ -544,27 +500,38 @@ def websocket_refresh_tokens(
):
"""Return metadata of users refresh tokens."""
current_id = connection.refresh_token_id
connection.send_message(
websocket_api.result_message(
msg["id"],
[
{
"id": refresh.id,
"client_id": refresh.client_id,
"client_name": refresh.client_name,
"client_icon": refresh.client_icon,
"type": refresh.token_type,
"created_at": refresh.created_at,
"is_current": refresh.id == current_id,
"last_used_at": refresh.last_used_at,
"last_used_ip": refresh.last_used_ip,
}
for refresh in connection.user.refresh_tokens.values()
],
tokens = []
for refresh in connection.user.refresh_tokens.values():
if refresh.credential:
auth_provider_type = refresh.credential.auth_provider_type
else:
auth_provider_type = None
tokens.append(
{
"id": refresh.id,
"client_id": refresh.client_id,
"client_name": refresh.client_name,
"client_icon": refresh.client_icon,
"type": refresh.token_type,
"created_at": refresh.created_at,
"is_current": refresh.id == current_id,
"last_used_at": refresh.last_used_at,
"last_used_ip": refresh.last_used_ip,
"auth_provider_type": auth_provider_type,
}
)
)
connection.send_result(msg["id"], tokens)
@websocket_api.websocket_command(
{
vol.Required("type"): "auth/delete_refresh_token",
vol.Required("refresh_token_id"): str,
}
)
@websocket_api.ws_require_user()
@websocket_api.async_response
async def websocket_delete_refresh_token(
@ -574,15 +541,21 @@ async def websocket_delete_refresh_token(
refresh_token = connection.user.refresh_tokens.get(msg["refresh_token_id"])
if refresh_token is None:
return websocket_api.error_message(
msg["id"], "invalid_token_id", "Received invalid token"
)
connection.send_error(msg["id"], "invalid_token_id", "Received invalid token")
return
await hass.auth.async_remove_refresh_token(refresh_token)
connection.send_message(websocket_api.result_message(msg["id"], {}))
connection.send_result(msg["id"], {})
@websocket_api.websocket_command(
{
vol.Required("type"): "auth/sign_path",
vol.Required("path"): str,
vol.Optional("expires", default=30): int,
}
)
@websocket_api.ws_require_user()
@callback
def websocket_sign_path(

View file

@ -193,7 +193,7 @@ async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
user = refresh_token.user
client = await hass_ws_client(hass, hass_access_token)
await client.send_json({"id": 5, "type": auth.WS_TYPE_CURRENT_USER})
await client.send_json({"id": 5, "type": "auth/current_user"})
result = await client.receive_json()
assert result["success"], result
@ -410,7 +410,7 @@ async def test_ws_long_lived_access_token(hass, hass_ws_client, hass_access_toke
await ws_client.send_json(
{
"id": 5,
"type": auth.WS_TYPE_LONG_LIVED_ACCESS_TOKEN,
"type": "auth/long_lived_access_token",
"client_name": "GPS Logger",
"lifespan": 365,
}
@ -434,7 +434,7 @@ async def test_ws_refresh_tokens(hass, hass_ws_client, hass_access_token):
ws_client = await hass_ws_client(hass, hass_access_token)
await ws_client.send_json({"id": 5, "type": auth.WS_TYPE_REFRESH_TOKENS})
await ws_client.send_json({"id": 5, "type": "auth/refresh_tokens"})
result = await ws_client.receive_json()
assert result["success"], result
@ -450,6 +450,7 @@ async def test_ws_refresh_tokens(hass, hass_ws_client, hass_access_token):
assert token["is_current"] is True
assert token["last_used_at"] == refresh_token.last_used_at.isoformat()
assert token["last_used_ip"] == refresh_token.last_used_ip
assert token["auth_provider_type"] == "homeassistant"
async def test_ws_delete_refresh_token(
@ -468,7 +469,7 @@ async def test_ws_delete_refresh_token(
await ws_client.send_json(
{
"id": 5,
"type": auth.WS_TYPE_DELETE_REFRESH_TOKEN,
"type": "auth/delete_refresh_token",
"refresh_token_id": refresh_token.id,
}
)
@ -490,7 +491,7 @@ async def test_ws_sign_path(hass, hass_ws_client, hass_access_token):
await ws_client.send_json(
{
"id": 5,
"type": auth.WS_TYPE_SIGN_PATH,
"type": "auth/sign_path",
"path": "/api/hello",
"expires": 20,
}