From b7c7e7f57b532e97d1aa9aaa60182ddcc3ea633d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 13 Dec 2022 14:22:34 -1000 Subject: [PATCH] Try to reconnect disconnected shelly devices as soon as they discovered by zeroconf (#83872) --- .../components/shelly/config_flow.py | 25 ++++++- .../components/shelly/coordinator.py | 14 ++++ homeassistant/components/shelly/utils.py | 6 ++ tests/components/shelly/conftest.py | 6 ++ tests/components/shelly/test_config_flow.py | 73 +++++++++++++++++++ 5 files changed, 120 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/shelly/config_flow.py b/homeassistant/components/shelly/config_flow.py index 9bf4a6126b0..8679edf5382 100644 --- a/homeassistant/components/shelly/config_flow.py +++ b/homeassistant/components/shelly/config_flow.py @@ -30,7 +30,7 @@ from .const import ( LOGGER, BLEScannerMode, ) -from .coordinator import get_entry_data +from .coordinator import async_reconnect_soon, get_entry_data from .utils import ( get_block_device_name, get_block_device_sleep_period, @@ -41,6 +41,7 @@ from .utils import ( get_rpc_device_name, get_rpc_device_sleep_period, get_ws_context, + mac_address_from_name, ) HOST_SCHEMA: Final = vol.Schema({vol.Required(CONF_HOST): str}) @@ -210,11 +211,25 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): step_id="credentials", data_schema=vol.Schema(schema), errors=errors ) + async def _async_discovered_mac(self, mac: str, host: str) -> None: + """Abort and reconnect soon if the device with the mac address is already configured.""" + if ( + current_entry := await self.async_set_unique_id(mac) + ) and current_entry.data[CONF_HOST] == host: + await async_reconnect_soon(self.hass, current_entry) + self._abort_if_unique_id_configured({CONF_HOST: host}) + async def async_step_zeroconf( self, discovery_info: zeroconf.ZeroconfServiceInfo ) -> FlowResult: """Handle zeroconf discovery.""" host = discovery_info.host + # First try to get the mac address from the name + # so we can avoid making another connection to the + # device if we already have it configured + if mac := mac_address_from_name(discovery_info.name): + await self._async_discovered_mac(mac, host) + try: self.info = await self._async_get_info(host) except DeviceConnectionError: @@ -222,10 +237,12 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): except FirmwareUnsupported: return self.async_abort(reason="unsupported_firmware") - await self.async_set_unique_id(self.info["mac"]) - self._abort_if_unique_id_configured({CONF_HOST: host}) - self.host = host + if not mac: + # We could not get the mac address from the name + # so need to check here since we just got the info + await self._async_discovered_mac(self.info["mac"], host) + self.host = host self.context.update( { "title_placeholders": {"name": discovery_info.name.split(".")[0]}, diff --git a/homeassistant/components/shelly/coordinator.py b/homeassistant/components/shelly/coordinator.py index 9203606230b..9ccbab66b0a 100644 --- a/homeassistant/components/shelly/coordinator.py +++ b/homeassistant/components/shelly/coordinator.py @@ -14,6 +14,7 @@ from aioshelly.exceptions import DeviceConnectionError, InvalidAuthError, RpcCal from aioshelly.rpc_device import RpcDevice, UpdateType from awesomeversion import AwesomeVersion +from homeassistant import config_entries from homeassistant.config_entries import ConfigEntry from homeassistant.const import ATTR_DEVICE_ID, CONF_HOST, EVENT_HOMEASSISTANT_STOP from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback @@ -646,3 +647,16 @@ def get_rpc_coordinator_by_device_id( return coordinator return None + + +async def async_reconnect_soon( + hass: HomeAssistant, entry: config_entries.ConfigEntry +) -> None: + """Try to reconnect soon.""" + if ( + not hass.is_stopping + and entry.state == config_entries.ConfigEntryState.LOADED + and (entry_data := get_entry_data(hass).get(entry.entry_id)) + and (coordinator := entry_data.rpc) + ): + hass.async_create_task(coordinator.async_request_refresh()) diff --git a/homeassistant/components/shelly/utils.py b/homeassistant/components/shelly/utils.py index 418cec64d40..b048b219e6b 100644 --- a/homeassistant/components/shelly/utils.py +++ b/homeassistant/components/shelly/utils.py @@ -408,3 +408,9 @@ def brightness_to_percentage(brightness: int) -> int: def percentage_to_brightness(percentage: int) -> int: """Convert percentage to brightness level.""" return round(255 * percentage / 100) + + +def mac_address_from_name(name: str) -> str | None: + """Convert a name to a mac address.""" + mac = name.partition(".")[0].partition("-")[-1] + return mac.upper() if len(mac) == 12 else None diff --git a/tests/components/shelly/conftest.py b/tests/components/shelly/conftest.py index 214ed3b1503..e2ba5fc767a 100644 --- a/tests/components/shelly/conftest.py +++ b/tests/components/shelly/conftest.py @@ -301,8 +301,14 @@ async def mock_rpc_device(): {}, UpdateType.EVENT ) + def disconnected(): + rpc_device_mock.return_value.subscribe_updates.call_args[0][0]( + {}, UpdateType.DISCONNECTED + ) + device = _mock_rpc_device("0.12.0") rpc_device_mock.return_value = device + rpc_device_mock.return_value.mock_disconnected = Mock(side_effect=disconnected) rpc_device_mock.return_value.mock_update = Mock(side_effect=update) rpc_device_mock.return_value.mock_event = Mock(side_effect=event) diff --git a/tests/components/shelly/test_config_flow.py b/tests/components/shelly/test_config_flow.py index 1a1acea16a3..6795049a207 100644 --- a/tests/components/shelly/test_config_flow.py +++ b/tests/components/shelly/test_config_flow.py @@ -37,6 +37,15 @@ DISCOVERY_INFO = zeroconf.ZeroconfServiceInfo( properties={zeroconf.ATTR_PROPERTIES_ID: "shelly1pm-12345"}, type="mock_type", ) +DISCOVERY_INFO_WITH_MAC = zeroconf.ZeroconfServiceInfo( + host="1.1.1.1", + addresses=["1.1.1.1"], + hostname="mock_hostname", + name="shelly1pm-AABBCCDDEEFF", + port=None, + properties={zeroconf.ATTR_PROPERTIES_ID: "shelly1pm-AABBCCDDEEFF"}, + type="mock_type", +) MOCK_CONFIG = { "sys": { "device": {"name": "Test name"}, @@ -1064,3 +1073,67 @@ async def test_options_flow_pre_ble_device(hass, mock_pre_ble_rpc_device): assert result["reason"] == "ble_unsupported" await hass.config_entries.async_unload(entry.entry_id) + + +async def test_zeroconf_already_configured_triggers_refresh_mac_in_name( + hass, mock_rpc_device, monkeypatch +): + """Test zeroconf discovery triggers refresh when the mac is in the device name.""" + entry = MockConfigEntry( + domain="shelly", + unique_id="AABBCCDDEEFF", + data={"host": "1.1.1.1", "gen": 2, "sleep_period": 0, "model": "SHSW-1"}, + ) + entry.add_to_hass(hass) + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + assert len(mock_rpc_device.initialize.mock_calls) == 1 + + with patch( + "aioshelly.common.get_info", + return_value={"mac": "", "type": "SHSW-1", "auth": False}, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + data=DISCOVERY_INFO_WITH_MAC, + context={"source": config_entries.SOURCE_ZEROCONF}, + ) + assert result["type"] == data_entry_flow.FlowResultType.ABORT + assert result["reason"] == "already_configured" + + monkeypatch.setattr(mock_rpc_device, "connected", False) + mock_rpc_device.mock_disconnected() + await hass.async_block_till_done() + assert len(mock_rpc_device.initialize.mock_calls) == 2 + + +async def test_zeroconf_already_configured_triggers_refresh( + hass, mock_rpc_device, monkeypatch +): + """Test zeroconf discovery triggers refresh when the mac is obtained via get_info.""" + entry = MockConfigEntry( + domain="shelly", + unique_id="AABBCCDDEEFF", + data={"host": "1.1.1.1", "gen": 2, "sleep_period": 0, "model": "SHSW-1"}, + ) + entry.add_to_hass(hass) + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + assert len(mock_rpc_device.initialize.mock_calls) == 1 + + with patch( + "aioshelly.common.get_info", + return_value={"mac": "AABBCCDDEEFF", "type": "SHSW-1", "auth": False}, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + data=DISCOVERY_INFO, + context={"source": config_entries.SOURCE_ZEROCONF}, + ) + assert result["type"] == data_entry_flow.FlowResultType.ABORT + assert result["reason"] == "already_configured" + + monkeypatch.setattr(mock_rpc_device, "connected", False) + mock_rpc_device.mock_disconnected() + await hass.async_block_till_done() + assert len(mock_rpc_device.initialize.mock_calls) == 2