diff --git a/homeassistant/components/emulated_hue/__init__.py b/homeassistant/components/emulated_hue/__init__.py index 86102242b31..1ba93da716c 100644 --- a/homeassistant/components/emulated_hue/__init__.py +++ b/homeassistant/components/emulated_hue/__init__.py @@ -138,16 +138,16 @@ async def async_setup(hass: HomeAssistant, yaml_config: ConfigType) -> bool: app._on_startup.freeze() await app.startup() - DescriptionXmlView(config).register(app, app.router) - HueUsernameView().register(app, app.router) - HueConfigView(config).register(app, app.router) - HueUnauthorizedUser().register(app, app.router) - HueAllLightsStateView(config).register(app, app.router) - HueOneLightStateView(config).register(app, app.router) - HueOneLightChangeView(config).register(app, app.router) - HueAllGroupsStateView(config).register(app, app.router) - HueGroupView(config).register(app, app.router) - HueFullStateView(config).register(app, app.router) + DescriptionXmlView(config).register(hass, app, app.router) + HueUsernameView().register(hass, app, app.router) + HueConfigView(config).register(hass, app, app.router) + HueUnauthorizedUser().register(hass, app, app.router) + HueAllLightsStateView(config).register(hass, app, app.router) + HueOneLightStateView(config).register(hass, app, app.router) + HueOneLightChangeView(config).register(hass, app, app.router) + HueAllGroupsStateView(config).register(hass, app, app.router) + HueGroupView(config).register(hass, app, app.router) + HueFullStateView(config).register(hass, app, app.router) async def _start(event: Event) -> None: """Start the bridge.""" diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index 2d306ba5ee5..fda8717c3dd 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -365,7 +365,7 @@ class HomeAssistantHTTP: class_name = view.__class__.__name__ raise AttributeError(f'{class_name} missing required attribute "name"') - view.register(self.app, self.app.router) + view.register(self.hass, self.app, self.app.router) def register_redirect( self, diff --git a/homeassistant/components/http/view.py b/homeassistant/components/http/view.py index d39fca28782..abdcfe466c1 100644 --- a/homeassistant/components/http/view.py +++ b/homeassistant/components/http/view.py @@ -19,7 +19,7 @@ import voluptuous as vol from homeassistant import exceptions from homeassistant.const import CONTENT_TYPE_JSON -from homeassistant.core import Context, is_callback +from homeassistant.core import Context, HomeAssistant, is_callback from homeassistant.helpers.json import ( find_paths_unserializable_data, json_bytes, @@ -27,7 +27,7 @@ from homeassistant.helpers.json import ( ) from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS, format_unserializable_data -from .const import KEY_AUTHENTICATED, KEY_HASS +from .const import KEY_AUTHENTICATED _LOGGER = logging.getLogger(__name__) @@ -88,7 +88,9 @@ class HomeAssistantView: data["code"] = message_code return self.json(data, status_code, headers=headers) - def register(self, app: web.Application, router: web.UrlDispatcher) -> None: + def register( + self, hass: HomeAssistant, app: web.Application, router: web.UrlDispatcher + ) -> None: """Register the view with a router.""" assert self.url is not None, "No url set for view" urls = [self.url] + self.extra_urls @@ -98,7 +100,7 @@ class HomeAssistantView: if not (handler := getattr(self, method, None)): continue - handler = request_handler_factory(self, handler) + handler = request_handler_factory(hass, self, handler) for url in urls: routes.append(router.add_route(method, url, handler)) @@ -115,16 +117,17 @@ class HomeAssistantView: def request_handler_factory( - view: HomeAssistantView, handler: Callable + hass: HomeAssistant, view: HomeAssistantView, handler: Callable ) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: """Wrap the handler classes.""" - assert asyncio.iscoroutinefunction(handler) or is_callback( + is_coroutinefunction = asyncio.iscoroutinefunction(handler) + assert is_coroutinefunction or is_callback( handler ), "Handler should be a coroutine or a callback." async def handle(request: web.Request) -> web.StreamResponse: """Handle incoming request.""" - if request.app[KEY_HASS].is_stopping: + if hass.is_stopping: return web.Response(status=HTTPStatus.SERVICE_UNAVAILABLE) authenticated = request.get(KEY_AUTHENTICATED, False) @@ -132,18 +135,19 @@ def request_handler_factory( if view.requires_auth and not authenticated: raise HTTPUnauthorized() - _LOGGER.debug( - "Serving %s to %s (auth: %s)", - request.path, - request.remote, - authenticated, - ) + if _LOGGER.isEnabledFor(logging.DEBUG): + _LOGGER.debug( + "Serving %s to %s (auth: %s)", + request.path, + request.remote, + authenticated, + ) try: - result = handler(request, **request.match_info) - - if asyncio.iscoroutine(result): - result = await result + if is_coroutinefunction: + result = await handler(request, **request.match_info) + else: + result = handler(request, **request.match_info) except vol.Invalid as err: raise HTTPBadRequest() from err except exceptions.ServiceNotFound as err: @@ -156,21 +160,20 @@ def request_handler_factory( return result status_code = HTTPStatus.OK - if isinstance(result, tuple): result, status_code = result if isinstance(result, bytes): - bresult = result - elif isinstance(result, str): - bresult = result.encode("utf-8") - elif result is None: - bresult = b"" - else: - raise TypeError( - f"Result should be None, string, bytes or StreamResponse. Got: {result}" - ) + return web.Response(body=result, status=status_code) - return web.Response(body=bresult, status=status_code) + if isinstance(result, str): + return web.Response(text=result, status=status_code) + + if result is None: + return web.Response(body=b"", status=status_code) + + raise TypeError( + f"Result should be None, string, bytes or StreamResponse. Got: {result}" + ) return handle diff --git a/tests/components/emulated_hue/test_hue_api.py b/tests/components/emulated_hue/test_hue_api.py index 153a9cac0ca..247a507bb69 100644 --- a/tests/components/emulated_hue/test_hue_api.py +++ b/tests/components/emulated_hue/test_hue_api.py @@ -215,13 +215,13 @@ def _mock_hue_endpoints( web_app = hass.http.app config = Config(hass, conf, "127.0.0.1") config.numbers = entity_numbers - HueUsernameView().register(web_app, web_app.router) - HueAllLightsStateView(config).register(web_app, web_app.router) - HueOneLightStateView(config).register(web_app, web_app.router) - HueOneLightChangeView(config).register(web_app, web_app.router) - HueAllGroupsStateView(config).register(web_app, web_app.router) - HueFullStateView(config).register(web_app, web_app.router) - HueConfigView(config).register(web_app, web_app.router) + HueUsernameView().register(hass, web_app, web_app.router) + HueAllLightsStateView(config).register(hass, web_app, web_app.router) + HueOneLightStateView(config).register(hass, web_app, web_app.router) + HueOneLightChangeView(config).register(hass, web_app, web_app.router) + HueAllGroupsStateView(config).register(hass, web_app, web_app.router) + HueFullStateView(config).register(hass, web_app, web_app.router) + HueConfigView(config).register(hass, web_app, web_app.router) @pytest.fixture diff --git a/tests/components/http/test_ban.py b/tests/components/http/test_ban.py index 25574833d17..b34c8866ff1 100644 --- a/tests/components/http/test_ban.py +++ b/tests/components/http/test_ban.py @@ -333,13 +333,15 @@ async def test_failed_login_attempts_counter( return None, 200 app.router.add_get( - "/auth_true", request_handler_factory(Mock(requires_auth=True), auth_handler) + "/auth_true", + request_handler_factory(hass, Mock(requires_auth=True), auth_handler), ) app.router.add_get( - "/auth_false", request_handler_factory(Mock(requires_auth=True), auth_handler) + "/auth_false", + request_handler_factory(hass, Mock(requires_auth=True), auth_handler), ) app.router.add_get( - "/", request_handler_factory(Mock(requires_auth=False), auth_handler) + "/", request_handler_factory(hass, Mock(requires_auth=False), auth_handler) ) setup_bans(hass, app, 5) diff --git a/tests/components/http/test_data_validator.py b/tests/components/http/test_data_validator.py index 04f5dbf50f0..ecff4370999 100644 --- a/tests/components/http/test_data_validator.py +++ b/tests/components/http/test_data_validator.py @@ -27,7 +27,7 @@ async def get_client(aiohttp_client, validator): """Test method.""" return b"" - TestView().register(app, app.router) + TestView().register(app["hass"], app, app.router) client = await aiohttp_client(app) return client diff --git a/tests/components/http/test_view.py b/tests/components/http/test_view.py index 059c56b715d..e52413d5225 100644 --- a/tests/components/http/test_view.py +++ b/tests/components/http/test_view.py @@ -20,13 +20,13 @@ from homeassistant.exceptions import ServiceNotFound, Unauthorized @pytest.fixture -def mock_request(): +def mock_request() -> Mock: """Mock a request.""" return Mock(app={"hass": Mock(is_stopping=False)}, match_info={}) @pytest.fixture -def mock_request_with_stopping(): +def mock_request_with_stopping() -> Mock: """Mock a request.""" return Mock(app={"hass": Mock(is_stopping=True)}, match_info={}) @@ -48,34 +48,51 @@ async def test_nan_serialized_to_null() -> None: assert json.loads(response.body.decode("utf-8")) is None -async def test_handling_unauthorized(mock_request) -> None: +async def test_handling_unauthorized(mock_request: Mock) -> None: """Test handling unauth exceptions.""" with pytest.raises(HTTPUnauthorized): await request_handler_factory( - Mock(requires_auth=False), AsyncMock(side_effect=Unauthorized) + mock_request.app["hass"], + Mock(requires_auth=False), + AsyncMock(side_effect=Unauthorized), )(mock_request) -async def test_handling_invalid_data(mock_request) -> None: +async def test_handling_invalid_data(mock_request: Mock) -> None: """Test handling unauth exceptions.""" with pytest.raises(HTTPBadRequest): await request_handler_factory( - Mock(requires_auth=False), AsyncMock(side_effect=vol.Invalid("yo")) + mock_request.app["hass"], + Mock(requires_auth=False), + AsyncMock(side_effect=vol.Invalid("yo")), )(mock_request) -async def test_handling_service_not_found(mock_request) -> None: +async def test_handling_service_not_found(mock_request: Mock) -> None: """Test handling unauth exceptions.""" with pytest.raises(HTTPInternalServerError): await request_handler_factory( + mock_request.app["hass"], Mock(requires_auth=False), AsyncMock(side_effect=ServiceNotFound("test", "test")), )(mock_request) -async def test_not_running(mock_request_with_stopping) -> None: +async def test_not_running(mock_request_with_stopping: Mock) -> None: """Test we get a 503 when not running.""" response = await request_handler_factory( - Mock(requires_auth=False), AsyncMock(side_effect=Unauthorized) + mock_request_with_stopping.app["hass"], + Mock(requires_auth=False), + AsyncMock(side_effect=Unauthorized), )(mock_request_with_stopping) assert response.status == HTTPStatus.SERVICE_UNAVAILABLE + + +async def test_invalid_handler(mock_request: Mock) -> None: + """Test an invalid handler.""" + with pytest.raises(TypeError): + await request_handler_factory( + mock_request.app["hass"], + Mock(requires_auth=False), + AsyncMock(return_value=["not valid"]), + )(mock_request)