From 1650cee16c7abe2bd85ef27886d649335a93eb3e Mon Sep 17 00:00:00 2001 From: Matthias Alphart Date: Wed, 28 Aug 2024 18:10:38 +0200 Subject: [PATCH] Check KNX integration is loaded on websocket calls (#123178) --- homeassistant/components/knx/websocket.py | 110 +++++++++++++++++++--- tests/components/knx/test_websocket.py | 27 ++++++ 2 files changed, 124 insertions(+), 13 deletions(-) diff --git a/homeassistant/components/knx/websocket.py b/homeassistant/components/knx/websocket.py index 4af3012741a..5c21a941484 100644 --- a/homeassistant/components/knx/websocket.py +++ b/homeassistant/components/knx/websocket.py @@ -2,7 +2,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Final +import asyncio +from collections.abc import Awaitable, Callable +from functools import wraps +from typing import TYPE_CHECKING, Any, Final, overload import knx_frontend as knx_panel import voluptuous as vol @@ -77,21 +80,92 @@ async def register_panel(hass: HomeAssistant) -> None: ) +type KnxWebSocketCommandHandler = Callable[ + [HomeAssistant, KNXModule, websocket_api.ActiveConnection, dict[str, Any]], None +] +type KnxAsyncWebSocketCommandHandler = Callable[ + [HomeAssistant, KNXModule, websocket_api.ActiveConnection, dict[str, Any]], + Awaitable[None], +] + + +@overload +def provide_knx( + func: KnxAsyncWebSocketCommandHandler, +) -> websocket_api.const.AsyncWebSocketCommandHandler: ... +@overload +def provide_knx( + func: KnxWebSocketCommandHandler, +) -> websocket_api.const.WebSocketCommandHandler: ... + + +def provide_knx( + func: KnxAsyncWebSocketCommandHandler | KnxWebSocketCommandHandler, +) -> ( + websocket_api.const.AsyncWebSocketCommandHandler + | websocket_api.const.WebSocketCommandHandler +): + """Websocket decorator to provide a KNXModule instance.""" + + def _send_not_loaded_error( + connection: websocket_api.ActiveConnection, msg_id: int + ) -> None: + connection.send_error( + msg_id, + websocket_api.const.ERR_HOME_ASSISTANT_ERROR, + "KNX integration not loaded.", + ) + + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def with_knx( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], + ) -> None: + """Add KNX Module to call function.""" + try: + knx: KNXModule = hass.data[DOMAIN] + except KeyError: + _send_not_loaded_error(connection, msg["id"]) + return + await func(hass, knx, connection, msg) + + else: + + @wraps(func) + def with_knx( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], + ) -> None: + """Add KNX Module to call function.""" + try: + knx: KNXModule = hass.data[DOMAIN] + except KeyError: + _send_not_loaded_error(connection, msg["id"]) + return + func(hass, knx, connection, msg) + + return with_knx + + @websocket_api.require_admin @websocket_api.websocket_command( { vol.Required("type"): "knx/info", } ) +@provide_knx @callback def ws_info( hass: HomeAssistant, + knx: KNXModule, connection: websocket_api.ActiveConnection, msg: dict, ) -> None: """Handle get info command.""" - knx: KNXModule = hass.data[DOMAIN] - _project_info = None if project_info := knx.project.info: _project_info = { @@ -119,13 +193,14 @@ def ws_info( } ) @websocket_api.async_response +@provide_knx async def ws_get_knx_project( hass: HomeAssistant, + knx: KNXModule, connection: websocket_api.ActiveConnection, msg: dict, ) -> None: """Handle get KNX project.""" - knx: KNXModule = hass.data[DOMAIN] knxproject = await knx.project.get_knxproject() connection.send_result( msg["id"], @@ -145,13 +220,14 @@ async def ws_get_knx_project( } ) @websocket_api.async_response +@provide_knx async def ws_project_file_process( hass: HomeAssistant, + knx: KNXModule, connection: websocket_api.ActiveConnection, msg: dict, ) -> None: """Handle get info command.""" - knx: KNXModule = hass.data[DOMAIN] try: await knx.project.process_project_file( xknx=knx.xknx, @@ -175,13 +251,14 @@ async def ws_project_file_process( } ) @websocket_api.async_response +@provide_knx async def ws_project_file_remove( hass: HomeAssistant, + knx: KNXModule, connection: websocket_api.ActiveConnection, msg: dict, ) -> None: """Handle get info command.""" - knx: KNXModule = hass.data[DOMAIN] await knx.project.remove_project_file() connection.send_result(msg["id"]) @@ -192,14 +269,15 @@ async def ws_project_file_remove( vol.Required("type"): "knx/group_monitor_info", } ) +@provide_knx @callback def ws_group_monitor_info( hass: HomeAssistant, + knx: KNXModule, connection: websocket_api.ActiveConnection, msg: dict, ) -> None: """Handle get info command of group monitor.""" - knx: KNXModule = hass.data[DOMAIN] recent_telegrams = [*knx.telegrams.recent_telegrams] connection.send_result( msg["id"], @@ -272,8 +350,10 @@ def ws_validate_entity( } ) @websocket_api.async_response +@provide_knx async def ws_create_entity( hass: HomeAssistant, + knx: KNXModule, connection: websocket_api.ActiveConnection, msg: dict, ) -> None: @@ -283,7 +363,6 @@ async def ws_create_entity( except EntityStoreValidationException as exc: connection.send_result(msg["id"], exc.validation_error) return - knx: KNXModule = hass.data[DOMAIN] try: entity_id = await knx.config_store.create_entity( # use validation result so defaults are applied @@ -308,8 +387,10 @@ async def ws_create_entity( } ) @websocket_api.async_response +@provide_knx async def ws_update_entity( hass: HomeAssistant, + knx: KNXModule, connection: websocket_api.ActiveConnection, msg: dict, ) -> None: @@ -319,7 +400,6 @@ async def ws_update_entity( except EntityStoreValidationException as exc: connection.send_result(msg["id"], exc.validation_error) return - knx: KNXModule = hass.data[DOMAIN] try: await knx.config_store.update_entity( validated_data[CONF_PLATFORM], @@ -344,13 +424,14 @@ async def ws_update_entity( } ) @websocket_api.async_response +@provide_knx async def ws_delete_entity( hass: HomeAssistant, + knx: KNXModule, connection: websocket_api.ActiveConnection, msg: dict, ) -> None: """Delete entity from entity store and remove it.""" - knx: KNXModule = hass.data[DOMAIN] try: await knx.config_store.delete_entity(msg[CONF_ENTITY_ID]) except ConfigStoreException as err: @@ -367,14 +448,15 @@ async def ws_delete_entity( vol.Required("type"): "knx/get_entity_entries", } ) +@provide_knx @callback def ws_get_entity_entries( hass: HomeAssistant, + knx: KNXModule, connection: websocket_api.ActiveConnection, msg: dict, ) -> None: """Get entities configured from entity store.""" - knx: KNXModule = hass.data[DOMAIN] entity_entries = [ entry.extended_dict for entry in knx.config_store.get_entity_entries() ] @@ -388,14 +470,15 @@ def ws_get_entity_entries( vol.Required(CONF_ENTITY_ID): str, } ) +@provide_knx @callback def ws_get_entity_config( hass: HomeAssistant, + knx: KNXModule, connection: websocket_api.ActiveConnection, msg: dict, ) -> None: """Get entity configuration from entity store.""" - knx: KNXModule = hass.data[DOMAIN] try: config_info = knx.config_store.get_entity_config(msg[CONF_ENTITY_ID]) except ConfigStoreException as err: @@ -414,14 +497,15 @@ def ws_get_entity_config( vol.Optional("area_id"): str, } ) +@provide_knx @callback def ws_create_device( hass: HomeAssistant, + knx: KNXModule, connection: websocket_api.ActiveConnection, msg: dict, ) -> None: """Create a new KNX device.""" - knx: KNXModule = hass.data[DOMAIN] identifier = f"knx_vdev_{ulid_now()}" device_registry = dr.async_get(hass) _device = device_registry.async_get_or_create( diff --git a/tests/components/knx/test_websocket.py b/tests/components/knx/test_websocket.py index 309ea111709..e747b0daade 100644 --- a/tests/components/knx/test_websocket.py +++ b/tests/components/knx/test_websocket.py @@ -3,6 +3,8 @@ from typing import Any from unittest.mock import patch +import pytest + from homeassistant.components.knx import DOMAIN, KNX_ADDRESS, SwitchSchema from homeassistant.components.knx.project import STORAGE_KEY as KNX_PROJECT_STORAGE_KEY from homeassistant.const import CONF_NAME @@ -355,3 +357,28 @@ async def test_knx_subscribe_telegrams_command_project( ) assert res["event"]["direction"] == "Incoming" assert res["event"]["timestamp"] is not None + + +@pytest.mark.parametrize( + "endpoint", + [ + "knx/info", # sync ws-command + "knx/get_knx_project", # async ws-command + ], +) +async def test_websocket_when_config_entry_unloaded( + hass: HomeAssistant, + knx: KNXTestKit, + hass_ws_client: WebSocketGenerator, + endpoint: str, +) -> None: + """Test websocket connection when config entry is unloaded.""" + await knx.setup_integration({}) + await hass.config_entries.async_unload(knx.mock_config_entry.entry_id) + client = await hass_ws_client(hass) + + await client.send_json_auto_id({"type": endpoint}) + res = await client.receive_json() + assert not res["success"] + assert res["error"]["code"] == "home_assistant_error" + assert res["error"]["message"] == "KNX integration not loaded."