From dfe7c5ebed452d12f48bc5d53301218a9fd1faf3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 28 Jun 2023 20:39:31 -0500 Subject: [PATCH] Refactor ESPHome connection management logic into a class (#95457) * Refactor ESPHome setup logic into a class Avoids all the nonlocals and fixes the C901 * cleanup * touch ups * touch ups * touch ups * make easier to read * stale --- homeassistant/components/esphome/__init__.py | 321 +++++++++++-------- 1 file changed, 192 insertions(+), 129 deletions(-) diff --git a/homeassistant/components/esphome/__init__.py b/homeassistant/components/esphome/__init__.py index afaefe117ba..271b0b9aa16 100644 --- a/homeassistant/components/esphome/__init__.py +++ b/homeassistant/components/esphome/__init__.py @@ -137,57 +137,60 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -async def async_setup_entry( # noqa: C901 - hass: HomeAssistant, entry: ConfigEntry -) -> bool: - """Set up the esphome component.""" - host = entry.data[CONF_HOST] - port = entry.data[CONF_PORT] - password = entry.data[CONF_PASSWORD] - noise_psk = entry.data.get(CONF_NOISE_PSK) - device_id: str = None # type: ignore[assignment] +class ESPHomeManager: + """Class to manage an ESPHome connection.""" - zeroconf_instance = await zeroconf.async_get_instance(hass) - - cli = APIClient( - host, - port, - password, - client_info=f"Home Assistant {ha_version}", - zeroconf_instance=zeroconf_instance, - noise_psk=noise_psk, + __slots__ = ( + "hass", + "host", + "password", + "entry", + "cli", + "device_id", + "domain_data", + "voice_assistant_udp_server", + "reconnect_logic", + "zeroconf_instance", + "entry_data", ) - services_issue = f"service_calls_not_enabled-{entry.unique_id}" - if entry.options.get(CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS): - async_delete_issue(hass, DOMAIN, services_issue) + def __init__( + self, + hass: HomeAssistant, + entry: ConfigEntry, + host: str, + password: str | None, + cli: APIClient, + zeroconf_instance: zeroconf.HaZeroconf, + domain_data: DomainData, + entry_data: RuntimeEntryData, + ) -> None: + """Initialize the esphome manager.""" + self.hass = hass + self.host = host + self.password = password + self.entry = entry + self.cli = cli + self.device_id: str | None = None + self.domain_data = domain_data + self.voice_assistant_udp_server: VoiceAssistantUDPServer | None = None + self.reconnect_logic: ReconnectLogic | None = None + self.zeroconf_instance = zeroconf_instance + self.entry_data = entry_data - domain_data = DomainData.get(hass) - entry_data = RuntimeEntryData( - client=cli, - entry_id=entry.entry_id, - store=domain_data.get_or_create_store(hass, entry), - original_options=dict(entry.options), - ) - domain_data.set_entry_data(entry, entry_data) - - async def on_stop(event: Event) -> None: + async def on_stop(self, event: Event) -> None: """Cleanup the socket client on HA stop.""" - await _cleanup_instance(hass, entry) + await _cleanup_instance(self.hass, self.entry) - # Use async_listen instead of async_listen_once so that we don't deregister - # the callback twice when shutting down Home Assistant. - # "Unable to remove unknown listener - # .onetime_listener>" - entry_data.cleanup_callbacks.append( - hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, on_stop) - ) + @property + def services_issue(self) -> str: + """Return the services issue name for this entry.""" + return f"service_calls_not_enabled-{self.entry.unique_id}" @callback - def async_on_service_call(service: HomeassistantServiceCall) -> None: + def async_on_service_call(self, service: HomeassistantServiceCall) -> None: """Call service when user automation in ESPHome config is triggered.""" - device_info = entry_data.device_info - assert device_info is not None + hass = self.hass domain, service_name = service.service.split(".", 1) service_data = service.data @@ -201,15 +204,16 @@ async def async_setup_entry( # noqa: C901 template.render_complex(data_template, service.variables) ) except TemplateError as ex: - _LOGGER.error("Error rendering data template for %s: %s", host, ex) + _LOGGER.error("Error rendering data template for %s: %s", self.host, ex) return if service.is_event: + device_id = self.device_id # ESPHome uses service call packet for both events and service calls # Ensure the user can only send events of form 'esphome.xyz' if domain != "esphome": _LOGGER.error( - "Can only generate events under esphome domain! (%s)", host + "Can only generate events under esphome domain! (%s)", self.host ) return @@ -226,17 +230,21 @@ async def async_setup_entry( # noqa: C901 **service_data, }, ) - elif entry.options.get(CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS): + elif self.entry.options.get( + CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS + ): hass.async_create_task( hass.services.async_call( domain, service_name, service_data, blocking=True ) ) else: + device_info = self.entry_data.device_info + assert device_info is not None async_create_issue( hass, DOMAIN, - services_issue, + self.services_issue, is_fixable=False, severity=IssueSeverity.WARNING, translation_key="service_calls_not_allowed", @@ -256,7 +264,7 @@ async def async_setup_entry( # noqa: C901 ) async def _send_home_assistant_state( - entity_id: str, attribute: str | None, state: State | None + self, entity_id: str, attribute: str | None, state: State | None ) -> None: """Forward Home Assistant states to ESPHome.""" if state is None or (attribute and attribute not in state.attributes): @@ -271,102 +279,102 @@ async def async_setup_entry( # noqa: C901 else: send_state = attr_val - await cli.send_home_assistant_state(entity_id, attribute, str(send_state)) + await self.cli.send_home_assistant_state(entity_id, attribute, str(send_state)) @callback def async_on_state_subscription( - entity_id: str, attribute: str | None = None + self, entity_id: str, attribute: str | None = None ) -> None: """Subscribe and forward states for requested entities.""" + hass = self.hass async def send_home_assistant_state_event(event: Event) -> None: """Forward Home Assistant states updates to ESPHome.""" + event_data = event.data + new_state: State | None = event_data.get("new_state") + old_state: State | None = event_data.get("old_state") + + if new_state is None or old_state is None: + return # Only communicate changes to the state or attribute tracked - if event.data.get("new_state") is None or ( - event.data.get("old_state") is not None - and "new_state" in event.data - and ( - ( - not attribute - and event.data["old_state"].state - == event.data["new_state"].state - ) - or ( - attribute - and attribute in event.data["old_state"].attributes - and attribute in event.data["new_state"].attributes - and event.data["old_state"].attributes[attribute] - == event.data["new_state"].attributes[attribute] - ) - ) + if (not attribute and old_state.state == new_state.state) or ( + attribute + and old_state.attributes.get(attribute) + == new_state.attributes.get(attribute) ): return - await _send_home_assistant_state( - event.data["entity_id"], attribute, event.data.get("new_state") + await self._send_home_assistant_state( + event.data["entity_id"], attribute, new_state ) - unsub = async_track_state_change_event( - hass, [entity_id], send_home_assistant_state_event + self.entry_data.disconnect_callbacks.append( + async_track_state_change_event( + hass, [entity_id], send_home_assistant_state_event + ) ) - entry_data.disconnect_callbacks.append(unsub) # Send initial state hass.async_create_task( - _send_home_assistant_state(entity_id, attribute, hass.states.get(entity_id)) + self._send_home_assistant_state( + entity_id, attribute, hass.states.get(entity_id) + ) ) - voice_assistant_udp_server: VoiceAssistantUDPServer | None = None - def _handle_pipeline_event( - event_type: VoiceAssistantEventType, data: dict[str, str] | None + self, event_type: VoiceAssistantEventType, data: dict[str, str] | None ) -> None: - cli.send_voice_assistant_event(event_type, data) + self.cli.send_voice_assistant_event(event_type, data) - def _handle_pipeline_finished() -> None: - nonlocal voice_assistant_udp_server + def _handle_pipeline_finished(self) -> None: + self.entry_data.async_set_assist_pipeline_state(False) - entry_data.async_set_assist_pipeline_state(False) + if self.voice_assistant_udp_server is not None: + self.voice_assistant_udp_server.close() + self.voice_assistant_udp_server = None - if voice_assistant_udp_server is not None: - voice_assistant_udp_server.close() - voice_assistant_udp_server = None - - async def _handle_pipeline_start(conversation_id: str, use_vad: bool) -> int | None: + async def _handle_pipeline_start( + self, conversation_id: str, use_vad: bool + ) -> int | None: """Start a voice assistant pipeline.""" - nonlocal voice_assistant_udp_server - - if voice_assistant_udp_server is not None: + if self.voice_assistant_udp_server is not None: return None + hass = self.hass voice_assistant_udp_server = VoiceAssistantUDPServer( - hass, entry_data, _handle_pipeline_event, _handle_pipeline_finished + hass, + self.entry_data, + self._handle_pipeline_event, + self._handle_pipeline_finished, ) port = await voice_assistant_udp_server.start_server() + assert self.device_id is not None, "Device ID must be set" hass.async_create_background_task( voice_assistant_udp_server.run_pipeline( - device_id=device_id, + device_id=self.device_id, conversation_id=conversation_id or None, use_vad=use_vad, ), "esphome.voice_assistant_udp_server.run_pipeline", ) - entry_data.async_set_assist_pipeline_state(True) + self.entry_data.async_set_assist_pipeline_state(True) return port - async def _handle_pipeline_stop() -> None: + async def _handle_pipeline_stop(self) -> None: """Stop a voice assistant pipeline.""" - nonlocal voice_assistant_udp_server + if self.voice_assistant_udp_server is not None: + self.voice_assistant_udp_server.stop() - if voice_assistant_udp_server is not None: - voice_assistant_udp_server.stop() - - async def on_connect() -> None: + async def on_connect(self) -> None: """Subscribe to states and list entities on successful API login.""" - nonlocal device_id + entry = self.entry + entry_data = self.entry_data + reconnect_logic = self.reconnect_logic + hass = self.hass + cli = self.cli try: device_info = await cli.device_info() @@ -389,6 +397,7 @@ async def async_setup_entry( # noqa: C901 entry_data.api_version = cli.api_version entry_data.available = True if entry_data.device_info.name: + assert reconnect_logic is not None, "Reconnect logic must be set" reconnect_logic.name = entry_data.device_info.name if device_info.bluetooth_proxy_feature_flags_compat(cli.api_version): @@ -396,37 +405,38 @@ async def async_setup_entry( # noqa: C901 await async_connect_scanner(hass, entry, cli, entry_data) ) - device_id = _async_setup_device_registry( - hass, entry, entry_data.device_info - ) + _async_setup_device_registry(hass, entry, entry_data.device_info) entry_data.async_update_device_state(hass) entity_infos, services = await cli.list_entities_services() await entry_data.async_update_static_infos(hass, entry, entity_infos) await _setup_services(hass, entry_data, services) await cli.subscribe_states(entry_data.async_update_state) - await cli.subscribe_service_calls(async_on_service_call) - await cli.subscribe_home_assistant_states(async_on_state_subscription) + await cli.subscribe_service_calls(self.async_on_service_call) + await cli.subscribe_home_assistant_states(self.async_on_state_subscription) if device_info.voice_assistant_version: entry_data.disconnect_callbacks.append( await cli.subscribe_voice_assistant( - _handle_pipeline_start, - _handle_pipeline_stop, + self._handle_pipeline_start, + self._handle_pipeline_stop, ) ) hass.async_create_task(entry_data.async_save_to_store()) except APIConnectionError as err: - _LOGGER.warning("Error getting initial data for %s: %s", host, err) + _LOGGER.warning("Error getting initial data for %s: %s", self.host, err) # Re-connection logic will trigger after this await cli.disconnect() else: _async_check_firmware_version(hass, device_info, entry_data.api_version) - _async_check_using_api_password(hass, device_info, bool(password)) + _async_check_using_api_password(hass, device_info, bool(self.password)) - async def on_disconnect(expected_disconnect: bool) -> None: + async def on_disconnect(self, expected_disconnect: bool) -> None: """Run disconnect callbacks on API disconnect.""" + entry_data = self.entry_data + hass = self.hass + host = self.host name = entry_data.device_info.name if entry_data.device_info else host _LOGGER.debug( "%s: %s disconnected (expected=%s), running disconnected callbacks", @@ -453,7 +463,7 @@ async def async_setup_entry( # noqa: C901 # will be cleared anyway. entry_data.async_update_device_state(hass) - async def on_connect_error(err: Exception) -> None: + async def on_connect_error(self, err: Exception) -> None: """Start reauth flow if appropriate connect error type.""" if isinstance( err, @@ -463,32 +473,85 @@ async def async_setup_entry( # noqa: C901 InvalidAuthAPIError, ), ): - entry.async_start_reauth(hass) + self.entry.async_start_reauth(self.hass) - reconnect_logic = ReconnectLogic( - client=cli, - on_connect=on_connect, - on_disconnect=on_disconnect, + async def async_start(self) -> None: + """Start the esphome connection manager.""" + hass = self.hass + entry = self.entry + entry_data = self.entry_data + + if entry.options.get(CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS): + async_delete_issue(hass, DOMAIN, self.services_issue) + + # Use async_listen instead of async_listen_once so that we don't deregister + # the callback twice when shutting down Home Assistant. + # "Unable to remove unknown listener + # .onetime_listener>" + entry_data.cleanup_callbacks.append( + hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, self.on_stop) + ) + + reconnect_logic = ReconnectLogic( + client=self.cli, + on_connect=self.on_connect, + on_disconnect=self.on_disconnect, + zeroconf_instance=self.zeroconf_instance, + name=self.host, + on_connect_error=self.on_connect_error, + ) + self.reconnect_logic = reconnect_logic + + infos, services = await entry_data.async_load_from_store() + await entry_data.async_update_static_infos(hass, entry, infos) + await _setup_services(hass, entry_data, services) + + if entry_data.device_info is not None and entry_data.device_info.name: + reconnect_logic.name = entry_data.device_info.name + if entry.unique_id is None: + hass.config_entries.async_update_entry( + entry, unique_id=format_mac(entry_data.device_info.mac_address) + ) + + await reconnect_logic.start() + entry_data.cleanup_callbacks.append(reconnect_logic.stop_callback) + + entry.async_on_unload( + entry.add_update_listener(entry_data.async_update_listener) + ) + + +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Set up the esphome component.""" + host = entry.data[CONF_HOST] + port = entry.data[CONF_PORT] + password = entry.data[CONF_PASSWORD] + noise_psk = entry.data.get(CONF_NOISE_PSK) + + zeroconf_instance = await zeroconf.async_get_instance(hass) + + cli = APIClient( + host, + port, + password, + client_info=f"Home Assistant {ha_version}", zeroconf_instance=zeroconf_instance, - name=host, - on_connect_error=on_connect_error, + noise_psk=noise_psk, ) - infos, services = await entry_data.async_load_from_store() - await entry_data.async_update_static_infos(hass, entry, infos) - await _setup_services(hass, entry_data, services) + domain_data = DomainData.get(hass) + entry_data = RuntimeEntryData( + client=cli, + entry_id=entry.entry_id, + store=domain_data.get_or_create_store(hass, entry), + original_options=dict(entry.options), + ) + domain_data.set_entry_data(entry, entry_data) - if entry_data.device_info is not None and entry_data.device_info.name: - reconnect_logic.name = entry_data.device_info.name - if entry.unique_id is None: - hass.config_entries.async_update_entry( - entry, unique_id=format_mac(entry_data.device_info.mac_address) - ) - - await reconnect_logic.start() - entry_data.cleanup_callbacks.append(reconnect_logic.stop_callback) - - entry.async_on_unload(entry.add_update_listener(entry_data.async_update_listener)) + manager = ESPHomeManager( + hass, entry, host, password, cli, zeroconf_instance, domain_data, entry_data + ) + await manager.async_start() return True