Use modern WS API for auth integration + add auth provider type to refresh token info (#72552)
This commit is contained in:
parent
d092861926
commit
ff3374b4e0
2 changed files with 63 additions and 89 deletions
|
@ -148,43 +148,6 @@ from homeassistant.util import dt as dt_util
|
|||
from . import indieauth, login_flow, mfa_setup_flow
|
||||
|
||||
DOMAIN = "auth"
|
||||
WS_TYPE_CURRENT_USER = "auth/current_user"
|
||||
SCHEMA_WS_CURRENT_USER = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{vol.Required("type"): WS_TYPE_CURRENT_USER}
|
||||
)
|
||||
|
||||
WS_TYPE_LONG_LIVED_ACCESS_TOKEN = "auth/long_lived_access_token"
|
||||
SCHEMA_WS_LONG_LIVED_ACCESS_TOKEN = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{
|
||||
vol.Required("type"): WS_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||
vol.Required("lifespan"): int, # days
|
||||
vol.Required("client_name"): str,
|
||||
vol.Optional("client_icon"): str,
|
||||
}
|
||||
)
|
||||
|
||||
WS_TYPE_REFRESH_TOKENS = "auth/refresh_tokens"
|
||||
SCHEMA_WS_REFRESH_TOKENS = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{vol.Required("type"): WS_TYPE_REFRESH_TOKENS}
|
||||
)
|
||||
|
||||
WS_TYPE_DELETE_REFRESH_TOKEN = "auth/delete_refresh_token"
|
||||
SCHEMA_WS_DELETE_REFRESH_TOKEN = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{
|
||||
vol.Required("type"): WS_TYPE_DELETE_REFRESH_TOKEN,
|
||||
vol.Required("refresh_token_id"): str,
|
||||
}
|
||||
)
|
||||
|
||||
WS_TYPE_SIGN_PATH = "auth/sign_path"
|
||||
SCHEMA_WS_SIGN_PATH = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{
|
||||
vol.Required("type"): WS_TYPE_SIGN_PATH,
|
||||
vol.Required("path"): str,
|
||||
vol.Optional("expires", default=30): int,
|
||||
}
|
||||
)
|
||||
|
||||
RESULT_TYPE_CREDENTIALS = "credentials"
|
||||
|
||||
|
||||
|
@ -204,27 +167,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
hass.http.register_view(LinkUserView(retrieve_result))
|
||||
hass.http.register_view(OAuth2AuthorizeCallbackView())
|
||||
|
||||
websocket_api.async_register_command(
|
||||
hass, WS_TYPE_CURRENT_USER, websocket_current_user, SCHEMA_WS_CURRENT_USER
|
||||
)
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
WS_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||
websocket_create_long_lived_access_token,
|
||||
SCHEMA_WS_LONG_LIVED_ACCESS_TOKEN,
|
||||
)
|
||||
websocket_api.async_register_command(
|
||||
hass, WS_TYPE_REFRESH_TOKENS, websocket_refresh_tokens, SCHEMA_WS_REFRESH_TOKENS
|
||||
)
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
WS_TYPE_DELETE_REFRESH_TOKEN,
|
||||
websocket_delete_refresh_token,
|
||||
SCHEMA_WS_DELETE_REFRESH_TOKEN,
|
||||
)
|
||||
websocket_api.async_register_command(
|
||||
hass, WS_TYPE_SIGN_PATH, websocket_sign_path, SCHEMA_WS_SIGN_PATH
|
||||
)
|
||||
websocket_api.async_register_command(hass, websocket_current_user)
|
||||
websocket_api.async_register_command(hass, websocket_create_long_lived_access_token)
|
||||
websocket_api.async_register_command(hass, websocket_refresh_tokens)
|
||||
websocket_api.async_register_command(hass, websocket_delete_refresh_token)
|
||||
websocket_api.async_register_command(hass, websocket_sign_path)
|
||||
|
||||
await login_flow.async_setup(hass, store_result)
|
||||
await mfa_setup_flow.async_setup(hass)
|
||||
|
@ -476,6 +423,7 @@ def _create_auth_code_store():
|
|||
return store_result, retrieve_result
|
||||
|
||||
|
||||
@websocket_api.websocket_command({vol.Required("type"): "auth/current_user"})
|
||||
@websocket_api.ws_require_user()
|
||||
@websocket_api.async_response
|
||||
async def websocket_current_user(
|
||||
|
@ -513,6 +461,14 @@ async def websocket_current_user(
|
|||
)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "auth/long_lived_access_token",
|
||||
vol.Required("lifespan"): int, # days
|
||||
vol.Required("client_name"): str,
|
||||
vol.Optional("client_icon"): str,
|
||||
}
|
||||
)
|
||||
@websocket_api.ws_require_user()
|
||||
@websocket_api.async_response
|
||||
async def websocket_create_long_lived_access_token(
|
||||
|
@ -530,13 +486,13 @@ async def websocket_create_long_lived_access_token(
|
|||
try:
|
||||
access_token = hass.auth.async_create_access_token(refresh_token)
|
||||
except InvalidAuthError as exc:
|
||||
return websocket_api.error_message(
|
||||
msg["id"], websocket_api.const.ERR_UNAUTHORIZED, str(exc)
|
||||
)
|
||||
connection.send_error(msg["id"], websocket_api.const.ERR_UNAUTHORIZED, str(exc))
|
||||
return
|
||||
|
||||
connection.send_message(websocket_api.result_message(msg["id"], access_token))
|
||||
connection.send_result(msg["id"], access_token)
|
||||
|
||||
|
||||
@websocket_api.websocket_command({vol.Required("type"): "auth/refresh_tokens"})
|
||||
@websocket_api.ws_require_user()
|
||||
@callback
|
||||
def websocket_refresh_tokens(
|
||||
|
@ -544,27 +500,38 @@ def websocket_refresh_tokens(
|
|||
):
|
||||
"""Return metadata of users refresh tokens."""
|
||||
current_id = connection.refresh_token_id
|
||||
connection.send_message(
|
||||
websocket_api.result_message(
|
||||
msg["id"],
|
||||
[
|
||||
{
|
||||
"id": refresh.id,
|
||||
"client_id": refresh.client_id,
|
||||
"client_name": refresh.client_name,
|
||||
"client_icon": refresh.client_icon,
|
||||
"type": refresh.token_type,
|
||||
"created_at": refresh.created_at,
|
||||
"is_current": refresh.id == current_id,
|
||||
"last_used_at": refresh.last_used_at,
|
||||
"last_used_ip": refresh.last_used_ip,
|
||||
}
|
||||
for refresh in connection.user.refresh_tokens.values()
|
||||
],
|
||||
|
||||
tokens = []
|
||||
for refresh in connection.user.refresh_tokens.values():
|
||||
if refresh.credential:
|
||||
auth_provider_type = refresh.credential.auth_provider_type
|
||||
else:
|
||||
auth_provider_type = None
|
||||
|
||||
tokens.append(
|
||||
{
|
||||
"id": refresh.id,
|
||||
"client_id": refresh.client_id,
|
||||
"client_name": refresh.client_name,
|
||||
"client_icon": refresh.client_icon,
|
||||
"type": refresh.token_type,
|
||||
"created_at": refresh.created_at,
|
||||
"is_current": refresh.id == current_id,
|
||||
"last_used_at": refresh.last_used_at,
|
||||
"last_used_ip": refresh.last_used_ip,
|
||||
"auth_provider_type": auth_provider_type,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
connection.send_result(msg["id"], tokens)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "auth/delete_refresh_token",
|
||||
vol.Required("refresh_token_id"): str,
|
||||
}
|
||||
)
|
||||
@websocket_api.ws_require_user()
|
||||
@websocket_api.async_response
|
||||
async def websocket_delete_refresh_token(
|
||||
|
@ -574,15 +541,21 @@ async def websocket_delete_refresh_token(
|
|||
refresh_token = connection.user.refresh_tokens.get(msg["refresh_token_id"])
|
||||
|
||||
if refresh_token is None:
|
||||
return websocket_api.error_message(
|
||||
msg["id"], "invalid_token_id", "Received invalid token"
|
||||
)
|
||||
connection.send_error(msg["id"], "invalid_token_id", "Received invalid token")
|
||||
return
|
||||
|
||||
await hass.auth.async_remove_refresh_token(refresh_token)
|
||||
|
||||
connection.send_message(websocket_api.result_message(msg["id"], {}))
|
||||
connection.send_result(msg["id"], {})
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "auth/sign_path",
|
||||
vol.Required("path"): str,
|
||||
vol.Optional("expires", default=30): int,
|
||||
}
|
||||
)
|
||||
@websocket_api.ws_require_user()
|
||||
@callback
|
||||
def websocket_sign_path(
|
||||
|
|
|
@ -193,7 +193,7 @@ async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
|
|||
user = refresh_token.user
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
|
||||
await client.send_json({"id": 5, "type": auth.WS_TYPE_CURRENT_USER})
|
||||
await client.send_json({"id": 5, "type": "auth/current_user"})
|
||||
|
||||
result = await client.receive_json()
|
||||
assert result["success"], result
|
||||
|
@ -410,7 +410,7 @@ async def test_ws_long_lived_access_token(hass, hass_ws_client, hass_access_toke
|
|||
await ws_client.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
"type": auth.WS_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||
"type": "auth/long_lived_access_token",
|
||||
"client_name": "GPS Logger",
|
||||
"lifespan": 365,
|
||||
}
|
||||
|
@ -434,7 +434,7 @@ async def test_ws_refresh_tokens(hass, hass_ws_client, hass_access_token):
|
|||
|
||||
ws_client = await hass_ws_client(hass, hass_access_token)
|
||||
|
||||
await ws_client.send_json({"id": 5, "type": auth.WS_TYPE_REFRESH_TOKENS})
|
||||
await ws_client.send_json({"id": 5, "type": "auth/refresh_tokens"})
|
||||
|
||||
result = await ws_client.receive_json()
|
||||
assert result["success"], result
|
||||
|
@ -450,6 +450,7 @@ async def test_ws_refresh_tokens(hass, hass_ws_client, hass_access_token):
|
|||
assert token["is_current"] is True
|
||||
assert token["last_used_at"] == refresh_token.last_used_at.isoformat()
|
||||
assert token["last_used_ip"] == refresh_token.last_used_ip
|
||||
assert token["auth_provider_type"] == "homeassistant"
|
||||
|
||||
|
||||
async def test_ws_delete_refresh_token(
|
||||
|
@ -468,7 +469,7 @@ async def test_ws_delete_refresh_token(
|
|||
await ws_client.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
"type": auth.WS_TYPE_DELETE_REFRESH_TOKEN,
|
||||
"type": "auth/delete_refresh_token",
|
||||
"refresh_token_id": refresh_token.id,
|
||||
}
|
||||
)
|
||||
|
@ -490,7 +491,7 @@ async def test_ws_sign_path(hass, hass_ws_client, hass_access_token):
|
|||
await ws_client.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
"type": auth.WS_TYPE_SIGN_PATH,
|
||||
"type": "auth/sign_path",
|
||||
"path": "/api/hello",
|
||||
"expires": 20,
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue