diff --git a/homeassistant/components/otbr/__init__.py b/homeassistant/components/otbr/__init__.py index d313c61c0ec..602c76f77ef 100644 --- a/homeassistant/components/otbr/__init__.py +++ b/homeassistant/components/otbr/__init__.py @@ -78,6 +78,11 @@ class OTBRData: """Create an active operational dataset.""" return await self.api.create_active_dataset(dataset) + @_handle_otbr_error + async def set_active_dataset_tlvs(self, dataset: bytes) -> None: + """Set current active operational dataset in TLVS format.""" + await self.api.set_active_dataset_tlvs(dataset) + @_handle_otbr_error async def get_extended_address(self) -> bytes: """Get extended address (EUI-64).""" diff --git a/homeassistant/components/otbr/websocket_api.py b/homeassistant/components/otbr/websocket_api.py index 3d885cd5007..aa8c1dd2dd9 100644 --- a/homeassistant/components/otbr/websocket_api.py +++ b/homeassistant/components/otbr/websocket_api.py @@ -2,9 +2,11 @@ from typing import TYPE_CHECKING import python_otbr_api +from python_otbr_api import tlv_parser +import voluptuous as vol from homeassistant.components import websocket_api -from homeassistant.components.thread import async_add_dataset +from homeassistant.components.thread import async_add_dataset, async_get_dataset from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError @@ -20,6 +22,7 @@ def async_setup(hass: HomeAssistant) -> None: websocket_api.async_register_command(hass, websocket_info) websocket_api.async_register_command(hass, websocket_create_network) websocket_api.async_register_command(hass, websocket_get_extended_address) + websocket_api.async_register_command(hass, websocket_set_network) @websocket_api.websocket_command( @@ -111,6 +114,67 @@ async def websocket_create_network( connection.send_result(msg["id"]) +@websocket_api.websocket_command( + { + "type": "otbr/set_network", + vol.Required("dataset_id"): str, + } +) +@websocket_api.require_admin +@websocket_api.async_response +async def websocket_set_network( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict +) -> None: + """Set the Thread network to be used by the OTBR.""" + if DOMAIN not in hass.data: + connection.send_error(msg["id"], "not_loaded", "No OTBR API loaded") + return + + dataset_tlv = await async_get_dataset(hass, msg["dataset_id"]) + + if not dataset_tlv: + connection.send_error(msg["id"], "unknown_dataset", "Unknown dataset") + return + dataset = tlv_parser.parse_tlv(dataset_tlv) + if channel_str := dataset.get(tlv_parser.MeshcopTLVType.CHANNEL): + thread_dataset_channel = int(channel_str, base=16) + + # We currently have no way to know which channel zha is using, assume it's + # the default + zha_channel = DEFAULT_CHANNEL + + if thread_dataset_channel != zha_channel: + connection.send_error( + msg["id"], + "channel_conflict", + f"Can't connect to network on channel {thread_dataset_channel}, ZHA is " + f"using channel {zha_channel}", + ) + return + + data: OTBRData = hass.data[DOMAIN] + + try: + await data.set_enabled(False) + except HomeAssistantError as exc: + connection.send_error(msg["id"], "set_enabled_failed", str(exc)) + return + + try: + await data.set_active_dataset_tlvs(bytes.fromhex(dataset_tlv)) + except HomeAssistantError as exc: + connection.send_error(msg["id"], "set_active_dataset_tlvs_failed", str(exc)) + return + + try: + await data.set_enabled(True) + except HomeAssistantError as exc: + connection.send_error(msg["id"], "set_enabled_failed", str(exc)) + return + + connection.send_result(msg["id"]) + + @websocket_api.websocket_command( { "type": "otbr/get_extended_address", diff --git a/homeassistant/components/thread/__init__.py b/homeassistant/components/thread/__init__.py index 345fca854d2..4fc88479818 100644 --- a/homeassistant/components/thread/__init__.py +++ b/homeassistant/components/thread/__init__.py @@ -6,13 +6,19 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers.typing import ConfigType from .const import DOMAIN -from .dataset_store import DatasetEntry, async_add_dataset, async_get_preferred_dataset +from .dataset_store import ( + DatasetEntry, + async_add_dataset, + async_get_dataset, + async_get_preferred_dataset, +) from .websocket_api import async_setup as async_setup_ws_api __all__ = [ "DOMAIN", "DatasetEntry", "async_add_dataset", + "async_get_dataset", "async_get_preferred_dataset", ] diff --git a/homeassistant/components/thread/dataset_store.py b/homeassistant/components/thread/dataset_store.py index b9a27b617e6..ea5a16f90cd 100644 --- a/homeassistant/components/thread/dataset_store.py +++ b/homeassistant/components/thread/dataset_store.py @@ -159,6 +159,14 @@ async def async_add_dataset(hass: HomeAssistant, source: str, tlv: str) -> None: store.async_add(source, tlv) +async def async_get_dataset(hass: HomeAssistant, dataset_id: str) -> str | None: + """Get a dataset.""" + store = await async_get_store(hass) + if (entry := store.async_get(dataset_id)) is None: + return None + return entry.tlv + + async def async_get_preferred_dataset(hass: HomeAssistant) -> str | None: """Get the preferred dataset.""" store = await async_get_store(hass) diff --git a/tests/components/otbr/test_websocket_api.py b/tests/components/otbr/test_websocket_api.py index 056563e7b87..04210a3433e 100644 --- a/tests/components/otbr/test_websocket_api.py +++ b/tests/components/otbr/test_websocket_api.py @@ -4,11 +4,11 @@ from unittest.mock import patch import pytest import python_otbr_api -from homeassistant.components import otbr +from homeassistant.components import otbr, thread from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component -from . import BASE_URL, DATASET_CH16 +from . import BASE_URL, DATASET_CH15, DATASET_CH16 from tests.test_util.aiohttp import AiohttpClientMocker from tests.typing import WebSocketGenerator @@ -290,6 +290,190 @@ async def test_create_network_fails_5( assert msg["error"]["code"] == "get_active_dataset_tlvs_empty" +async def test_set_network( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + otbr_config_entry, + websocket_client, +) -> None: + """Test set network.""" + + await thread.async_add_dataset(hass, "test", DATASET_CH15.hex()) + dataset_store = await thread.dataset_store.async_get_store(hass) + dataset_id = list(dataset_store.datasets)[1] + + with patch( + "python_otbr_api.OTBR.set_active_dataset_tlvs" + ) as set_active_dataset_tlvs_mock, patch( + "python_otbr_api.OTBR.set_enabled" + ) as set_enabled_mock: + await websocket_client.send_json_auto_id( + { + "type": "otbr/set_network", + "dataset_id": dataset_id, + } + ) + + msg = await websocket_client.receive_json() + assert msg["success"] + assert msg["result"] is None + + set_active_dataset_tlvs_mock.assert_called_once_with(DATASET_CH15) + assert len(set_enabled_mock.mock_calls) == 2 + assert set_enabled_mock.mock_calls[0][1][0] is False + assert set_enabled_mock.mock_calls[1][1][0] is True + + +async def test_set_network_no_entry( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test set network.""" + await async_setup_component(hass, "otbr", {}) + websocket_client = await hass_ws_client(hass) + await websocket_client.send_json_auto_id( + { + "type": "otbr/set_network", + "dataset_id": "abc", + } + ) + + msg = await websocket_client.receive_json() + assert not msg["success"] + assert msg["error"]["code"] == "not_loaded" + + +async def test_set_network_channel_conflict( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + otbr_config_entry, + websocket_client, +) -> None: + """Test set network.""" + + dataset_store = await thread.dataset_store.async_get_store(hass) + dataset_id = list(dataset_store.datasets)[0] + + await websocket_client.send_json_auto_id( + { + "type": "otbr/set_network", + "dataset_id": dataset_id, + } + ) + + msg = await websocket_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == "channel_conflict" + + +async def test_set_network_unknown_dataset( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + otbr_config_entry, + websocket_client, +) -> None: + """Test set network.""" + + await websocket_client.send_json_auto_id( + { + "type": "otbr/set_network", + "dataset_id": "abc", + } + ) + + msg = await websocket_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == "unknown_dataset" + + +async def test_set_network_fails_1( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + otbr_config_entry, + websocket_client, +) -> None: + """Test set network.""" + await thread.async_add_dataset(hass, "test", DATASET_CH15.hex()) + dataset_store = await thread.dataset_store.async_get_store(hass) + dataset_id = list(dataset_store.datasets)[1] + + with patch( + "python_otbr_api.OTBR.set_enabled", + side_effect=python_otbr_api.OTBRError, + ): + await websocket_client.send_json_auto_id( + { + "type": "otbr/set_network", + "dataset_id": dataset_id, + } + ) + msg = await websocket_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == "set_enabled_failed" + + +async def test_set_network_fails_2( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + otbr_config_entry, + websocket_client, +) -> None: + """Test set network.""" + await thread.async_add_dataset(hass, "test", DATASET_CH15.hex()) + dataset_store = await thread.dataset_store.async_get_store(hass) + dataset_id = list(dataset_store.datasets)[1] + + with patch( + "python_otbr_api.OTBR.set_enabled", + ), patch( + "python_otbr_api.OTBR.set_active_dataset_tlvs", + side_effect=python_otbr_api.OTBRError, + ): + await websocket_client.send_json_auto_id( + { + "type": "otbr/set_network", + "dataset_id": dataset_id, + } + ) + msg = await websocket_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == "set_active_dataset_tlvs_failed" + + +async def test_set_network_fails_3( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + otbr_config_entry, + websocket_client, +) -> None: + """Test set network.""" + await thread.async_add_dataset(hass, "test", DATASET_CH15.hex()) + dataset_store = await thread.dataset_store.async_get_store(hass) + dataset_id = list(dataset_store.datasets)[1] + + with patch( + "python_otbr_api.OTBR.set_enabled", + side_effect=[None, python_otbr_api.OTBRError], + ), patch( + "python_otbr_api.OTBR.set_active_dataset_tlvs", + ): + await websocket_client.send_json_auto_id( + { + "type": "otbr/set_network", + "dataset_id": dataset_id, + } + ) + msg = await websocket_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == "set_enabled_failed" + + async def test_get_extended_address( hass: HomeAssistant, aioclient_mock: AiohttpClientMocker, diff --git a/tests/components/thread/test_dataset_store.py b/tests/components/thread/test_dataset_store.py index 553068ab8bd..581329e860a 100644 --- a/tests/components/thread/test_dataset_store.py +++ b/tests/components/thread/test_dataset_store.py @@ -83,6 +83,17 @@ async def test_delete_preferred_dataset(hass: HomeAssistant) -> None: assert len(store.datasets) == 1 +async def test_get_dataset(hass: HomeAssistant) -> None: + """Test get the preferred dataset.""" + assert await dataset_store.async_get_dataset(hass, "blah") is None + + await dataset_store.async_add_dataset(hass, "source", DATASET_1) + store = await dataset_store.async_get_store(hass) + dataset_id = list(store.datasets.values())[0].id + + assert (await dataset_store.async_get_dataset(hass, dataset_id)) == DATASET_1 + + async def test_get_preferred_dataset(hass: HomeAssistant) -> None: """Test get the preferred dataset.""" assert await dataset_store.async_get_preferred_dataset(hass) is None