Refactor ESPHome connection management logic into a class (#95457)
* Refactor ESPHome setup logic into a class Avoids all the nonlocals and fixes the C901 * cleanup * touch ups * touch ups * touch ups * make easier to read * stale
This commit is contained in:
parent
a7dfe46fb1
commit
dfe7c5ebed
1 changed files with 192 additions and 129 deletions
|
@ -137,57 +137,60 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
async def async_setup_entry( # noqa: C901
|
||||
hass: HomeAssistant, entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Set up the esphome component."""
|
||||
host = entry.data[CONF_HOST]
|
||||
port = entry.data[CONF_PORT]
|
||||
password = entry.data[CONF_PASSWORD]
|
||||
noise_psk = entry.data.get(CONF_NOISE_PSK)
|
||||
device_id: str = None # type: ignore[assignment]
|
||||
class ESPHomeManager:
|
||||
"""Class to manage an ESPHome connection."""
|
||||
|
||||
zeroconf_instance = await zeroconf.async_get_instance(hass)
|
||||
|
||||
cli = APIClient(
|
||||
host,
|
||||
port,
|
||||
password,
|
||||
client_info=f"Home Assistant {ha_version}",
|
||||
zeroconf_instance=zeroconf_instance,
|
||||
noise_psk=noise_psk,
|
||||
__slots__ = (
|
||||
"hass",
|
||||
"host",
|
||||
"password",
|
||||
"entry",
|
||||
"cli",
|
||||
"device_id",
|
||||
"domain_data",
|
||||
"voice_assistant_udp_server",
|
||||
"reconnect_logic",
|
||||
"zeroconf_instance",
|
||||
"entry_data",
|
||||
)
|
||||
|
||||
services_issue = f"service_calls_not_enabled-{entry.unique_id}"
|
||||
if entry.options.get(CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS):
|
||||
async_delete_issue(hass, DOMAIN, services_issue)
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
entry: ConfigEntry,
|
||||
host: str,
|
||||
password: str | None,
|
||||
cli: APIClient,
|
||||
zeroconf_instance: zeroconf.HaZeroconf,
|
||||
domain_data: DomainData,
|
||||
entry_data: RuntimeEntryData,
|
||||
) -> None:
|
||||
"""Initialize the esphome manager."""
|
||||
self.hass = hass
|
||||
self.host = host
|
||||
self.password = password
|
||||
self.entry = entry
|
||||
self.cli = cli
|
||||
self.device_id: str | None = None
|
||||
self.domain_data = domain_data
|
||||
self.voice_assistant_udp_server: VoiceAssistantUDPServer | None = None
|
||||
self.reconnect_logic: ReconnectLogic | None = None
|
||||
self.zeroconf_instance = zeroconf_instance
|
||||
self.entry_data = entry_data
|
||||
|
||||
domain_data = DomainData.get(hass)
|
||||
entry_data = RuntimeEntryData(
|
||||
client=cli,
|
||||
entry_id=entry.entry_id,
|
||||
store=domain_data.get_or_create_store(hass, entry),
|
||||
original_options=dict(entry.options),
|
||||
)
|
||||
domain_data.set_entry_data(entry, entry_data)
|
||||
|
||||
async def on_stop(event: Event) -> None:
|
||||
async def on_stop(self, event: Event) -> None:
|
||||
"""Cleanup the socket client on HA stop."""
|
||||
await _cleanup_instance(hass, entry)
|
||||
await _cleanup_instance(self.hass, self.entry)
|
||||
|
||||
# Use async_listen instead of async_listen_once so that we don't deregister
|
||||
# the callback twice when shutting down Home Assistant.
|
||||
# "Unable to remove unknown listener
|
||||
# <function EventBus.async_listen_once.<locals>.onetime_listener>"
|
||||
entry_data.cleanup_callbacks.append(
|
||||
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, on_stop)
|
||||
)
|
||||
@property
|
||||
def services_issue(self) -> str:
|
||||
"""Return the services issue name for this entry."""
|
||||
return f"service_calls_not_enabled-{self.entry.unique_id}"
|
||||
|
||||
@callback
|
||||
def async_on_service_call(service: HomeassistantServiceCall) -> None:
|
||||
def async_on_service_call(self, service: HomeassistantServiceCall) -> None:
|
||||
"""Call service when user automation in ESPHome config is triggered."""
|
||||
device_info = entry_data.device_info
|
||||
assert device_info is not None
|
||||
hass = self.hass
|
||||
domain, service_name = service.service.split(".", 1)
|
||||
service_data = service.data
|
||||
|
||||
|
@ -201,15 +204,16 @@ async def async_setup_entry( # noqa: C901
|
|||
template.render_complex(data_template, service.variables)
|
||||
)
|
||||
except TemplateError as ex:
|
||||
_LOGGER.error("Error rendering data template for %s: %s", host, ex)
|
||||
_LOGGER.error("Error rendering data template for %s: %s", self.host, ex)
|
||||
return
|
||||
|
||||
if service.is_event:
|
||||
device_id = self.device_id
|
||||
# ESPHome uses service call packet for both events and service calls
|
||||
# Ensure the user can only send events of form 'esphome.xyz'
|
||||
if domain != "esphome":
|
||||
_LOGGER.error(
|
||||
"Can only generate events under esphome domain! (%s)", host
|
||||
"Can only generate events under esphome domain! (%s)", self.host
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -226,17 +230,21 @@ async def async_setup_entry( # noqa: C901
|
|||
**service_data,
|
||||
},
|
||||
)
|
||||
elif entry.options.get(CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS):
|
||||
elif self.entry.options.get(
|
||||
CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS
|
||||
):
|
||||
hass.async_create_task(
|
||||
hass.services.async_call(
|
||||
domain, service_name, service_data, blocking=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
device_info = self.entry_data.device_info
|
||||
assert device_info is not None
|
||||
async_create_issue(
|
||||
hass,
|
||||
DOMAIN,
|
||||
services_issue,
|
||||
self.services_issue,
|
||||
is_fixable=False,
|
||||
severity=IssueSeverity.WARNING,
|
||||
translation_key="service_calls_not_allowed",
|
||||
|
@ -256,7 +264,7 @@ async def async_setup_entry( # noqa: C901
|
|||
)
|
||||
|
||||
async def _send_home_assistant_state(
|
||||
entity_id: str, attribute: str | None, state: State | None
|
||||
self, entity_id: str, attribute: str | None, state: State | None
|
||||
) -> None:
|
||||
"""Forward Home Assistant states to ESPHome."""
|
||||
if state is None or (attribute and attribute not in state.attributes):
|
||||
|
@ -271,102 +279,102 @@ async def async_setup_entry( # noqa: C901
|
|||
else:
|
||||
send_state = attr_val
|
||||
|
||||
await cli.send_home_assistant_state(entity_id, attribute, str(send_state))
|
||||
await self.cli.send_home_assistant_state(entity_id, attribute, str(send_state))
|
||||
|
||||
@callback
|
||||
def async_on_state_subscription(
|
||||
entity_id: str, attribute: str | None = None
|
||||
self, entity_id: str, attribute: str | None = None
|
||||
) -> None:
|
||||
"""Subscribe and forward states for requested entities."""
|
||||
hass = self.hass
|
||||
|
||||
async def send_home_assistant_state_event(event: Event) -> None:
|
||||
"""Forward Home Assistant states updates to ESPHome."""
|
||||
event_data = event.data
|
||||
new_state: State | None = event_data.get("new_state")
|
||||
old_state: State | None = event_data.get("old_state")
|
||||
|
||||
if new_state is None or old_state is None:
|
||||
return
|
||||
|
||||
# Only communicate changes to the state or attribute tracked
|
||||
if event.data.get("new_state") is None or (
|
||||
event.data.get("old_state") is not None
|
||||
and "new_state" in event.data
|
||||
and (
|
||||
(
|
||||
not attribute
|
||||
and event.data["old_state"].state
|
||||
== event.data["new_state"].state
|
||||
)
|
||||
or (
|
||||
attribute
|
||||
and attribute in event.data["old_state"].attributes
|
||||
and attribute in event.data["new_state"].attributes
|
||||
and event.data["old_state"].attributes[attribute]
|
||||
== event.data["new_state"].attributes[attribute]
|
||||
)
|
||||
)
|
||||
if (not attribute and old_state.state == new_state.state) or (
|
||||
attribute
|
||||
and old_state.attributes.get(attribute)
|
||||
== new_state.attributes.get(attribute)
|
||||
):
|
||||
return
|
||||
|
||||
await _send_home_assistant_state(
|
||||
event.data["entity_id"], attribute, event.data.get("new_state")
|
||||
await self._send_home_assistant_state(
|
||||
event.data["entity_id"], attribute, new_state
|
||||
)
|
||||
|
||||
unsub = async_track_state_change_event(
|
||||
hass, [entity_id], send_home_assistant_state_event
|
||||
self.entry_data.disconnect_callbacks.append(
|
||||
async_track_state_change_event(
|
||||
hass, [entity_id], send_home_assistant_state_event
|
||||
)
|
||||
)
|
||||
entry_data.disconnect_callbacks.append(unsub)
|
||||
|
||||
# Send initial state
|
||||
hass.async_create_task(
|
||||
_send_home_assistant_state(entity_id, attribute, hass.states.get(entity_id))
|
||||
self._send_home_assistant_state(
|
||||
entity_id, attribute, hass.states.get(entity_id)
|
||||
)
|
||||
)
|
||||
|
||||
voice_assistant_udp_server: VoiceAssistantUDPServer | None = None
|
||||
|
||||
def _handle_pipeline_event(
|
||||
event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
||||
self, event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
||||
) -> None:
|
||||
cli.send_voice_assistant_event(event_type, data)
|
||||
self.cli.send_voice_assistant_event(event_type, data)
|
||||
|
||||
def _handle_pipeline_finished() -> None:
|
||||
nonlocal voice_assistant_udp_server
|
||||
def _handle_pipeline_finished(self) -> None:
|
||||
self.entry_data.async_set_assist_pipeline_state(False)
|
||||
|
||||
entry_data.async_set_assist_pipeline_state(False)
|
||||
if self.voice_assistant_udp_server is not None:
|
||||
self.voice_assistant_udp_server.close()
|
||||
self.voice_assistant_udp_server = None
|
||||
|
||||
if voice_assistant_udp_server is not None:
|
||||
voice_assistant_udp_server.close()
|
||||
voice_assistant_udp_server = None
|
||||
|
||||
async def _handle_pipeline_start(conversation_id: str, use_vad: bool) -> int | None:
|
||||
async def _handle_pipeline_start(
|
||||
self, conversation_id: str, use_vad: bool
|
||||
) -> int | None:
|
||||
"""Start a voice assistant pipeline."""
|
||||
nonlocal voice_assistant_udp_server
|
||||
|
||||
if voice_assistant_udp_server is not None:
|
||||
if self.voice_assistant_udp_server is not None:
|
||||
return None
|
||||
|
||||
hass = self.hass
|
||||
voice_assistant_udp_server = VoiceAssistantUDPServer(
|
||||
hass, entry_data, _handle_pipeline_event, _handle_pipeline_finished
|
||||
hass,
|
||||
self.entry_data,
|
||||
self._handle_pipeline_event,
|
||||
self._handle_pipeline_finished,
|
||||
)
|
||||
port = await voice_assistant_udp_server.start_server()
|
||||
|
||||
assert self.device_id is not None, "Device ID must be set"
|
||||
hass.async_create_background_task(
|
||||
voice_assistant_udp_server.run_pipeline(
|
||||
device_id=device_id,
|
||||
device_id=self.device_id,
|
||||
conversation_id=conversation_id or None,
|
||||
use_vad=use_vad,
|
||||
),
|
||||
"esphome.voice_assistant_udp_server.run_pipeline",
|
||||
)
|
||||
entry_data.async_set_assist_pipeline_state(True)
|
||||
self.entry_data.async_set_assist_pipeline_state(True)
|
||||
|
||||
return port
|
||||
|
||||
async def _handle_pipeline_stop() -> None:
|
||||
async def _handle_pipeline_stop(self) -> None:
|
||||
"""Stop a voice assistant pipeline."""
|
||||
nonlocal voice_assistant_udp_server
|
||||
if self.voice_assistant_udp_server is not None:
|
||||
self.voice_assistant_udp_server.stop()
|
||||
|
||||
if voice_assistant_udp_server is not None:
|
||||
voice_assistant_udp_server.stop()
|
||||
|
||||
async def on_connect() -> None:
|
||||
async def on_connect(self) -> None:
|
||||
"""Subscribe to states and list entities on successful API login."""
|
||||
nonlocal device_id
|
||||
entry = self.entry
|
||||
entry_data = self.entry_data
|
||||
reconnect_logic = self.reconnect_logic
|
||||
hass = self.hass
|
||||
cli = self.cli
|
||||
try:
|
||||
device_info = await cli.device_info()
|
||||
|
||||
|
@ -389,6 +397,7 @@ async def async_setup_entry( # noqa: C901
|
|||
entry_data.api_version = cli.api_version
|
||||
entry_data.available = True
|
||||
if entry_data.device_info.name:
|
||||
assert reconnect_logic is not None, "Reconnect logic must be set"
|
||||
reconnect_logic.name = entry_data.device_info.name
|
||||
|
||||
if device_info.bluetooth_proxy_feature_flags_compat(cli.api_version):
|
||||
|
@ -396,37 +405,38 @@ async def async_setup_entry( # noqa: C901
|
|||
await async_connect_scanner(hass, entry, cli, entry_data)
|
||||
)
|
||||
|
||||
device_id = _async_setup_device_registry(
|
||||
hass, entry, entry_data.device_info
|
||||
)
|
||||
_async_setup_device_registry(hass, entry, entry_data.device_info)
|
||||
entry_data.async_update_device_state(hass)
|
||||
|
||||
entity_infos, services = await cli.list_entities_services()
|
||||
await entry_data.async_update_static_infos(hass, entry, entity_infos)
|
||||
await _setup_services(hass, entry_data, services)
|
||||
await cli.subscribe_states(entry_data.async_update_state)
|
||||
await cli.subscribe_service_calls(async_on_service_call)
|
||||
await cli.subscribe_home_assistant_states(async_on_state_subscription)
|
||||
await cli.subscribe_service_calls(self.async_on_service_call)
|
||||
await cli.subscribe_home_assistant_states(self.async_on_state_subscription)
|
||||
|
||||
if device_info.voice_assistant_version:
|
||||
entry_data.disconnect_callbacks.append(
|
||||
await cli.subscribe_voice_assistant(
|
||||
_handle_pipeline_start,
|
||||
_handle_pipeline_stop,
|
||||
self._handle_pipeline_start,
|
||||
self._handle_pipeline_stop,
|
||||
)
|
||||
)
|
||||
|
||||
hass.async_create_task(entry_data.async_save_to_store())
|
||||
except APIConnectionError as err:
|
||||
_LOGGER.warning("Error getting initial data for %s: %s", host, err)
|
||||
_LOGGER.warning("Error getting initial data for %s: %s", self.host, err)
|
||||
# Re-connection logic will trigger after this
|
||||
await cli.disconnect()
|
||||
else:
|
||||
_async_check_firmware_version(hass, device_info, entry_data.api_version)
|
||||
_async_check_using_api_password(hass, device_info, bool(password))
|
||||
_async_check_using_api_password(hass, device_info, bool(self.password))
|
||||
|
||||
async def on_disconnect(expected_disconnect: bool) -> None:
|
||||
async def on_disconnect(self, expected_disconnect: bool) -> None:
|
||||
"""Run disconnect callbacks on API disconnect."""
|
||||
entry_data = self.entry_data
|
||||
hass = self.hass
|
||||
host = self.host
|
||||
name = entry_data.device_info.name if entry_data.device_info else host
|
||||
_LOGGER.debug(
|
||||
"%s: %s disconnected (expected=%s), running disconnected callbacks",
|
||||
|
@ -453,7 +463,7 @@ async def async_setup_entry( # noqa: C901
|
|||
# will be cleared anyway.
|
||||
entry_data.async_update_device_state(hass)
|
||||
|
||||
async def on_connect_error(err: Exception) -> None:
|
||||
async def on_connect_error(self, err: Exception) -> None:
|
||||
"""Start reauth flow if appropriate connect error type."""
|
||||
if isinstance(
|
||||
err,
|
||||
|
@ -463,32 +473,85 @@ async def async_setup_entry( # noqa: C901
|
|||
InvalidAuthAPIError,
|
||||
),
|
||||
):
|
||||
entry.async_start_reauth(hass)
|
||||
self.entry.async_start_reauth(self.hass)
|
||||
|
||||
reconnect_logic = ReconnectLogic(
|
||||
client=cli,
|
||||
on_connect=on_connect,
|
||||
on_disconnect=on_disconnect,
|
||||
async def async_start(self) -> None:
|
||||
"""Start the esphome connection manager."""
|
||||
hass = self.hass
|
||||
entry = self.entry
|
||||
entry_data = self.entry_data
|
||||
|
||||
if entry.options.get(CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS):
|
||||
async_delete_issue(hass, DOMAIN, self.services_issue)
|
||||
|
||||
# Use async_listen instead of async_listen_once so that we don't deregister
|
||||
# the callback twice when shutting down Home Assistant.
|
||||
# "Unable to remove unknown listener
|
||||
# <function EventBus.async_listen_once.<locals>.onetime_listener>"
|
||||
entry_data.cleanup_callbacks.append(
|
||||
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, self.on_stop)
|
||||
)
|
||||
|
||||
reconnect_logic = ReconnectLogic(
|
||||
client=self.cli,
|
||||
on_connect=self.on_connect,
|
||||
on_disconnect=self.on_disconnect,
|
||||
zeroconf_instance=self.zeroconf_instance,
|
||||
name=self.host,
|
||||
on_connect_error=self.on_connect_error,
|
||||
)
|
||||
self.reconnect_logic = reconnect_logic
|
||||
|
||||
infos, services = await entry_data.async_load_from_store()
|
||||
await entry_data.async_update_static_infos(hass, entry, infos)
|
||||
await _setup_services(hass, entry_data, services)
|
||||
|
||||
if entry_data.device_info is not None and entry_data.device_info.name:
|
||||
reconnect_logic.name = entry_data.device_info.name
|
||||
if entry.unique_id is None:
|
||||
hass.config_entries.async_update_entry(
|
||||
entry, unique_id=format_mac(entry_data.device_info.mac_address)
|
||||
)
|
||||
|
||||
await reconnect_logic.start()
|
||||
entry_data.cleanup_callbacks.append(reconnect_logic.stop_callback)
|
||||
|
||||
entry.async_on_unload(
|
||||
entry.add_update_listener(entry_data.async_update_listener)
|
||||
)
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up the esphome component."""
|
||||
host = entry.data[CONF_HOST]
|
||||
port = entry.data[CONF_PORT]
|
||||
password = entry.data[CONF_PASSWORD]
|
||||
noise_psk = entry.data.get(CONF_NOISE_PSK)
|
||||
|
||||
zeroconf_instance = await zeroconf.async_get_instance(hass)
|
||||
|
||||
cli = APIClient(
|
||||
host,
|
||||
port,
|
||||
password,
|
||||
client_info=f"Home Assistant {ha_version}",
|
||||
zeroconf_instance=zeroconf_instance,
|
||||
name=host,
|
||||
on_connect_error=on_connect_error,
|
||||
noise_psk=noise_psk,
|
||||
)
|
||||
|
||||
infos, services = await entry_data.async_load_from_store()
|
||||
await entry_data.async_update_static_infos(hass, entry, infos)
|
||||
await _setup_services(hass, entry_data, services)
|
||||
domain_data = DomainData.get(hass)
|
||||
entry_data = RuntimeEntryData(
|
||||
client=cli,
|
||||
entry_id=entry.entry_id,
|
||||
store=domain_data.get_or_create_store(hass, entry),
|
||||
original_options=dict(entry.options),
|
||||
)
|
||||
domain_data.set_entry_data(entry, entry_data)
|
||||
|
||||
if entry_data.device_info is not None and entry_data.device_info.name:
|
||||
reconnect_logic.name = entry_data.device_info.name
|
||||
if entry.unique_id is None:
|
||||
hass.config_entries.async_update_entry(
|
||||
entry, unique_id=format_mac(entry_data.device_info.mac_address)
|
||||
)
|
||||
|
||||
await reconnect_logic.start()
|
||||
entry_data.cleanup_callbacks.append(reconnect_logic.stop_callback)
|
||||
|
||||
entry.async_on_unload(entry.add_update_listener(entry_data.async_update_listener))
|
||||
manager = ESPHomeManager(
|
||||
hass, entry, host, password, cli, zeroconf_instance, domain_data, entry_data
|
||||
)
|
||||
await manager.async_start()
|
||||
|
||||
return True
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue