From 14f68ec1a92ac7e41a1340f76255d6596affdd15 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 29 Aug 2022 19:28:42 -0400 Subject: [PATCH] Store redirect URI in context instead of asking each time (#77380) * Store redirect URI in context instead of asking each time * Fix tests --- homeassistant/components/auth/login_flow.py | 18 +++++++------- .../components/config/config_entries.py | 1 + homeassistant/data_entry_flow.py | 1 + homeassistant/helpers/data_entry_flow.py | 1 + tests/components/auth/test_init.py | 2 -- tests/components/auth/test_init_link_user.py | 1 - tests/components/auth/test_login_flow.py | 24 ++++++++++++------- .../components/philips_js/test_config_flow.py | 1 + tests/components/subaru/test_config_flow.py | 2 ++ 9 files changed, 30 insertions(+), 21 deletions(-) diff --git a/homeassistant/components/auth/login_flow.py b/homeassistant/components/auth/login_flow.py index bb13431bfa7..df076a1b4c8 100644 --- a/homeassistant/components/auth/login_flow.py +++ b/homeassistant/components/auth/login_flow.py @@ -193,7 +193,6 @@ class LoginFlowBaseView(HomeAssistantView): self, request: web.Request, client_id: str, - redirect_uri: str, result: data_entry_flow.FlowResult, ) -> web.Response: """Convert the flow result to a response.""" @@ -214,10 +213,13 @@ class LoginFlowBaseView(HomeAssistantView): hass: HomeAssistant = request.app["hass"] - if not await indieauth.verify_redirect_uri(hass, client_id, redirect_uri): + if not await indieauth.verify_redirect_uri( + hass, client_id, result["context"]["redirect_uri"] + ): return self.json_message("Invalid redirect URI", HTTPStatus.FORBIDDEN) result.pop("data") + result.pop("context") result_obj: Credentials = result.pop("result") @@ -278,6 +280,7 @@ class LoginFlowIndexView(LoginFlowBaseView): context={ "ip_address": ip_address(request.remote), # type: ignore[arg-type] "credential_only": data.get("type") == "link_user", + "redirect_uri": redirect_uri, }, ) except data_entry_flow.UnknownHandler: @@ -287,9 +290,7 @@ class LoginFlowIndexView(LoginFlowBaseView): "Handler does not support init", HTTPStatus.BAD_REQUEST ) - return await self._async_flow_result_to_response( - request, client_id, redirect_uri, result - ) + return await self._async_flow_result_to_response(request, client_id, result) class LoginFlowResourceView(LoginFlowBaseView): @@ -304,7 +305,7 @@ class LoginFlowResourceView(LoginFlowBaseView): @RequestDataValidator( vol.Schema( - {vol.Required("client_id"): str, vol.Required("redirect_uri"): str}, + {vol.Required("client_id"): str}, extra=vol.ALLOW_EXTRA, ) ) @@ -314,7 +315,6 @@ class LoginFlowResourceView(LoginFlowBaseView): ) -> web.Response: """Handle progressing a login flow request.""" client_id: str = data.pop("client_id") - redirect_uri: str = data.pop("redirect_uri") if not indieauth.verify_client_id(client_id): return self.json_message("Invalid client id", HTTPStatus.BAD_REQUEST) @@ -330,9 +330,7 @@ class LoginFlowResourceView(LoginFlowBaseView): except vol.Invalid: return self.json_message("User input malformed", HTTPStatus.BAD_REQUEST) - return await self._async_flow_result_to_response( - request, client_id, redirect_uri, result - ) + return await self._async_flow_result_to_response(request, client_id, result) async def delete(self, request: web.Request, flow_id: str) -> web.Response: """Cancel a flow in progress.""" diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index ac452666103..54132080f08 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -121,6 +121,7 @@ def _prepare_config_flow_result_json(result, prepare_result_json): data = result.copy() data["result"] = entry_json(result["result"]) data.pop("data") + data.pop("context") return data diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index cdc4023f32c..629258e01d1 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -484,6 +484,7 @@ class FlowHandler: data=data, description=description, description_placeholders=description_placeholders, + context=self.context, ) @callback diff --git a/homeassistant/helpers/data_entry_flow.py b/homeassistant/helpers/data_entry_flow.py index 428a62f0c9d..e3e4b4f0de8 100644 --- a/homeassistant/helpers/data_entry_flow.py +++ b/homeassistant/helpers/data_entry_flow.py @@ -30,6 +30,7 @@ class _BaseFlowManagerView(HomeAssistantView): data = result.copy() data.pop("result") data.pop("data") + data.pop("context") return data if "data_schema" not in result: diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py index 09a74cf9bc9..6854bb92052 100644 --- a/tests/components/auth/test_init.py +++ b/tests/components/auth/test_init.py @@ -64,7 +64,6 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client): f"/auth/login_flow/{step['flow_id']}", json={ "client_id": CLIENT_ID, - "redirect_uri": CLIENT_REDIRECT_URI, "username": "test-user", "password": "test-pass", }, @@ -133,7 +132,6 @@ async def test_auth_code_checks_local_only_user(hass, aiohttp_client): f"/auth/login_flow/{step['flow_id']}", json={ "client_id": CLIENT_ID, - "redirect_uri": CLIENT_REDIRECT_URI, "username": "test-user", "password": "test-pass", }, diff --git a/tests/components/auth/test_init_link_user.py b/tests/components/auth/test_init_link_user.py index bad6e3bfefe..882371a458f 100644 --- a/tests/components/auth/test_init_link_user.py +++ b/tests/components/auth/test_init_link_user.py @@ -48,7 +48,6 @@ async def async_get_code(hass, aiohttp_client): f"/auth/login_flow/{step['flow_id']}", json={ "client_id": CLIENT_ID, - "redirect_uri": CLIENT_REDIRECT_URI, "username": "2nd-user", "password": "2nd-pass", }, diff --git a/tests/components/auth/test_login_flow.py b/tests/components/auth/test_login_flow.py index ce547149786..b3adfb93afb 100644 --- a/tests/components/auth/test_login_flow.py +++ b/tests/components/auth/test_login_flow.py @@ -61,7 +61,6 @@ async def test_invalid_username_password(hass, aiohttp_client): f"/auth/login_flow/{step['flow_id']}", json={ "client_id": CLIENT_ID, - "redirect_uri": CLIENT_REDIRECT_URI, "username": "wrong-user", "password": "test-pass", }, @@ -82,7 +81,6 @@ async def test_invalid_username_password(hass, aiohttp_client): f"/auth/login_flow/{step['flow_id']}", json={ "client_id": CLIENT_ID, - "redirect_uri": CLIENT_REDIRECT_URI, "username": "test-user", "password": "wrong-pass", }, @@ -103,7 +101,6 @@ async def test_invalid_username_password(hass, aiohttp_client): f"/auth/login_flow/{step['flow_id']}", json={ "client_id": CLIENT_ID, - "redirect_uri": "http://some-other-domain.com", "username": "wrong-user", "password": "test-pass", }, @@ -116,7 +113,21 @@ async def test_invalid_username_password(hass, aiohttp_client): assert step["step_id"] == "init" assert step["errors"]["base"] == "invalid_auth" - # Incorrect redirect URI + +async def test_invalid_redirect_uri(hass, aiohttp_client): + """Test invalid redirect URI.""" + client = await async_setup_auth(hass, aiohttp_client) + resp = await client.post( + "/auth/login_flow", + json={ + "client_id": CLIENT_ID, + "handler": ["insecure_example", None], + "redirect_uri": "https://some-other-domain.com", + }, + ) + assert resp.status == HTTPStatus.OK + step = await resp.json() + with patch( "homeassistant.components.auth.indieauth.fetch_redirect_uris", return_value=[] ), patch( @@ -126,7 +137,6 @@ async def test_invalid_username_password(hass, aiohttp_client): f"/auth/login_flow/{step['flow_id']}", json={ "client_id": CLIENT_ID, - "redirect_uri": "http://some-other-domain.com", "username": "test-user", "password": "test-pass", }, @@ -165,7 +175,6 @@ async def test_login_exist_user(hass, aiohttp_client): f"/auth/login_flow/{step['flow_id']}", json={ "client_id": CLIENT_ID, - "redirect_uri": CLIENT_REDIRECT_URI, "username": "test-user", "password": "test-pass", }, @@ -206,14 +215,13 @@ async def test_login_local_only_user(hass, aiohttp_client): f"/auth/login_flow/{step['flow_id']}", json={ "client_id": CLIENT_ID, - "redirect_uri": CLIENT_REDIRECT_URI, "username": "test-user", "password": "test-pass", }, ) - assert len(mock_not_allowed_do_auth.mock_calls) == 1 assert resp.status == HTTPStatus.FORBIDDEN + assert len(mock_not_allowed_do_auth.mock_calls) == 1 assert await resp.json() == {"message": "Login blocked: User is local only"} diff --git a/tests/components/philips_js/test_config_flow.py b/tests/components/philips_js/test_config_flow.py index c6bade94ea4..284c7e7541e 100644 --- a/tests/components/philips_js/test_config_flow.py +++ b/tests/components/philips_js/test_config_flow.py @@ -120,6 +120,7 @@ async def test_pairing(hass, mock_tv_pairable, mock_setup_entry): ) assert result == { + "context": {"source": "user", "unique_id": "ABCDEFGHIJKLF"}, "flow_id": ANY, "type": "create_entry", "description": None, diff --git a/tests/components/subaru/test_config_flow.py b/tests/components/subaru/test_config_flow.py index e14a62d432d..62f69017a82 100644 --- a/tests/components/subaru/test_config_flow.py +++ b/tests/components/subaru/test_config_flow.py @@ -117,6 +117,7 @@ async def test_user_form_pin_not_required(hass, two_factor_verify_form): assert len(mock_setup_entry.mock_calls) == 1 expected = { + "context": {"source": "user"}, "title": TEST_USERNAME, "description": None, "description_placeholders": None, @@ -286,6 +287,7 @@ async def test_pin_form_success(hass, pin_form): assert len(mock_update_saved_pin.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1 expected = { + "context": {"source": "user"}, "title": TEST_USERNAME, "description": None, "description_placeholders": None,