diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index 30b36a40f32..1dc483eec6e 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -148,43 +148,6 @@ from homeassistant.util import dt as dt_util from . import indieauth, login_flow, mfa_setup_flow DOMAIN = "auth" -WS_TYPE_CURRENT_USER = "auth/current_user" -SCHEMA_WS_CURRENT_USER = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - {vol.Required("type"): WS_TYPE_CURRENT_USER} -) - -WS_TYPE_LONG_LIVED_ACCESS_TOKEN = "auth/long_lived_access_token" -SCHEMA_WS_LONG_LIVED_ACCESS_TOKEN = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - { - vol.Required("type"): WS_TYPE_LONG_LIVED_ACCESS_TOKEN, - vol.Required("lifespan"): int, # days - vol.Required("client_name"): str, - vol.Optional("client_icon"): str, - } -) - -WS_TYPE_REFRESH_TOKENS = "auth/refresh_tokens" -SCHEMA_WS_REFRESH_TOKENS = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - {vol.Required("type"): WS_TYPE_REFRESH_TOKENS} -) - -WS_TYPE_DELETE_REFRESH_TOKEN = "auth/delete_refresh_token" -SCHEMA_WS_DELETE_REFRESH_TOKEN = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - { - vol.Required("type"): WS_TYPE_DELETE_REFRESH_TOKEN, - vol.Required("refresh_token_id"): str, - } -) - -WS_TYPE_SIGN_PATH = "auth/sign_path" -SCHEMA_WS_SIGN_PATH = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - { - vol.Required("type"): WS_TYPE_SIGN_PATH, - vol.Required("path"): str, - vol.Optional("expires", default=30): int, - } -) - RESULT_TYPE_CREDENTIALS = "credentials" @@ -204,27 +167,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: hass.http.register_view(LinkUserView(retrieve_result)) hass.http.register_view(OAuth2AuthorizeCallbackView()) - websocket_api.async_register_command( - hass, WS_TYPE_CURRENT_USER, websocket_current_user, SCHEMA_WS_CURRENT_USER - ) - websocket_api.async_register_command( - hass, - WS_TYPE_LONG_LIVED_ACCESS_TOKEN, - websocket_create_long_lived_access_token, - SCHEMA_WS_LONG_LIVED_ACCESS_TOKEN, - ) - websocket_api.async_register_command( - hass, WS_TYPE_REFRESH_TOKENS, websocket_refresh_tokens, SCHEMA_WS_REFRESH_TOKENS - ) - websocket_api.async_register_command( - hass, - WS_TYPE_DELETE_REFRESH_TOKEN, - websocket_delete_refresh_token, - SCHEMA_WS_DELETE_REFRESH_TOKEN, - ) - websocket_api.async_register_command( - hass, WS_TYPE_SIGN_PATH, websocket_sign_path, SCHEMA_WS_SIGN_PATH - ) + websocket_api.async_register_command(hass, websocket_current_user) + websocket_api.async_register_command(hass, websocket_create_long_lived_access_token) + websocket_api.async_register_command(hass, websocket_refresh_tokens) + websocket_api.async_register_command(hass, websocket_delete_refresh_token) + websocket_api.async_register_command(hass, websocket_sign_path) await login_flow.async_setup(hass, store_result) await mfa_setup_flow.async_setup(hass) @@ -476,6 +423,7 @@ def _create_auth_code_store(): return store_result, retrieve_result +@websocket_api.websocket_command({vol.Required("type"): "auth/current_user"}) @websocket_api.ws_require_user() @websocket_api.async_response async def websocket_current_user( @@ -513,6 +461,14 @@ async def websocket_current_user( ) +@websocket_api.websocket_command( + { + vol.Required("type"): "auth/long_lived_access_token", + vol.Required("lifespan"): int, # days + vol.Required("client_name"): str, + vol.Optional("client_icon"): str, + } +) @websocket_api.ws_require_user() @websocket_api.async_response async def websocket_create_long_lived_access_token( @@ -530,13 +486,13 @@ async def websocket_create_long_lived_access_token( try: access_token = hass.auth.async_create_access_token(refresh_token) except InvalidAuthError as exc: - return websocket_api.error_message( - msg["id"], websocket_api.const.ERR_UNAUTHORIZED, str(exc) - ) + connection.send_error(msg["id"], websocket_api.const.ERR_UNAUTHORIZED, str(exc)) + return - connection.send_message(websocket_api.result_message(msg["id"], access_token)) + connection.send_result(msg["id"], access_token) +@websocket_api.websocket_command({vol.Required("type"): "auth/refresh_tokens"}) @websocket_api.ws_require_user() @callback def websocket_refresh_tokens( @@ -544,27 +500,38 @@ def websocket_refresh_tokens( ): """Return metadata of users refresh tokens.""" current_id = connection.refresh_token_id - connection.send_message( - websocket_api.result_message( - msg["id"], - [ - { - "id": refresh.id, - "client_id": refresh.client_id, - "client_name": refresh.client_name, - "client_icon": refresh.client_icon, - "type": refresh.token_type, - "created_at": refresh.created_at, - "is_current": refresh.id == current_id, - "last_used_at": refresh.last_used_at, - "last_used_ip": refresh.last_used_ip, - } - for refresh in connection.user.refresh_tokens.values() - ], + + tokens = [] + for refresh in connection.user.refresh_tokens.values(): + if refresh.credential: + auth_provider_type = refresh.credential.auth_provider_type + else: + auth_provider_type = None + + tokens.append( + { + "id": refresh.id, + "client_id": refresh.client_id, + "client_name": refresh.client_name, + "client_icon": refresh.client_icon, + "type": refresh.token_type, + "created_at": refresh.created_at, + "is_current": refresh.id == current_id, + "last_used_at": refresh.last_used_at, + "last_used_ip": refresh.last_used_ip, + "auth_provider_type": auth_provider_type, + } ) - ) + + connection.send_result(msg["id"], tokens) +@websocket_api.websocket_command( + { + vol.Required("type"): "auth/delete_refresh_token", + vol.Required("refresh_token_id"): str, + } +) @websocket_api.ws_require_user() @websocket_api.async_response async def websocket_delete_refresh_token( @@ -574,15 +541,21 @@ async def websocket_delete_refresh_token( refresh_token = connection.user.refresh_tokens.get(msg["refresh_token_id"]) if refresh_token is None: - return websocket_api.error_message( - msg["id"], "invalid_token_id", "Received invalid token" - ) + connection.send_error(msg["id"], "invalid_token_id", "Received invalid token") + return await hass.auth.async_remove_refresh_token(refresh_token) - connection.send_message(websocket_api.result_message(msg["id"], {})) + connection.send_result(msg["id"], {}) +@websocket_api.websocket_command( + { + vol.Required("type"): "auth/sign_path", + vol.Required("path"): str, + vol.Optional("expires", default=30): int, + } +) @websocket_api.ws_require_user() @callback def websocket_sign_path( diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py index f6d0695d97d..ef231950bd9 100644 --- a/tests/components/auth/test_init.py +++ b/tests/components/auth/test_init.py @@ -193,7 +193,7 @@ async def test_ws_current_user(hass, hass_ws_client, hass_access_token): user = refresh_token.user client = await hass_ws_client(hass, hass_access_token) - await client.send_json({"id": 5, "type": auth.WS_TYPE_CURRENT_USER}) + await client.send_json({"id": 5, "type": "auth/current_user"}) result = await client.receive_json() assert result["success"], result @@ -410,7 +410,7 @@ async def test_ws_long_lived_access_token(hass, hass_ws_client, hass_access_toke await ws_client.send_json( { "id": 5, - "type": auth.WS_TYPE_LONG_LIVED_ACCESS_TOKEN, + "type": "auth/long_lived_access_token", "client_name": "GPS Logger", "lifespan": 365, } @@ -434,7 +434,7 @@ async def test_ws_refresh_tokens(hass, hass_ws_client, hass_access_token): ws_client = await hass_ws_client(hass, hass_access_token) - await ws_client.send_json({"id": 5, "type": auth.WS_TYPE_REFRESH_TOKENS}) + await ws_client.send_json({"id": 5, "type": "auth/refresh_tokens"}) result = await ws_client.receive_json() assert result["success"], result @@ -450,6 +450,7 @@ async def test_ws_refresh_tokens(hass, hass_ws_client, hass_access_token): assert token["is_current"] is True assert token["last_used_at"] == refresh_token.last_used_at.isoformat() assert token["last_used_ip"] == refresh_token.last_used_ip + assert token["auth_provider_type"] == "homeassistant" async def test_ws_delete_refresh_token( @@ -468,7 +469,7 @@ async def test_ws_delete_refresh_token( await ws_client.send_json( { "id": 5, - "type": auth.WS_TYPE_DELETE_REFRESH_TOKEN, + "type": "auth/delete_refresh_token", "refresh_token_id": refresh_token.id, } ) @@ -490,7 +491,7 @@ async def test_ws_sign_path(hass, hass_ws_client, hass_access_token): await ws_client.send_json( { "id": 5, - "type": auth.WS_TYPE_SIGN_PATH, + "type": "auth/sign_path", "path": "/api/hello", "expires": 20, }