Refactor ESPHome connection management logic into a class (#95457)

* Refactor ESPHome setup logic into a class

Avoids all the nonlocals and fixes the C901

* cleanup

* touch ups

* touch ups

* touch ups

* make easier to read

* stale
This commit is contained in:
J. Nick Koston 2023-06-28 20:39:31 -05:00 committed by GitHub
parent a7dfe46fb1
commit dfe7c5ebed
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -137,57 +137,60 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
async def async_setup_entry( # noqa: C901
hass: HomeAssistant, entry: ConfigEntry
) -> bool:
"""Set up the esphome component."""
host = entry.data[CONF_HOST]
port = entry.data[CONF_PORT]
password = entry.data[CONF_PASSWORD]
noise_psk = entry.data.get(CONF_NOISE_PSK)
device_id: str = None # type: ignore[assignment]
class ESPHomeManager:
"""Class to manage an ESPHome connection."""
zeroconf_instance = await zeroconf.async_get_instance(hass)
cli = APIClient(
host,
port,
password,
client_info=f"Home Assistant {ha_version}",
zeroconf_instance=zeroconf_instance,
noise_psk=noise_psk,
__slots__ = (
"hass",
"host",
"password",
"entry",
"cli",
"device_id",
"domain_data",
"voice_assistant_udp_server",
"reconnect_logic",
"zeroconf_instance",
"entry_data",
)
services_issue = f"service_calls_not_enabled-{entry.unique_id}"
if entry.options.get(CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS):
async_delete_issue(hass, DOMAIN, services_issue)
def __init__(
self,
hass: HomeAssistant,
entry: ConfigEntry,
host: str,
password: str | None,
cli: APIClient,
zeroconf_instance: zeroconf.HaZeroconf,
domain_data: DomainData,
entry_data: RuntimeEntryData,
) -> None:
"""Initialize the esphome manager."""
self.hass = hass
self.host = host
self.password = password
self.entry = entry
self.cli = cli
self.device_id: str | None = None
self.domain_data = domain_data
self.voice_assistant_udp_server: VoiceAssistantUDPServer | None = None
self.reconnect_logic: ReconnectLogic | None = None
self.zeroconf_instance = zeroconf_instance
self.entry_data = entry_data
domain_data = DomainData.get(hass)
entry_data = RuntimeEntryData(
client=cli,
entry_id=entry.entry_id,
store=domain_data.get_or_create_store(hass, entry),
original_options=dict(entry.options),
)
domain_data.set_entry_data(entry, entry_data)
async def on_stop(event: Event) -> None:
async def on_stop(self, event: Event) -> None:
"""Cleanup the socket client on HA stop."""
await _cleanup_instance(hass, entry)
await _cleanup_instance(self.hass, self.entry)
# Use async_listen instead of async_listen_once so that we don't deregister
# the callback twice when shutting down Home Assistant.
# "Unable to remove unknown listener
# <function EventBus.async_listen_once.<locals>.onetime_listener>"
entry_data.cleanup_callbacks.append(
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, on_stop)
)
@property
def services_issue(self) -> str:
"""Return the services issue name for this entry."""
return f"service_calls_not_enabled-{self.entry.unique_id}"
@callback
def async_on_service_call(service: HomeassistantServiceCall) -> None:
def async_on_service_call(self, service: HomeassistantServiceCall) -> None:
"""Call service when user automation in ESPHome config is triggered."""
device_info = entry_data.device_info
assert device_info is not None
hass = self.hass
domain, service_name = service.service.split(".", 1)
service_data = service.data
@ -201,15 +204,16 @@ async def async_setup_entry( # noqa: C901
template.render_complex(data_template, service.variables)
)
except TemplateError as ex:
_LOGGER.error("Error rendering data template for %s: %s", host, ex)
_LOGGER.error("Error rendering data template for %s: %s", self.host, ex)
return
if service.is_event:
device_id = self.device_id
# ESPHome uses service call packet for both events and service calls
# Ensure the user can only send events of form 'esphome.xyz'
if domain != "esphome":
_LOGGER.error(
"Can only generate events under esphome domain! (%s)", host
"Can only generate events under esphome domain! (%s)", self.host
)
return
@ -226,17 +230,21 @@ async def async_setup_entry( # noqa: C901
**service_data,
},
)
elif entry.options.get(CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS):
elif self.entry.options.get(
CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS
):
hass.async_create_task(
hass.services.async_call(
domain, service_name, service_data, blocking=True
)
)
else:
device_info = self.entry_data.device_info
assert device_info is not None
async_create_issue(
hass,
DOMAIN,
services_issue,
self.services_issue,
is_fixable=False,
severity=IssueSeverity.WARNING,
translation_key="service_calls_not_allowed",
@ -256,7 +264,7 @@ async def async_setup_entry( # noqa: C901
)
async def _send_home_assistant_state(
entity_id: str, attribute: str | None, state: State | None
self, entity_id: str, attribute: str | None, state: State | None
) -> None:
"""Forward Home Assistant states to ESPHome."""
if state is None or (attribute and attribute not in state.attributes):
@ -271,102 +279,102 @@ async def async_setup_entry( # noqa: C901
else:
send_state = attr_val
await cli.send_home_assistant_state(entity_id, attribute, str(send_state))
await self.cli.send_home_assistant_state(entity_id, attribute, str(send_state))
@callback
def async_on_state_subscription(
entity_id: str, attribute: str | None = None
self, entity_id: str, attribute: str | None = None
) -> None:
"""Subscribe and forward states for requested entities."""
hass = self.hass
async def send_home_assistant_state_event(event: Event) -> None:
"""Forward Home Assistant states updates to ESPHome."""
event_data = event.data
new_state: State | None = event_data.get("new_state")
old_state: State | None = event_data.get("old_state")
if new_state is None or old_state is None:
return
# Only communicate changes to the state or attribute tracked
if event.data.get("new_state") is None or (
event.data.get("old_state") is not None
and "new_state" in event.data
and (
(
not attribute
and event.data["old_state"].state
== event.data["new_state"].state
)
or (
attribute
and attribute in event.data["old_state"].attributes
and attribute in event.data["new_state"].attributes
and event.data["old_state"].attributes[attribute]
== event.data["new_state"].attributes[attribute]
)
)
if (not attribute and old_state.state == new_state.state) or (
attribute
and old_state.attributes.get(attribute)
== new_state.attributes.get(attribute)
):
return
await _send_home_assistant_state(
event.data["entity_id"], attribute, event.data.get("new_state")
await self._send_home_assistant_state(
event.data["entity_id"], attribute, new_state
)
unsub = async_track_state_change_event(
hass, [entity_id], send_home_assistant_state_event
self.entry_data.disconnect_callbacks.append(
async_track_state_change_event(
hass, [entity_id], send_home_assistant_state_event
)
)
entry_data.disconnect_callbacks.append(unsub)
# Send initial state
hass.async_create_task(
_send_home_assistant_state(entity_id, attribute, hass.states.get(entity_id))
self._send_home_assistant_state(
entity_id, attribute, hass.states.get(entity_id)
)
)
voice_assistant_udp_server: VoiceAssistantUDPServer | None = None
def _handle_pipeline_event(
event_type: VoiceAssistantEventType, data: dict[str, str] | None
self, event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None:
cli.send_voice_assistant_event(event_type, data)
self.cli.send_voice_assistant_event(event_type, data)
def _handle_pipeline_finished() -> None:
nonlocal voice_assistant_udp_server
def _handle_pipeline_finished(self) -> None:
self.entry_data.async_set_assist_pipeline_state(False)
entry_data.async_set_assist_pipeline_state(False)
if self.voice_assistant_udp_server is not None:
self.voice_assistant_udp_server.close()
self.voice_assistant_udp_server = None
if voice_assistant_udp_server is not None:
voice_assistant_udp_server.close()
voice_assistant_udp_server = None
async def _handle_pipeline_start(conversation_id: str, use_vad: bool) -> int | None:
async def _handle_pipeline_start(
self, conversation_id: str, use_vad: bool
) -> int | None:
"""Start a voice assistant pipeline."""
nonlocal voice_assistant_udp_server
if voice_assistant_udp_server is not None:
if self.voice_assistant_udp_server is not None:
return None
hass = self.hass
voice_assistant_udp_server = VoiceAssistantUDPServer(
hass, entry_data, _handle_pipeline_event, _handle_pipeline_finished
hass,
self.entry_data,
self._handle_pipeline_event,
self._handle_pipeline_finished,
)
port = await voice_assistant_udp_server.start_server()
assert self.device_id is not None, "Device ID must be set"
hass.async_create_background_task(
voice_assistant_udp_server.run_pipeline(
device_id=device_id,
device_id=self.device_id,
conversation_id=conversation_id or None,
use_vad=use_vad,
),
"esphome.voice_assistant_udp_server.run_pipeline",
)
entry_data.async_set_assist_pipeline_state(True)
self.entry_data.async_set_assist_pipeline_state(True)
return port
async def _handle_pipeline_stop() -> None:
async def _handle_pipeline_stop(self) -> None:
"""Stop a voice assistant pipeline."""
nonlocal voice_assistant_udp_server
if self.voice_assistant_udp_server is not None:
self.voice_assistant_udp_server.stop()
if voice_assistant_udp_server is not None:
voice_assistant_udp_server.stop()
async def on_connect() -> None:
async def on_connect(self) -> None:
"""Subscribe to states and list entities on successful API login."""
nonlocal device_id
entry = self.entry
entry_data = self.entry_data
reconnect_logic = self.reconnect_logic
hass = self.hass
cli = self.cli
try:
device_info = await cli.device_info()
@ -389,6 +397,7 @@ async def async_setup_entry( # noqa: C901
entry_data.api_version = cli.api_version
entry_data.available = True
if entry_data.device_info.name:
assert reconnect_logic is not None, "Reconnect logic must be set"
reconnect_logic.name = entry_data.device_info.name
if device_info.bluetooth_proxy_feature_flags_compat(cli.api_version):
@ -396,37 +405,38 @@ async def async_setup_entry( # noqa: C901
await async_connect_scanner(hass, entry, cli, entry_data)
)
device_id = _async_setup_device_registry(
hass, entry, entry_data.device_info
)
_async_setup_device_registry(hass, entry, entry_data.device_info)
entry_data.async_update_device_state(hass)
entity_infos, services = await cli.list_entities_services()
await entry_data.async_update_static_infos(hass, entry, entity_infos)
await _setup_services(hass, entry_data, services)
await cli.subscribe_states(entry_data.async_update_state)
await cli.subscribe_service_calls(async_on_service_call)
await cli.subscribe_home_assistant_states(async_on_state_subscription)
await cli.subscribe_service_calls(self.async_on_service_call)
await cli.subscribe_home_assistant_states(self.async_on_state_subscription)
if device_info.voice_assistant_version:
entry_data.disconnect_callbacks.append(
await cli.subscribe_voice_assistant(
_handle_pipeline_start,
_handle_pipeline_stop,
self._handle_pipeline_start,
self._handle_pipeline_stop,
)
)
hass.async_create_task(entry_data.async_save_to_store())
except APIConnectionError as err:
_LOGGER.warning("Error getting initial data for %s: %s", host, err)
_LOGGER.warning("Error getting initial data for %s: %s", self.host, err)
# Re-connection logic will trigger after this
await cli.disconnect()
else:
_async_check_firmware_version(hass, device_info, entry_data.api_version)
_async_check_using_api_password(hass, device_info, bool(password))
_async_check_using_api_password(hass, device_info, bool(self.password))
async def on_disconnect(expected_disconnect: bool) -> None:
async def on_disconnect(self, expected_disconnect: bool) -> None:
"""Run disconnect callbacks on API disconnect."""
entry_data = self.entry_data
hass = self.hass
host = self.host
name = entry_data.device_info.name if entry_data.device_info else host
_LOGGER.debug(
"%s: %s disconnected (expected=%s), running disconnected callbacks",
@ -453,7 +463,7 @@ async def async_setup_entry( # noqa: C901
# will be cleared anyway.
entry_data.async_update_device_state(hass)
async def on_connect_error(err: Exception) -> None:
async def on_connect_error(self, err: Exception) -> None:
"""Start reauth flow if appropriate connect error type."""
if isinstance(
err,
@ -463,32 +473,85 @@ async def async_setup_entry( # noqa: C901
InvalidAuthAPIError,
),
):
entry.async_start_reauth(hass)
self.entry.async_start_reauth(self.hass)
reconnect_logic = ReconnectLogic(
client=cli,
on_connect=on_connect,
on_disconnect=on_disconnect,
async def async_start(self) -> None:
"""Start the esphome connection manager."""
hass = self.hass
entry = self.entry
entry_data = self.entry_data
if entry.options.get(CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS):
async_delete_issue(hass, DOMAIN, self.services_issue)
# Use async_listen instead of async_listen_once so that we don't deregister
# the callback twice when shutting down Home Assistant.
# "Unable to remove unknown listener
# <function EventBus.async_listen_once.<locals>.onetime_listener>"
entry_data.cleanup_callbacks.append(
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, self.on_stop)
)
reconnect_logic = ReconnectLogic(
client=self.cli,
on_connect=self.on_connect,
on_disconnect=self.on_disconnect,
zeroconf_instance=self.zeroconf_instance,
name=self.host,
on_connect_error=self.on_connect_error,
)
self.reconnect_logic = reconnect_logic
infos, services = await entry_data.async_load_from_store()
await entry_data.async_update_static_infos(hass, entry, infos)
await _setup_services(hass, entry_data, services)
if entry_data.device_info is not None and entry_data.device_info.name:
reconnect_logic.name = entry_data.device_info.name
if entry.unique_id is None:
hass.config_entries.async_update_entry(
entry, unique_id=format_mac(entry_data.device_info.mac_address)
)
await reconnect_logic.start()
entry_data.cleanup_callbacks.append(reconnect_logic.stop_callback)
entry.async_on_unload(
entry.add_update_listener(entry_data.async_update_listener)
)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up the esphome component."""
host = entry.data[CONF_HOST]
port = entry.data[CONF_PORT]
password = entry.data[CONF_PASSWORD]
noise_psk = entry.data.get(CONF_NOISE_PSK)
zeroconf_instance = await zeroconf.async_get_instance(hass)
cli = APIClient(
host,
port,
password,
client_info=f"Home Assistant {ha_version}",
zeroconf_instance=zeroconf_instance,
name=host,
on_connect_error=on_connect_error,
noise_psk=noise_psk,
)
infos, services = await entry_data.async_load_from_store()
await entry_data.async_update_static_infos(hass, entry, infos)
await _setup_services(hass, entry_data, services)
domain_data = DomainData.get(hass)
entry_data = RuntimeEntryData(
client=cli,
entry_id=entry.entry_id,
store=domain_data.get_or_create_store(hass, entry),
original_options=dict(entry.options),
)
domain_data.set_entry_data(entry, entry_data)
if entry_data.device_info is not None and entry_data.device_info.name:
reconnect_logic.name = entry_data.device_info.name
if entry.unique_id is None:
hass.config_entries.async_update_entry(
entry, unique_id=format_mac(entry_data.device_info.mac_address)
)
await reconnect_logic.start()
entry_data.cleanup_callbacks.append(reconnect_logic.stop_callback)
entry.async_on_unload(entry.add_update_listener(entry_data.async_update_listener))
manager = ESPHomeManager(
hass, entry, host, password, cli, zeroconf_instance, domain_data, entry_data
)
await manager.async_start()
return True