Reduce duplicate code in ESPHome connection callback (#107338)

This commit is contained in:
J. Nick Koston 2024-01-07 22:10:58 -10:00 committed by GitHub
parent 102fdbb237
commit d609344f40
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 23 deletions

View file

@ -383,6 +383,17 @@ class ESPHomeManager:
self.voice_assistant_udp_server.stop()
async def on_connect(self) -> None:
"""Subscribe to states and list entities on successful API login."""
try:
await self._on_connnect()
except APIConnectionError as err:
_LOGGER.warning(
"Error getting setting up connection for %s: %s", self.host, err
)
# Re-connection logic will trigger after this
await self.cli.disconnect()
async def _on_connnect(self) -> None:
"""Subscribe to states and list entities on successful API login."""
entry = self.entry
unique_id = entry.unique_id
@ -393,16 +404,10 @@ class ESPHomeManager:
cli = self.cli
stored_device_name = entry.data.get(CONF_DEVICE_NAME)
unique_id_is_mac_address = unique_id and ":" in unique_id
try:
results = await asyncio.gather(
cli.device_info(),
cli.list_entities_services(),
)
except APIConnectionError as err:
_LOGGER.warning("Error getting device info for %s: %s", self.host, err)
# Re-connection logic will trigger after this
await cli.disconnect()
return
results = await asyncio.gather(
cli.device_info(),
cli.list_entities_services(),
)
device_info: EsphomeDeviceInfo = results[0]
entity_infos_services: tuple[list[EntityInfo], list[UserService]] = results[1]
@ -487,18 +492,12 @@ class ESPHomeManager:
)
)
try:
setup_results = await asyncio.gather(
*setup_coros_with_disconnect_callbacks,
cli.subscribe_states(entry_data.async_update_state),
cli.subscribe_service_calls(self.async_on_service_call),
cli.subscribe_home_assistant_states(self.async_on_state_subscription),
)
except APIConnectionError as err:
_LOGGER.warning("Error getting initial data for %s: %s", self.host, err)
# Re-connection logic will trigger after this
await cli.disconnect()
return
setup_results = await asyncio.gather(
*setup_coros_with_disconnect_callbacks,
cli.subscribe_states(entry_data.async_update_state),
cli.subscribe_service_calls(self.async_on_service_call),
cli.subscribe_home_assistant_states(self.async_on_state_subscription),
)
for result_idx in range(len(setup_coros_with_disconnect_callbacks)):
cancel_callback = setup_results[result_idx]

View file

@ -90,6 +90,10 @@ class BaseMockReconnectLogic(ReconnectLogic):
self._cancel_connect("forced disconnect from test")
self._is_stopped = True
async def stop(self) -> None:
"""Stop the reconnect logic."""
self.stop_callback()
@pytest.fixture
def mock_device_info() -> DeviceInfo:

View file

@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, call
from aioesphomeapi import (
APIClient,
APIConnectionError,
DeviceInfo,
EntityInfo,
EntityState,
@ -510,8 +511,11 @@ async def test_connection_aborted_wrong_device(
"with mac address `11:22:33:44:55:ab`" in caplog.text
)
assert "Error getting setting up connection for" not in caplog.text
assert len(mock_client.disconnect.mock_calls) == 1
mock_client.disconnect.reset_mock()
caplog.clear()
# Make sure discovery triggers a reconnect to the correct device
# Make sure discovery triggers a reconnect
service_info = dhcp.DhcpServiceInfo(
ip="192.168.43.184",
hostname="test",
@ -533,6 +537,35 @@ async def test_connection_aborted_wrong_device(
assert "Unexpected device found at" not in caplog.text
async def test_failure_during_connect(
hass: HomeAssistant,
mock_client: APIClient,
mock_zeroconf: None,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test we disconnect when there is a failure during connection setup."""
entry = MockConfigEntry(
domain=DOMAIN,
data={
CONF_HOST: "192.168.43.183",
CONF_PORT: 6053,
CONF_PASSWORD: "",
CONF_DEVICE_NAME: "test",
},
unique_id="11:22:33:44:55:aa",
)
entry.add_to_hass(hass)
mock_client.device_info = AsyncMock(side_effect=APIConnectionError("fail"))
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert "Error getting setting up connection for" in caplog.text
# Ensure we disconnect so that the reconnect logic is triggered
assert len(mock_client.disconnect.mock_calls) == 1
async def test_state_subscription(
mock_client: APIClient,
hass: HomeAssistant,