diff --git a/homeassistant/components/diagnostics/__init__.py b/homeassistant/components/diagnostics/__init__.py index b54a710e807..219da41d318 100644 --- a/homeassistant/components/diagnostics/__init__.py +++ b/homeassistant/components/diagnostics/__init__.py @@ -1,8 +1,9 @@ """The Diagnostics integration.""" from __future__ import annotations +from collections import defaultdict from collections.abc import Callable, Coroutine -from dataclasses import dataclass, field +from dataclasses import asdict as dataclass_asdict, dataclass from http import HTTPStatus import json import logging @@ -33,6 +34,14 @@ __all__ = ["REDACTED", "async_redact_data"] _LOGGER = logging.getLogger(__name__) +@dataclass +class DiagnosticsSubscriptionSupport: + """Describe subscriptions supported by the platform.""" + + config_entry: bool + domain: bool + + @dataclass class DiagnosticsPlatformData: """Diagnostic platform data.""" @@ -43,13 +52,21 @@ class DiagnosticsPlatformData: device_diagnostics: Callable[ [HomeAssistant, ConfigEntry, DeviceEntry], Coroutine[Any, Any, Any] ] | None + subscription_support: DiagnosticsSubscriptionSupport -@dataclass class DiagnosticsData: """Diagnostic data.""" - platforms: dict[str, DiagnosticsPlatformData] = field(default_factory=dict) + def __init__(self) -> None: + """Initialize diagnostic data.""" + self.platforms: dict[str, DiagnosticsPlatformData] = {} + self.domain_subscriptions: defaultdict[ + str, set[tuple[websocket_api.ActiveConnection, int]] + ] = defaultdict(set) + self.config_entry_subscriptions: defaultdict[ + str, set[tuple[websocket_api.ActiveConnection, int]] + ] = defaultdict(set) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: @@ -62,6 +79,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: websocket_api.async_register_command(hass, handle_info) websocket_api.async_register_command(hass, handle_get) + websocket_api.async_register_command(hass, handle_subscribe_diagnostics) hass.http.register_view(DownloadDiagnosticsView) return True @@ -80,15 +98,26 @@ class DiagnosticsProtocol(Protocol): ) -> Any: """Return diagnostics for a device.""" + @callback + def async_supports_subscription(self) -> DiagnosticsSubscriptionSupport: + """Return if the platform supports subscribing to diagnostics.""" + async def _register_diagnostics_platform( hass: HomeAssistant, integration_domain: str, platform: DiagnosticsProtocol ) -> None: """Register a diagnostics platform.""" diagnostics_data: DiagnosticsData = hass.data[DOMAIN] + subscription_support: DiagnosticsSubscriptionSupport + if hasattr(platform, "async_supports_subscription"): + subscription_support = platform.async_supports_subscription() + else: + subscription_support = DiagnosticsSubscriptionSupport(False, False) + diagnostics_data.platforms[integration_domain] = DiagnosticsPlatformData( getattr(platform, "async_get_config_entry_diagnostics", None), getattr(platform, "async_get_device_diagnostics", None), + subscription_support, ) @@ -107,6 +136,7 @@ def handle_info( DiagnosticsType.CONFIG_ENTRY: info.config_entry_diagnostics is not None, DiagnosticsSubType.DEVICE: info.device_diagnostics is not None, }, + "supports_subscription": dataclass_asdict(info.subscription_support), } for domain, info in diagnostics_data.platforms.items() ] @@ -142,10 +172,48 @@ def handle_get( DiagnosticsType.CONFIG_ENTRY: info.config_entry_diagnostics is not None, DiagnosticsSubType.DEVICE: info.device_diagnostics is not None, }, + "supports_subscription": dataclass_asdict(info.subscription_support), }, ) +@websocket_api.websocket_command( + { + vol.Required("type"): "diagnostics/subscribe", + vol.Required("domain"): str, + vol.Optional("config_entry"): str, + } +) +@callback +def handle_subscribe_diagnostics( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] +) -> None: + """Subscribe to diagnostic messages.""" + + diagnostics_data: DiagnosticsData = hass.data[DOMAIN] + + msg_id = msg["id"] + domain = msg["domain"] + config_entry = msg.get("config_entry") + diagnostics_data.domain_subscriptions[domain].add((connection, msg_id)) + if config_entry: + diagnostics_data.config_entry_subscriptions[config_entry].add( + (connection, msg_id) + ) + + @callback + def cancel_subscription() -> None: + diagnostics_data.domain_subscriptions[domain].remove((connection, msg["id"])) + if config_entry: + diagnostics_data.config_entry_subscriptions[config_entry].remove( + (connection, msg_id) + ) + + connection.subscriptions[msg["id"]] = cancel_subscription + + connection.send_message(websocket_api.result_message(msg["id"])) + + async def _async_get_json_file_response( hass: HomeAssistant, data: Any, @@ -265,3 +333,62 @@ class DownloadDiagnosticsView(http.HomeAssistantView): return await _async_get_json_file_response( hass, data, filename, config_entry.domain, d_id, sub_id ) + + +@callback +def async_has_subscription( + hass: HomeAssistant, domain: str, config_entry_id: str | None = None +) -> bool: + """Return True if there is a matching diagnostics subscription.""" + + diagnostics_data: DiagnosticsData = hass.data[DOMAIN] + if ( + domain in diagnostics_data.domain_subscriptions + and diagnostics_data.domain_subscriptions[domain] + ): + return True + + if not config_entry_id: + return False + + if ( + config_entry_id in diagnostics_data.config_entry_subscriptions + and diagnostics_data.config_entry_subscriptions[config_entry_id] + ): + return True + + return False + + +@callback +def async_log_object( + hass: HomeAssistant, + data: Any, + domain: str, + config_entry_id: str | None = None, +) -> None: + """Send diagnostic data to subscribers.""" + diagnostics_data: DiagnosticsData = hass.data[DOMAIN] + + json_data = json.dumps({"data": data}, cls=ExtendedJSONEncoder) + + domain_subs: set[tuple[websocket_api.ActiveConnection, int]] = set() + + if domain in diagnostics_data.domain_subscriptions: + for conn, msg_id in diagnostics_data.domain_subscriptions[domain]: + conn.send_message(websocket_api.event_message(msg_id, json_data)) + domain_subs.add((conn, msg_id)) + + if not config_entry_id: + return + + if ( + config_entry_id in diagnostics_data.config_entry_subscriptions + and diagnostics_data.config_entry_subscriptions[config_entry_id] + ): + for conn, msg_id in diagnostics_data.config_entry_subscriptions[ + config_entry_id + ]: + if (conn, msg_id) in domain_subs: + continue + conn.send_message(websocket_api.event_message(msg_id, json_data)) diff --git a/tests/components/diagnostics/test_init.py b/tests/components/diagnostics/test_init.py index 11b113e30f6..30b7c0e188c 100644 --- a/tests/components/diagnostics/test_init.py +++ b/tests/components/diagnostics/test_init.py @@ -1,10 +1,18 @@ """Test the Diagnostics integration.""" from http import HTTPStatus +import json from unittest.mock import AsyncMock, Mock import pytest +from homeassistant.components.diagnostics import ( + DOMAIN, + DiagnosticsSubscriptionSupport, + async_has_subscription, + async_log_object, +) from homeassistant.components.websocket_api.const import TYPE_RESULT +from homeassistant.core import HomeAssistant from homeassistant.helpers.device_registry import async_get from homeassistant.helpers.system_info import async_get_system_info from homeassistant.setup import async_setup_component @@ -18,6 +26,8 @@ from tests.common import MockConfigEntry, mock_platform async def mock_diagnostics_integration(hass): """Mock a diagnostics integration.""" hass.config.components.add("fake_integration") + hass.config.components.add("fake_integration_no_subscribe") + hass.config.components.add("integration_without_diagnostics") mock_platform( hass, "fake_integration.diagnostics", @@ -32,14 +42,37 @@ async def mock_diagnostics_integration(hass): "device": "info", } ), + async_supports_subscription=Mock( + return_value=DiagnosticsSubscriptionSupport(True, True) + ), + ), + ) + mock_platform( + hass, + "fake_integration_no_subscribe.diagnostics", + Mock( + async_get_config_entry_diagnostics=AsyncMock( + return_value={ + "config_entry": "info", + } + ), + async_get_device_diagnostics=AsyncMock( + return_value={ + "device": "info", + } + ), + spec_set=[ + "async_get_config_entry_diagnostics", + "async_get_device_diagnostics", + ], ), ) mock_platform( hass, "integration_without_diagnostics.diagnostics", - Mock(), + Mock(spec_set=[]), ) - assert await async_setup_component(hass, "diagnostics", {}) + assert await async_setup_component(hass, DOMAIN, {}) async def test_websocket(hass, hass_ws_client): @@ -52,12 +85,22 @@ async def test_websocket(hass, hass_ws_client): assert msg["id"] == 5 assert msg["type"] == TYPE_RESULT assert msg["success"] - assert msg["result"] == [ - { - "domain": "fake_integration", - "handlers": {"config_entry": True, "device": True}, - } - ] + assert len(msg["result"]) == 3 + assert { + "domain": "fake_integration_no_subscribe", + "handlers": {"config_entry": True, "device": True}, + "supports_subscription": {"config_entry": False, "domain": False}, + } in msg["result"] + assert { + "domain": "fake_integration", + "handlers": {"config_entry": True, "device": True}, + "supports_subscription": {"config_entry": True, "domain": True}, + } in msg["result"] + assert { + "domain": "integration_without_diagnostics", + "handlers": {"config_entry": False, "device": False}, + "supports_subscription": {"config_entry": False, "domain": False}, + } in msg["result"] await client.send_json( {"id": 6, "type": "diagnostics/get", "domain": "fake_integration"} @@ -71,6 +114,7 @@ async def test_websocket(hass, hass_ws_client): assert msg["result"] == { "domain": "fake_integration", "handlers": {"config_entry": True, "device": True}, + "supports_subscription": {"config_entry": True, "domain": True}, } @@ -153,3 +197,76 @@ async def test_failure_scenarios(hass, hass_client): f"/api/diagnostics/config_entry/{config_entry.entry_id}/device/fake_id" ) assert response.status == HTTPStatus.NOT_FOUND + + +async def test_diagnostics_subscription_domain(hass: HomeAssistant, hass_ws_client): + """Test websocket diagnostics subscription for a domain.""" + + client = await hass_ws_client(hass) + + # Test there's no subscription + assert not async_has_subscription(hass, "fake_integration") + + await client.send_json( + {"id": 1, "type": "diagnostics/subscribe", "domain": "fake_integration"} + ) + response = await client.receive_json() + assert response["success"] + assert async_has_subscription(hass, "fake_integration") + + # Log some data + async_log_object(hass, {"some": "data"}, "fake_integration") + await hass.async_block_till_done() + + response = await client.receive_json() + assert json.loads(response["event"]) == {"data": {"some": "data"}} + + # Unsubscribe + await client.send_json({"id": 8, "type": "unsubscribe_events", "subscription": 1}) + response = await client.receive_json() + assert response["success"] + assert not async_has_subscription(hass, "fake_integration") + assert not async_has_subscription(hass, "fake_integration", "fake_config_entry_id") + + +async def test_diagnostics_subscription_config_entry( + hass: HomeAssistant, hass_ws_client +): + """Test websocket diagnostics subscription for a config_entry.""" + + client = await hass_ws_client(hass) + + # Test there's no subscription + assert not async_has_subscription(hass, "fake_integration", "fake_config_entry_id") + + await client.send_json( + { + "id": 1, + "type": "diagnostics/subscribe", + "domain": "fake_integration", + "config_entry": "fake_config_entry_id", + } + ) + response = await client.receive_json() + assert response["success"] + assert async_has_subscription(hass, "fake_integration", "fake_config_entry_id") + + # Log some data for the domain + async_log_object(hass, {"some": "data"}, "fake_integration") + await hass.async_block_till_done() + + response = await client.receive_json() + assert json.loads(response["event"]) == {"data": {"some": "data"}} + + # Log some data for the config entry + async_log_object(hass, {"some": "data"}, "fake_integration", "fake_config_entry_id") + await hass.async_block_till_done() + + response = await client.receive_json() + assert json.loads(response["event"]) == {"data": {"some": "data"}} + + # Unsubscribe + await client.send_json({"id": 8, "type": "unsubscribe_events", "subscription": 1}) + response = await client.receive_json() + assert response["success"] + assert not async_has_subscription(hass, "fake_integration", "fake_config_entry_id")