diff --git a/homeassistant/components/thread/discovery.py b/homeassistant/components/thread/discovery.py new file mode 100644 index 00000000000..2001626ca1b --- /dev/null +++ b/homeassistant/components/thread/discovery.py @@ -0,0 +1,154 @@ +"""The Thread integration.""" +from __future__ import annotations + +from collections.abc import Callable +import dataclasses +import logging + +from zeroconf import ServiceListener, Zeroconf +from zeroconf.asyncio import AsyncZeroconf + +from homeassistant.components import zeroconf +from homeassistant.core import HomeAssistant + +_LOGGER = logging.getLogger(__name__) + +KNOWN_BRANDS: dict[str | None, str] = { + "Apple Inc.": "apple", + "Google Inc.": "google", + "HomeAssistant": "homeassistant", +} +THREAD_TYPE = "_meshcop._udp.local." + + +@dataclasses.dataclass +class ThreadRouterDiscoveryData: + """Thread router discovery data.""" + + brand: str | None + extended_pan_id: str | None + model_name: str | None + network_name: str | None + server: str | None + vendor_name: str | None + + +class ThreadRouterDiscovery: + """mDNS based Thread router discovery.""" + + class ThreadServiceListener(ServiceListener): + """Service listener which listens for thread routers.""" + + def __init__( + self, + hass: HomeAssistant, + aiozc: AsyncZeroconf, + router_discovered: Callable, + router_removed: Callable, + ) -> None: + """Initialize.""" + self._aiozc = aiozc + self._hass = hass + self._known_routers: dict[str, tuple[str, ThreadRouterDiscoveryData]] = {} + self._router_discovered = router_discovered + self._router_removed = router_removed + + def add_service(self, zc: Zeroconf, type_: str, name: str) -> None: + """Handle service added.""" + _LOGGER.debug("add_service %s", name) + self._hass.async_create_task(self._add_update_service(type_, name)) + + def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None: + """Handle service removed.""" + _LOGGER.debug("remove_service %s", name) + if name not in self._known_routers: + return + extended_mac_address, _ = self._known_routers.pop(name) + self._router_removed(extended_mac_address) + + def update_service(self, zc: Zeroconf, type_: str, name: str) -> None: + """Handle service updated.""" + _LOGGER.debug("update_service %s", name) + self._hass.async_create_task(self._add_update_service(type_, name)) + + async def _add_update_service(self, type_: str, name: str): + """Add or update a service.""" + service = None + tries = 0 + while service is None and tries < 4: + service = await self._aiozc.async_get_service_info(type_, name) + tries += 1 + + if not service: + _LOGGER.debug("_add_update_service failed to add %s, %s", type_, name) + return + + def try_decode(value: bytes | None) -> str | None: + """Try decoding UTF-8.""" + if value is None: + return None + try: + return value.decode() + except UnicodeDecodeError: + return None + + _LOGGER.debug("_add_update_service %s %s", name, service) + # We use the extended mac address as key, bail out if it's missing + try: + extended_mac_address = service.properties[b"xa"].hex() + except (KeyError, UnicodeDecodeError) as err: + _LOGGER.debug("_add_update_service failed to parse service %s", err) + return + ext_pan_id = service.properties.get(b"xp") + network_name = try_decode(service.properties.get(b"nn")) + model_name = try_decode(service.properties.get(b"mn")) + server = service.server + vendor_name = try_decode(service.properties.get(b"vn")) + data = ThreadRouterDiscoveryData( + brand=KNOWN_BRANDS.get(vendor_name), + extended_pan_id=ext_pan_id.hex() if ext_pan_id is not None else None, + model_name=model_name, + network_name=network_name, + server=server, + vendor_name=vendor_name, + ) + if name in self._known_routers and self._known_routers[name] == ( + extended_mac_address, + data, + ): + _LOGGER.debug( + "_add_update_service suppressing identical update for %s", name + ) + return + self._known_routers[name] = (extended_mac_address, data) + self._router_discovered(extended_mac_address, data) + + def __init__( + self, + hass: HomeAssistant, + router_discovered: Callable[[str, ThreadRouterDiscoveryData], None], + router_removed: Callable[[str], None], + ) -> None: + """Initialize.""" + self._hass = hass + self._aiozc: AsyncZeroconf | None = None + self._router_discovered = router_discovered + self._router_removed = router_removed + self._service_listener: ThreadRouterDiscovery.ThreadServiceListener | None = ( + None + ) + + async def async_start(self) -> None: + """Start discovery.""" + self._aiozc = aiozc = await zeroconf.async_get_async_instance(self._hass) + self._service_listener = self.ThreadServiceListener( + self._hass, aiozc, self._router_discovered, self._router_removed + ) + await aiozc.async_add_service_listener(THREAD_TYPE, self._service_listener) + + async def async_stop(self) -> None: + """Stop discovery.""" + if not self._aiozc or not self._service_listener: + return + await self._aiozc.async_remove_service_listener(self._service_listener) + self._service_listener = None diff --git a/homeassistant/components/thread/manifest.json b/homeassistant/components/thread/manifest.json index c8bc98834fd..a6e823de570 100644 --- a/homeassistant/components/thread/manifest.json +++ b/homeassistant/components/thread/manifest.json @@ -3,6 +3,7 @@ "name": "Thread", "codeowners": ["@home-assistant/core"], "config_flow": true, + "dependencies": ["zeroconf"], "documentation": "https://www.home-assistant.io/integrations/thread", "integration_type": "service", "iot_class": "local_polling", diff --git a/homeassistant/components/thread/websocket_api.py b/homeassistant/components/thread/websocket_api.py index 97303f0ea7d..54f700f3ba1 100644 --- a/homeassistant/components/thread/websocket_api.py +++ b/homeassistant/components/thread/websocket_api.py @@ -9,13 +9,14 @@ import voluptuous as vol from homeassistant.components import websocket_api from homeassistant.core import HomeAssistant, callback -from . import dataset_store +from . import dataset_store, discovery @callback def async_setup(hass: HomeAssistant) -> None: """Set up the sensor websocket API.""" websocket_api.async_register_command(hass, ws_add_dataset) + websocket_api.async_register_command(hass, ws_discover_routers) websocket_api.async_register_command(hass, ws_get_dataset) websocket_api.async_register_command(hass, ws_list_datasets) @@ -100,3 +101,59 @@ async def ws_list_datasets( ) connection.send_result(msg["id"], {"datasets": result}) + + +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "thread/discover_routers", + } +) +@websocket_api.async_response +async def ws_discover_routers( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] +) -> None: + """Discover Thread routers.""" + + @callback + def router_discovered(key: str, data: discovery.ThreadRouterDiscoveryData) -> None: + """Forward router discovery or update to websocket.""" + + connection.send_message( + websocket_api.event_message( + msg["id"], + { + "type": "router_discovered", + "key": key, + "data": data, + }, + ) + ) + + @callback + def router_removed(key: str) -> None: + """Forward router discovery or update to websocket.""" + + connection.send_message( + websocket_api.event_message( + msg["id"], + { + "type": "router_removed", + "key": key, + }, + ) + ) + + @callback + def stop_discovery() -> None: + """Stop discovery.""" + hass.async_create_task(thread_discovery.async_stop()) + + # Start Thread router discovery + thread_discovery = discovery.ThreadRouterDiscovery( + hass, router_discovered, router_removed + ) + await thread_discovery.async_start() + connection.subscriptions[msg["id"]] = stop_discovery + + connection.send_message(websocket_api.result_message(msg["id"])) diff --git a/tests/components/thread/__init__.py b/tests/components/thread/__init__.py index e186d8ea0a6..c8e4453da1f 100644 --- a/tests/components/thread/__init__.py +++ b/tests/components/thread/__init__.py @@ -17,3 +17,171 @@ DATASET_3 = ( "E5AA15DD051000112233445566778899AABBCCDDEEFF030E7ef09f90a3f09f90a5f09f90a47e01" "0212340410445F2B5CA6F2A93A55CE570A70EFEECB0C0402A0F7F8" ) + + +ROUTER_DISCOVERY_GOOGLE_1 = { + "type_": "_meshcop._udp.local.", + "name": "Google-Nest-Hub-#ABED._meshcop._udp.local.", + "addresses": [b"\xc0\xa8\x00|"], + "port": 49191, + "weight": 0, + "priority": 0, + "server": "2d99f293-cd8e-2770-8dd2-6675de9fa000.local.", + "properties": { + b"rv": b"1", + b"vn": b"Google Inc.", + b"mn": b"Google Nest Hub", + b"nn": b"NEST-PAN-E1AF", + b"xp": b"\x9eu\xe2V\xf6\x14\t\xa3", + b"tv": b"1.3.0", + b"xa": b"\xf6\xa9\x9bBZg\xab\xed", + b"sb": b"\x00\x00\x01\xb1", + b"at": b"\x00\x00b\xf2\xf8$T\xe3", + b"pt": b"4\x860D", + b"sq": b"{", + b"bb": b"\xf0\xbf", + b"dn": b"DefaultDomain", + b"id": b"\xbc7@\xc3\xe9c\xaa\x875\xbe\xbe\xcd|\xc5\x03\xc7", + b"vat": b"000062f2f82454e3", + b"vcd": b"BC3740C3E963AA8735BEBECD7CC503C7", + b"vo": b"|\xd9\\", + b"vvo": b"7CD95C", + b"vxp": b"9e75e256f61409a3", + }, + "interface_index": None, +} + +ROUTER_DISCOVERY_GOOGLE_2 = { + "type": "_meshcop._udp.local.", + "name": "Google-Nest-Hub-#D8D5._meshcop._udp.local.", + "addresses": [b"\xc0\xa8\x00q"], + "port": 49191, + "weight": 0, + "priority": 0, + "server": "80adee71-a563-2cfe-4402-95a9bc6ae3a1.local.", + "properties": { + b"rv": b"1", + b"vn": b"Google Inc.", + b"mn": b"Google Nest Hub", + b"nn": b"NEST-PAN-E1AF", + b"xp": b"\x9eu\xe2V\xf6\x14\t\xa3", + b"tv": b"1.3.0", + b"xa": b"\x8e9Z\xaek\xd5\xd8\xd5", + b"sb": b"\x00\x00\x00\xb1", + b"at": b"\x00\x00b\xf2\xf8$T\xe3", + b"pt": b"4\x860D", + b"sq": b'"', + b"bb": b"\xf0\xbf", + b"dn": b"DefaultDomain", + b"id": b"\xffi]\x11\xf6\xac)\xbe\xdb\x84\xb1o{\x8c\x1e\x82", + b"vat": b"000062f2f82454e3", + b"vcd": b"FF695D11F6AC29BEDB84B16F7B8C1E82", + b"vo": b"|\xd9\\", + b"vvo": b"7CD95C", + b"vxp": b"9e75e256f61409a3", + }, + "interface_index": None, +} + +ROUTER_DISCOVERY_HASS = { + "type_": "_meshcop._udp.local.", + "name": "HomeAssistant OpenThreadBorderRouter #0BBF._meshcop._udp.local.", + "addresses": [b"\xc0\xa8\x00s"], + "port": 49153, + "weight": 0, + "priority": 0, + "server": "core-silabs-multiprotocol.local.", + "properties": { + b"rv": b"1", + b"vn": b"HomeAssistant", + b"mn": b"OpenThreadBorderRouter", + b"nn": b"OpenThread HC", + b"xp": b"\xe6\x0f\xc7\xc1\x86!,\xe5", + b"tv": b"1.3.0", + b"xa": b"\xae\xeb/YKW\x0b\xbf", + b"sb": b"\x00\x00\x01\xb1", + b"at": b"\x00\x00\x00\x00\x00\x01\x00\x00", + b"pt": b"\x8f\x06Q~", + b"sq": b"3", + b"bb": b"\xf0\xbf", + b"dn": b"DefaultDomain", + }, + "interface_index": None, +} + +ROUTER_DISCOVERY_HASS_BAD_DATA = { + "type_": "_meshcop._udp.local.", + "name": "HomeAssistant OpenThreadBorderRouter #0BBF._meshcop._udp.local.", + "addresses": [b"\xc0\xa8\x00s"], + "port": 49153, + "weight": 0, + "priority": 0, + "server": "core-silabs-multiprotocol.local.", + "properties": { + b"rv": b"1", + b"vn": b"HomeAssistant\xff", # Invalid UTF-8 + b"mn": b"OpenThreadBorderRouter", + b"nn": b"OpenThread HC", + b"xp": b"\xe6\x0f\xc7\xc1\x86!,\xe5", + b"tv": b"1.3.0", + b"xa": b"\xae\xeb/YKW\x0b\xbf", + b"sb": b"\x00\x00\x01\xb1", + b"at": b"\x00\x00\x00\x00\x00\x01\x00\x00", + b"pt": b"\x8f\x06Q~", + b"sq": b"3", + b"bb": b"\xf0\xbf", + b"dn": b"DefaultDomain", + }, + "interface_index": None, +} + +ROUTER_DISCOVERY_HASS_MISSING_DATA = { + "type_": "_meshcop._udp.local.", + "name": "HomeAssistant OpenThreadBorderRouter #0BBF._meshcop._udp.local.", + "addresses": [b"\xc0\xa8\x00s"], + "port": 49153, + "weight": 0, + "priority": 0, + "server": "core-silabs-multiprotocol.local.", + "properties": { + b"rv": b"1", + b"mn": b"OpenThreadBorderRouter", + b"nn": b"OpenThread HC", + b"xp": b"\xe6\x0f\xc7\xc1\x86!,\xe5", + b"tv": b"1.3.0", + b"xa": b"\xae\xeb/YKW\x0b\xbf", + b"sb": b"\x00\x00\x01\xb1", + b"at": b"\x00\x00\x00\x00\x00\x01\x00\x00", + b"pt": b"\x8f\x06Q~", + b"sq": b"3", + b"bb": b"\xf0\xbf", + b"dn": b"DefaultDomain", + }, + "interface_index": None, +} + + +ROUTER_DISCOVERY_HASS_MISSING_MANDATORY_DATA = { + "type_": "_meshcop._udp.local.", + "name": "HomeAssistant OpenThreadBorderRouter #0BBF._meshcop._udp.local.", + "addresses": [b"\xc0\xa8\x00s"], + "port": 49153, + "weight": 0, + "priority": 0, + "server": "core-silabs-multiprotocol.local.", + "properties": { + b"rv": b"1", + b"vn": b"HomeAssistant", + b"mn": b"OpenThreadBorderRouter", + b"nn": b"OpenThread HC", + b"xp": b"\xe6\x0f\xc7\xc1\x86!,\xe5", + b"tv": b"1.3.0", + b"sb": b"\x00\x00\x01\xb1", + b"at": b"\x00\x00\x00\x00\x00\x01\x00\x00", + b"pt": b"\x8f\x06Q~", + b"sq": b"3", + b"bb": b"\xf0\xbf", + b"dn": b"DefaultDomain", + }, + "interface_index": None, +} diff --git a/tests/components/thread/conftest.py b/tests/components/thread/conftest.py index 37555d07a90..a02b18e9eab 100644 --- a/tests/components/thread/conftest.py +++ b/tests/components/thread/conftest.py @@ -20,3 +20,8 @@ async def thread_config_entry_fixture(hass): ) config_entry.add_to_hass(hass) assert await hass.config_entries.async_setup(config_entry.entry_id) + + +@pytest.fixture(autouse=True) +def use_mocked_zeroconf(mock_async_zeroconf): + """Mock zeroconf in all tests.""" diff --git a/tests/components/thread/test_discovery.py b/tests/components/thread/test_discovery.py new file mode 100644 index 00000000000..ff77d86339e --- /dev/null +++ b/tests/components/thread/test_discovery.py @@ -0,0 +1,300 @@ +"""Test the thread websocket API.""" + +from unittest.mock import ANY, AsyncMock, Mock + +import pytest +from zeroconf.asyncio import AsyncServiceInfo + +from homeassistant.components.thread import discovery +from homeassistant.components.thread.const import DOMAIN +from homeassistant.core import HomeAssistant, callback +from homeassistant.setup import async_setup_component + +from . import ( + ROUTER_DISCOVERY_GOOGLE_1, + ROUTER_DISCOVERY_HASS, + ROUTER_DISCOVERY_HASS_BAD_DATA, + ROUTER_DISCOVERY_HASS_MISSING_DATA, + ROUTER_DISCOVERY_HASS_MISSING_MANDATORY_DATA, +) + + +async def test_discover_routers(hass: HomeAssistant, mock_async_zeroconf) -> None: + """Test discovering thread routers.""" + mock_async_zeroconf.async_add_service_listener = AsyncMock() + mock_async_zeroconf.async_remove_service_listener = AsyncMock() + mock_async_zeroconf.async_get_service_info = AsyncMock() + + assert await async_setup_component(hass, DOMAIN, {}) + await hass.async_block_till_done() + + discovered = [] + removed = [] + + @callback + def router_discovered(key: str, data: discovery.ThreadRouterDiscoveryData) -> None: + """Handle router discovered.""" + discovered.append((key, data)) + + @callback + def router_removed(key: str) -> None: + """Handle router removed.""" + removed.append(key) + + # Start Thread router discovery + thread_disovery = discovery.ThreadRouterDiscovery( + hass, router_discovered, router_removed + ) + await thread_disovery.async_start() + + mock_async_zeroconf.async_add_service_listener.assert_called_once_with( + "_meshcop._udp.local.", ANY + ) + listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = ( + mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1] + ) + + # Discover a service + mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo( + **ROUTER_DISCOVERY_HASS + ) + listener.add_service( + None, ROUTER_DISCOVERY_HASS["type_"], ROUTER_DISCOVERY_HASS["name"] + ) + await hass.async_block_till_done() + assert len(discovered) == 1 + assert len(removed) == 0 + assert discovered[-1] == ( + "aeeb2f594b570bbf", + discovery.ThreadRouterDiscoveryData( + brand="homeassistant", + extended_pan_id="e60fc7c186212ce5", + model_name="OpenThreadBorderRouter", + network_name="OpenThread HC", + server="core-silabs-multiprotocol.local.", + vendor_name="HomeAssistant", + ), + ) + + # Discover another service - we don't care if zeroconf considers this an update + mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo( + **ROUTER_DISCOVERY_GOOGLE_1 + ) + listener.update_service( + None, ROUTER_DISCOVERY_GOOGLE_1["type_"], ROUTER_DISCOVERY_GOOGLE_1["name"] + ) + await hass.async_block_till_done() + assert len(discovered) == 2 + assert len(removed) == 0 + assert discovered[-1] == ( + "f6a99b425a67abed", + discovery.ThreadRouterDiscoveryData( + brand="google", + extended_pan_id="9e75e256f61409a3", + model_name="Google Nest Hub", + network_name="NEST-PAN-E1AF", + server="2d99f293-cd8e-2770-8dd2-6675de9fa000.local.", + vendor_name="Google Inc.", + ), + ) + + # Remove a service + listener.remove_service( + None, ROUTER_DISCOVERY_HASS["type_"], ROUTER_DISCOVERY_HASS["name"] + ) + await hass.async_block_till_done() + assert len(discovered) == 2 + assert len(removed) == 1 + assert removed[-1] == "aeeb2f594b570bbf" + + # Remove the service again + listener.remove_service( + None, ROUTER_DISCOVERY_HASS["type_"], ROUTER_DISCOVERY_HASS["name"] + ) + await hass.async_block_till_done() + assert len(discovered) == 2 + assert len(removed) == 1 + + # Remove an unknown service + listener.remove_service(None, ROUTER_DISCOVERY_HASS["type_"], "unknown") + await hass.async_block_till_done() + assert len(discovered) == 2 + assert len(removed) == 1 + + # Stop Thread router discovery + await thread_disovery.async_stop() + mock_async_zeroconf.async_remove_service_listener.assert_called_once_with(listener) + + +@pytest.mark.parametrize( + "data", (ROUTER_DISCOVERY_HASS_BAD_DATA, ROUTER_DISCOVERY_HASS_MISSING_DATA) +) +async def test_discover_routers_bad_data( + hass: HomeAssistant, mock_async_zeroconf, data +) -> None: + """Test discovering thread routers with bad or missing vendor mDNS data.""" + mock_async_zeroconf.async_add_service_listener = AsyncMock() + mock_async_zeroconf.async_remove_service_listener = AsyncMock() + mock_async_zeroconf.async_get_service_info = AsyncMock() + + assert await async_setup_component(hass, DOMAIN, {}) + await hass.async_block_till_done() + + # Start Thread router discovery + router_discovered_removed = Mock() + thread_disovery = discovery.ThreadRouterDiscovery( + hass, router_discovered_removed, router_discovered_removed + ) + await thread_disovery.async_start() + listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = ( + mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1] + ) + + # Discover a service with bad or missing data + mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo(**data) + listener.add_service(None, data["type_"], data["name"]) + await hass.async_block_till_done() + router_discovered_removed.assert_called_once_with( + "aeeb2f594b570bbf", + discovery.ThreadRouterDiscoveryData( + brand=None, + extended_pan_id="e60fc7c186212ce5", + model_name="OpenThreadBorderRouter", + network_name="OpenThread HC", + server="core-silabs-multiprotocol.local.", + vendor_name=None, + ), + ) + + +async def test_discover_routers_missing_mandatory_data( + hass: HomeAssistant, mock_async_zeroconf +) -> None: + """Test discovering thread routers with missing mandatory mDNS data.""" + mock_async_zeroconf.async_add_service_listener = AsyncMock() + mock_async_zeroconf.async_remove_service_listener = AsyncMock() + mock_async_zeroconf.async_get_service_info = AsyncMock() + + assert await async_setup_component(hass, DOMAIN, {}) + await hass.async_block_till_done() + + # Start Thread router discovery + router_discovered_removed = Mock() + thread_disovery = discovery.ThreadRouterDiscovery( + hass, router_discovered_removed, router_discovered_removed + ) + await thread_disovery.async_start() + listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = ( + mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1] + ) + + # Discover a service with missing mandatory data + mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo( + **ROUTER_DISCOVERY_HASS_MISSING_MANDATORY_DATA + ) + listener.add_service( + None, + ROUTER_DISCOVERY_HASS_MISSING_MANDATORY_DATA["type_"], + ROUTER_DISCOVERY_HASS_MISSING_MANDATORY_DATA["name"], + ) + await hass.async_block_till_done() + router_discovered_removed.assert_not_called() + + +async def test_discover_routers_get_service_info_fails( + hass: HomeAssistant, mock_async_zeroconf +) -> None: + """Test discovering thread routers with invalid mDNS data.""" + mock_async_zeroconf.async_add_service_listener = AsyncMock() + mock_async_zeroconf.async_remove_service_listener = AsyncMock() + mock_async_zeroconf.async_get_service_info = AsyncMock() + + assert await async_setup_component(hass, DOMAIN, {}) + await hass.async_block_till_done() + + # Start Thread router discovery + router_discovered_removed = Mock() + thread_disovery = discovery.ThreadRouterDiscovery( + hass, router_discovered_removed, router_discovered_removed + ) + await thread_disovery.async_start() + listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = ( + mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1] + ) + + # Discover a service with missing data + mock_async_zeroconf.async_get_service_info.return_value = None + listener.add_service( + None, ROUTER_DISCOVERY_HASS["type_"], ROUTER_DISCOVERY_HASS["name"] + ) + await hass.async_block_till_done() + router_discovered_removed.assert_not_called() + + +async def test_discover_routers_update_unchanged( + hass: HomeAssistant, mock_async_zeroconf +) -> None: + """Test discovering thread routers with identical mDNS data in update.""" + mock_async_zeroconf.async_add_service_listener = AsyncMock() + mock_async_zeroconf.async_remove_service_listener = AsyncMock() + mock_async_zeroconf.async_get_service_info = AsyncMock() + + assert await async_setup_component(hass, DOMAIN, {}) + await hass.async_block_till_done() + + # Start Thread router discovery + router_discovered_removed = Mock() + thread_disovery = discovery.ThreadRouterDiscovery( + hass, router_discovered_removed, router_discovered_removed + ) + await thread_disovery.async_start() + listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = ( + mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1] + ) + + # Discover a service + mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo( + **ROUTER_DISCOVERY_HASS + ) + listener.add_service( + None, ROUTER_DISCOVERY_HASS["type_"], ROUTER_DISCOVERY_HASS["name"] + ) + await hass.async_block_till_done() + router_discovered_removed.assert_called_once() + + # Update the service unchanged + mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo( + **ROUTER_DISCOVERY_HASS + ) + listener.update_service( + None, ROUTER_DISCOVERY_HASS["type_"], ROUTER_DISCOVERY_HASS["name"] + ) + await hass.async_block_till_done() + router_discovered_removed.assert_called_once() + + +async def test_discover_routers_stop_twice( + hass: HomeAssistant, mock_async_zeroconf +) -> None: + """Test discovering thread routers stopping discovery twice.""" + mock_async_zeroconf.async_add_service_listener = AsyncMock() + mock_async_zeroconf.async_remove_service_listener = AsyncMock() + mock_async_zeroconf.async_get_service_info = AsyncMock() + + assert await async_setup_component(hass, DOMAIN, {}) + await hass.async_block_till_done() + + # Start Thread router discovery + router_discovered_removed = Mock() + thread_disovery = discovery.ThreadRouterDiscovery( + hass, router_discovered_removed, router_discovered_removed + ) + await thread_disovery.async_start() + + # Stop Thread router discovery + await thread_disovery.async_stop() + mock_async_zeroconf.async_remove_service_listener.assert_called_once() + + # Stop Thread router discovery again + await thread_disovery.async_stop() + mock_async_zeroconf.async_remove_service_listener.assert_called_once() diff --git a/tests/components/thread/test_websocket_api.py b/tests/components/thread/test_websocket_api.py index 10c37258f93..7a474a19943 100644 --- a/tests/components/thread/test_websocket_api.py +++ b/tests/components/thread/test_websocket_api.py @@ -1,11 +1,21 @@ """Test the thread websocket API.""" -from homeassistant.components.thread import dataset_store +from unittest.mock import ANY, AsyncMock + +from zeroconf.asyncio import AsyncServiceInfo + +from homeassistant.components.thread import dataset_store, discovery from homeassistant.components.thread.const import DOMAIN from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component -from . import DATASET_1, DATASET_2, DATASET_3 +from . import ( + DATASET_1, + DATASET_2, + DATASET_3, + ROUTER_DISCOVERY_GOOGLE_1, + ROUTER_DISCOVERY_HASS, +) async def test_add_dataset(hass: HomeAssistant, hass_ws_client) -> None: @@ -121,3 +131,98 @@ async def test_list_get_dataset(hass: HomeAssistant, hass_ws_client) -> None: msg = await client.receive_json() assert not msg["success"] assert msg["error"] == {"code": "not_found", "message": "unknown dataset"} + + +async def test_discover_routers( + hass: HomeAssistant, hass_ws_client, mock_async_zeroconf +) -> None: + """Test discovering thread routers.""" + mock_async_zeroconf.async_add_service_listener = AsyncMock() + mock_async_zeroconf.async_remove_service_listener = AsyncMock() + mock_async_zeroconf.async_get_service_info = AsyncMock() + + assert await async_setup_component(hass, DOMAIN, {}) + await hass.async_block_till_done() + + client = await hass_ws_client(hass) + + # Subscribe + await client.send_json({"id": 1, "type": "thread/discover_routers"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] is None + + mock_async_zeroconf.async_add_service_listener.assert_called_once_with( + "_meshcop._udp.local.", ANY + ) + listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = ( + mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1] + ) + + # Discover a service + mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo( + **ROUTER_DISCOVERY_HASS + ) + listener.add_service( + None, ROUTER_DISCOVERY_HASS["type_"], ROUTER_DISCOVERY_HASS["name"] + ) + msg = await client.receive_json() + assert msg == { + "event": { + "data": { + "brand": "homeassistant", + "extended_pan_id": "e60fc7c186212ce5", + "model_name": "OpenThreadBorderRouter", + "network_name": "OpenThread HC", + "server": "core-silabs-multiprotocol.local.", + "vendor_name": "HomeAssistant", + }, + "key": "aeeb2f594b570bbf", + "type": "router_discovered", + }, + "id": 1, + "type": "event", + } + + # Discover another service - we don't care if zeroconf considers this an update + mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo( + **ROUTER_DISCOVERY_GOOGLE_1 + ) + listener.update_service( + None, ROUTER_DISCOVERY_GOOGLE_1["type_"], ROUTER_DISCOVERY_GOOGLE_1["name"] + ) + msg = await client.receive_json() + assert msg == { + "event": { + "data": { + "brand": "google", + "extended_pan_id": "9e75e256f61409a3", + "model_name": "Google Nest Hub", + "network_name": "NEST-PAN-E1AF", + "server": "2d99f293-cd8e-2770-8dd2-6675de9fa000.local.", + "vendor_name": "Google Inc.", + }, + "key": "f6a99b425a67abed", + "type": "router_discovered", + }, + "id": 1, + "type": "event", + } + + # Remove a service + listener.remove_service( + None, ROUTER_DISCOVERY_HASS["type_"], ROUTER_DISCOVERY_HASS["name"] + ) + msg = await client.receive_json() + assert msg == { + "event": {"key": "aeeb2f594b570bbf", "type": "router_removed"}, + "id": 1, + "type": "event", + } + + # Unsubscribe + await client.send_json({"id": 2, "type": "unsubscribe_events", "subscription": 1}) + response = await client.receive_json() + assert response["success"] + + mock_async_zeroconf.async_remove_service_listener.assert_called_once_with(listener)