Support subscribing to diagnostics messages

This commit is contained in:
Erik 2022-12-05 13:31:34 +01:00
parent ee7022dc67
commit 37da138dc9
2 changed files with 255 additions and 11 deletions

View file

@ -1,8 +1,9 @@
"""The Diagnostics integration."""
from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from dataclasses import asdict as dataclass_asdict, dataclass
from http import HTTPStatus
import json
import logging
@ -33,6 +34,14 @@ __all__ = ["REDACTED", "async_redact_data"]
_LOGGER = logging.getLogger(__name__)
@dataclass
class DiagnosticsSubscriptionSupport:
"""Describe subscriptions supported by the platform."""
config_entry: bool
domain: bool
@dataclass
class DiagnosticsPlatformData:
"""Diagnostic platform data."""
@ -43,13 +52,21 @@ class DiagnosticsPlatformData:
device_diagnostics: Callable[
[HomeAssistant, ConfigEntry, DeviceEntry], Coroutine[Any, Any, Any]
] | None
subscription_support: DiagnosticsSubscriptionSupport
@dataclass
class DiagnosticsData:
"""Diagnostic data."""
platforms: dict[str, DiagnosticsPlatformData] = field(default_factory=dict)
def __init__(self) -> None:
"""Initialize diagnostic data."""
self.platforms: dict[str, DiagnosticsPlatformData] = {}
self.domain_subscriptions: defaultdict[
str, set[tuple[websocket_api.ActiveConnection, int]]
] = defaultdict(set)
self.config_entry_subscriptions: defaultdict[
str, set[tuple[websocket_api.ActiveConnection, int]]
] = defaultdict(set)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
@ -62,6 +79,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
websocket_api.async_register_command(hass, handle_info)
websocket_api.async_register_command(hass, handle_get)
websocket_api.async_register_command(hass, handle_subscribe_diagnostics)
hass.http.register_view(DownloadDiagnosticsView)
return True
@ -80,15 +98,26 @@ class DiagnosticsProtocol(Protocol):
) -> Any:
"""Return diagnostics for a device."""
@callback
def async_supports_subscription(self) -> DiagnosticsSubscriptionSupport:
"""Return if the platform supports subscribing to diagnostics."""
async def _register_diagnostics_platform(
hass: HomeAssistant, integration_domain: str, platform: DiagnosticsProtocol
) -> None:
"""Register a diagnostics platform."""
diagnostics_data: DiagnosticsData = hass.data[DOMAIN]
subscription_support: DiagnosticsSubscriptionSupport
if hasattr(platform, "async_supports_subscription"):
subscription_support = platform.async_supports_subscription()
else:
subscription_support = DiagnosticsSubscriptionSupport(False, False)
diagnostics_data.platforms[integration_domain] = DiagnosticsPlatformData(
getattr(platform, "async_get_config_entry_diagnostics", None),
getattr(platform, "async_get_device_diagnostics", None),
subscription_support,
)
@ -107,6 +136,7 @@ def handle_info(
DiagnosticsType.CONFIG_ENTRY: info.config_entry_diagnostics is not None,
DiagnosticsSubType.DEVICE: info.device_diagnostics is not None,
},
"supports_subscription": dataclass_asdict(info.subscription_support),
}
for domain, info in diagnostics_data.platforms.items()
]
@ -142,10 +172,48 @@ def handle_get(
DiagnosticsType.CONFIG_ENTRY: info.config_entry_diagnostics is not None,
DiagnosticsSubType.DEVICE: info.device_diagnostics is not None,
},
"supports_subscription": dataclass_asdict(info.subscription_support),
},
)
@websocket_api.websocket_command(
{
vol.Required("type"): "diagnostics/subscribe",
vol.Required("domain"): str,
vol.Optional("config_entry"): str,
}
)
@callback
def handle_subscribe_diagnostics(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Subscribe to diagnostic messages."""
diagnostics_data: DiagnosticsData = hass.data[DOMAIN]
msg_id = msg["id"]
domain = msg["domain"]
config_entry = msg.get("config_entry")
diagnostics_data.domain_subscriptions[domain].add((connection, msg_id))
if config_entry:
diagnostics_data.config_entry_subscriptions[config_entry].add(
(connection, msg_id)
)
@callback
def cancel_subscription() -> None:
diagnostics_data.domain_subscriptions[domain].remove((connection, msg["id"]))
if config_entry:
diagnostics_data.config_entry_subscriptions[config_entry].remove(
(connection, msg_id)
)
connection.subscriptions[msg["id"]] = cancel_subscription
connection.send_message(websocket_api.result_message(msg["id"]))
async def _async_get_json_file_response(
hass: HomeAssistant,
data: Any,
@ -265,3 +333,62 @@ class DownloadDiagnosticsView(http.HomeAssistantView):
return await _async_get_json_file_response(
hass, data, filename, config_entry.domain, d_id, sub_id
)
@callback
def async_has_subscription(
hass: HomeAssistant, domain: str, config_entry_id: str | None = None
) -> bool:
"""Return True if there is a matching diagnostics subscription."""
diagnostics_data: DiagnosticsData = hass.data[DOMAIN]
if (
domain in diagnostics_data.domain_subscriptions
and diagnostics_data.domain_subscriptions[domain]
):
return True
if not config_entry_id:
return False
if (
config_entry_id in diagnostics_data.config_entry_subscriptions
and diagnostics_data.config_entry_subscriptions[config_entry_id]
):
return True
return False
@callback
def async_log_object(
hass: HomeAssistant,
data: Any,
domain: str,
config_entry_id: str | None = None,
) -> None:
"""Send diagnostic data to subscribers."""
diagnostics_data: DiagnosticsData = hass.data[DOMAIN]
json_data = json.dumps({"data": data}, cls=ExtendedJSONEncoder)
domain_subs: set[tuple[websocket_api.ActiveConnection, int]] = set()
if domain in diagnostics_data.domain_subscriptions:
for conn, msg_id in diagnostics_data.domain_subscriptions[domain]:
conn.send_message(websocket_api.event_message(msg_id, json_data))
domain_subs.add((conn, msg_id))
if not config_entry_id:
return
if (
config_entry_id in diagnostics_data.config_entry_subscriptions
and diagnostics_data.config_entry_subscriptions[config_entry_id]
):
for conn, msg_id in diagnostics_data.config_entry_subscriptions[
config_entry_id
]:
if (conn, msg_id) in domain_subs:
continue
conn.send_message(websocket_api.event_message(msg_id, json_data))

View file

@ -1,10 +1,18 @@
"""Test the Diagnostics integration."""
from http import HTTPStatus
import json
from unittest.mock import AsyncMock, Mock
import pytest
from homeassistant.components.diagnostics import (
DOMAIN,
DiagnosticsSubscriptionSupport,
async_has_subscription,
async_log_object,
)
from homeassistant.components.websocket_api.const import TYPE_RESULT
from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import async_get
from homeassistant.helpers.system_info import async_get_system_info
from homeassistant.setup import async_setup_component
@ -18,6 +26,8 @@ from tests.common import MockConfigEntry, mock_platform
async def mock_diagnostics_integration(hass):
"""Mock a diagnostics integration."""
hass.config.components.add("fake_integration")
hass.config.components.add("fake_integration_no_subscribe")
hass.config.components.add("integration_without_diagnostics")
mock_platform(
hass,
"fake_integration.diagnostics",
@ -32,14 +42,37 @@ async def mock_diagnostics_integration(hass):
"device": "info",
}
),
async_supports_subscription=Mock(
return_value=DiagnosticsSubscriptionSupport(True, True)
),
),
)
mock_platform(
hass,
"fake_integration_no_subscribe.diagnostics",
Mock(
async_get_config_entry_diagnostics=AsyncMock(
return_value={
"config_entry": "info",
}
),
async_get_device_diagnostics=AsyncMock(
return_value={
"device": "info",
}
),
spec_set=[
"async_get_config_entry_diagnostics",
"async_get_device_diagnostics",
],
),
)
mock_platform(
hass,
"integration_without_diagnostics.diagnostics",
Mock(),
Mock(spec_set=[]),
)
assert await async_setup_component(hass, "diagnostics", {})
assert await async_setup_component(hass, DOMAIN, {})
async def test_websocket(hass, hass_ws_client):
@ -52,12 +85,22 @@ async def test_websocket(hass, hass_ws_client):
assert msg["id"] == 5
assert msg["type"] == TYPE_RESULT
assert msg["success"]
assert msg["result"] == [
{
"domain": "fake_integration",
"handlers": {"config_entry": True, "device": True},
}
]
assert len(msg["result"]) == 3
assert {
"domain": "fake_integration_no_subscribe",
"handlers": {"config_entry": True, "device": True},
"supports_subscription": {"config_entry": False, "domain": False},
} in msg["result"]
assert {
"domain": "fake_integration",
"handlers": {"config_entry": True, "device": True},
"supports_subscription": {"config_entry": True, "domain": True},
} in msg["result"]
assert {
"domain": "integration_without_diagnostics",
"handlers": {"config_entry": False, "device": False},
"supports_subscription": {"config_entry": False, "domain": False},
} in msg["result"]
await client.send_json(
{"id": 6, "type": "diagnostics/get", "domain": "fake_integration"}
@ -71,6 +114,7 @@ async def test_websocket(hass, hass_ws_client):
assert msg["result"] == {
"domain": "fake_integration",
"handlers": {"config_entry": True, "device": True},
"supports_subscription": {"config_entry": True, "domain": True},
}
@ -153,3 +197,76 @@ async def test_failure_scenarios(hass, hass_client):
f"/api/diagnostics/config_entry/{config_entry.entry_id}/device/fake_id"
)
assert response.status == HTTPStatus.NOT_FOUND
async def test_diagnostics_subscription_domain(hass: HomeAssistant, hass_ws_client):
"""Test websocket diagnostics subscription for a domain."""
client = await hass_ws_client(hass)
# Test there's no subscription
assert not async_has_subscription(hass, "fake_integration")
await client.send_json(
{"id": 1, "type": "diagnostics/subscribe", "domain": "fake_integration"}
)
response = await client.receive_json()
assert response["success"]
assert async_has_subscription(hass, "fake_integration")
# Log some data
async_log_object(hass, {"some": "data"}, "fake_integration")
await hass.async_block_till_done()
response = await client.receive_json()
assert json.loads(response["event"]) == {"data": {"some": "data"}}
# Unsubscribe
await client.send_json({"id": 8, "type": "unsubscribe_events", "subscription": 1})
response = await client.receive_json()
assert response["success"]
assert not async_has_subscription(hass, "fake_integration")
assert not async_has_subscription(hass, "fake_integration", "fake_config_entry_id")
async def test_diagnostics_subscription_config_entry(
hass: HomeAssistant, hass_ws_client
):
"""Test websocket diagnostics subscription for a config_entry."""
client = await hass_ws_client(hass)
# Test there's no subscription
assert not async_has_subscription(hass, "fake_integration", "fake_config_entry_id")
await client.send_json(
{
"id": 1,
"type": "diagnostics/subscribe",
"domain": "fake_integration",
"config_entry": "fake_config_entry_id",
}
)
response = await client.receive_json()
assert response["success"]
assert async_has_subscription(hass, "fake_integration", "fake_config_entry_id")
# Log some data for the domain
async_log_object(hass, {"some": "data"}, "fake_integration")
await hass.async_block_till_done()
response = await client.receive_json()
assert json.loads(response["event"]) == {"data": {"some": "data"}}
# Log some data for the config entry
async_log_object(hass, {"some": "data"}, "fake_integration", "fake_config_entry_id")
await hass.async_block_till_done()
response = await client.receive_json()
assert json.loads(response["event"]) == {"data": {"some": "data"}}
# Unsubscribe
await client.send_json({"id": 8, "type": "unsubscribe_events", "subscription": 1})
response = await client.receive_json()
assert response["success"]
assert not async_has_subscription(hass, "fake_integration", "fake_config_entry_id")