Minor refactor of zha config flow (#82200)

* Minor refactor of zha config flow

* Move ZhaRadioManager to a separate module
This commit is contained in:
Erik Montnemery 2022-11-16 17:13:23 +01:00 committed by GitHub
parent f952b74b74
commit bb64b39d0e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 283 additions and 215 deletions

View file

@ -1,44 +1,34 @@
"""Config flow for ZHA.""" """Config flow for ZHA."""
from __future__ import annotations from __future__ import annotations
import asyncio
import collections import collections
import contextlib
import copy import copy
import json import json
import logging
import os
from typing import Any from typing import Any
import serial.tools.list_ports import serial.tools.list_ports
import voluptuous as vol import voluptuous as vol
from zigpy.application import ControllerApplication
import zigpy.backups import zigpy.backups
from zigpy.config import CONF_DEVICE, CONF_DEVICE_PATH from zigpy.config import CONF_DEVICE, CONF_DEVICE_PATH
from zigpy.exceptions import NetworkNotFormed
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import onboarding, usb, zeroconf from homeassistant.components import onboarding, usb, zeroconf
from homeassistant.components.file_upload import process_uploaded_file from homeassistant.components.file_upload import process_uploaded_file
from homeassistant.const import CONF_NAME from homeassistant.const import CONF_NAME
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowHandler, FlowResult from homeassistant.data_entry_flow import FlowHandler, FlowResult
from homeassistant.helpers.selector import FileSelector, FileSelectorConfig from homeassistant.helpers.selector import FileSelector, FileSelectorConfig
from homeassistant.util import dt from homeassistant.util import dt
from .core.const import ( from .core.const import (
CONF_BAUDRATE, CONF_BAUDRATE,
CONF_DATABASE,
CONF_FLOWCONTROL, CONF_FLOWCONTROL,
CONF_RADIO_TYPE, CONF_RADIO_TYPE,
CONF_ZIGPY,
DATA_ZHA,
DATA_ZHA_CONFIG,
DEFAULT_DATABASE_NAME,
DOMAIN, DOMAIN,
EZSP_OVERWRITE_EUI64, EZSP_OVERWRITE_EUI64,
RadioType, RadioType,
) )
from .radio_manager import ZhaRadioManager
CONF_MANUAL_PATH = "Enter Manually" CONF_MANUAL_PATH = "Enter Manually"
SUPPORTED_PORT_SETTINGS = ( SUPPORTED_PORT_SETTINGS = (
@ -47,16 +37,6 @@ SUPPORTED_PORT_SETTINGS = (
) )
DECONZ_DOMAIN = "deconz" DECONZ_DOMAIN = "deconz"
# Only the common radio types will be autoprobed, ordered by new device popularity.
# XBee takes too long to probe since it scans through all possible bauds and likely has
# very few users to begin with.
AUTOPROBE_RADIOS = (
RadioType.ezsp,
RadioType.znp,
RadioType.deconz,
RadioType.zigate,
)
FORMATION_STRATEGY = "formation_strategy" FORMATION_STRATEGY = "formation_strategy"
FORMATION_FORM_NEW_NETWORK = "form_new_network" FORMATION_FORM_NEW_NETWORK = "form_new_network"
FORMATION_REUSE_SETTINGS = "reuse_settings" FORMATION_REUSE_SETTINGS = "reuse_settings"
@ -74,8 +54,6 @@ UPLOADED_BACKUP_FILE = "uploaded_backup_file"
DEFAULT_ZHA_ZEROCONF_PORT = 6638 DEFAULT_ZHA_ZEROCONF_PORT = 6638
ESPHOME_API_PORT = 6053 ESPHOME_API_PORT = 6053
CONNECT_DELAY_S = 1.0
HARDWARE_DISCOVERY_SCHEMA = vol.Schema( HARDWARE_DISCOVERY_SCHEMA = vol.Schema(
{ {
vol.Required("name"): str, vol.Required("name"): str,
@ -84,8 +62,6 @@ HARDWARE_DISCOVERY_SCHEMA = vol.Schema(
} }
) )
_LOGGER = logging.getLogger(__name__)
def _format_backup_choice( def _format_backup_choice(
backup: zigpy.backups.NetworkBackup, *, pan_ids: bool = True backup: zigpy.backups.NetworkBackup, *, pan_ids: bool = True
@ -134,110 +110,44 @@ def _prevent_overwrite_ezsp_ieee(
class BaseZhaFlow(FlowHandler): class BaseZhaFlow(FlowHandler):
"""Mixin for common ZHA flow steps and forms.""" """Mixin for common ZHA flow steps and forms."""
_hass: HomeAssistant
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize flow instance.""" """Initialize flow instance."""
super().__init__() super().__init__()
self._device_path: str | None = None self._hass = None # type: ignore[assignment]
self._device_settings: dict[str, Any] | None = None self._radio_mgr = ZhaRadioManager()
self._radio_type: RadioType | None = None
self._title: str | None = None self._title: str | None = None
self._current_settings: zigpy.backups.NetworkBackup | None = None
self._backups: list[zigpy.backups.NetworkBackup] = []
self._chosen_backup: zigpy.backups.NetworkBackup | None = None
@contextlib.asynccontextmanager @property
async def _connect_zigpy_app(self) -> ControllerApplication: def hass(self):
"""Connect to the radio with the current config and then clean up.""" """Return hass."""
assert self._radio_type is not None return self._hass
config = self.hass.data.get(DATA_ZHA, {}).get(DATA_ZHA_CONFIG, {}) @hass.setter
app_config = config.get(CONF_ZIGPY, {}).copy() def hass(self, hass):
"""Set hass."""
self._hass = hass
self._radio_mgr.hass = hass
database_path = config.get( async def _async_create_radio_entry(self) -> FlowResult:
CONF_DATABASE, """Create a config entry with the current flow state."""
self.hass.config.path(DEFAULT_DATABASE_NAME),
)
# Don't create `zigbee.db` if it doesn't already exist
if not await self.hass.async_add_executor_job(os.path.exists, database_path):
database_path = None
app_config[CONF_DATABASE] = database_path
app_config[CONF_DEVICE] = self._device_settings
app_config = self._radio_type.controller.SCHEMA(app_config)
app = await self._radio_type.controller.new(
app_config, auto_form=False, start_radio=False
)
try:
await app.connect()
yield app
finally:
await app.disconnect()
await asyncio.sleep(CONNECT_DELAY_S)
async def _restore_backup(
self, backup: zigpy.backups.NetworkBackup, **kwargs: Any
) -> None:
"""Restore the provided network backup, passing through kwargs."""
if self._current_settings is not None and self._current_settings.supersedes(
self._chosen_backup
):
return
async with self._connect_zigpy_app() as app:
await app.backups.restore_backup(backup, **kwargs)
def _parse_radio_type(self, radio_type: str) -> RadioType:
"""Parse a radio type name, accounting for past aliases."""
if radio_type == "efr32":
return RadioType.ezsp
return RadioType[radio_type]
async def _detect_radio_type(self) -> bool:
"""Probe all radio types on the current port."""
for radio in AUTOPROBE_RADIOS:
_LOGGER.debug("Attempting to probe radio type %s", radio)
dev_config = radio.controller.SCHEMA_DEVICE(
{CONF_DEVICE_PATH: self._device_path}
)
probe_result = await radio.controller.probe(dev_config)
if not probe_result:
continue
# Radio library probing can succeed and return new device settings
if isinstance(probe_result, dict):
dev_config = probe_result
self._radio_type = radio
self._device_settings = dev_config
return True
return False
async def _async_create_radio_entity(self) -> FlowResult:
"""Create a config entity with the current flow state."""
assert self._title is not None assert self._title is not None
assert self._radio_type is not None assert self._radio_mgr.radio_type is not None
assert self._device_path is not None assert self._radio_mgr.device_path is not None
assert self._device_settings is not None assert self._radio_mgr.device_settings is not None
device_settings = self._device_settings.copy() device_settings = self._radio_mgr.device_settings.copy()
device_settings[CONF_DEVICE_PATH] = await self.hass.async_add_executor_job( device_settings[CONF_DEVICE_PATH] = await self.hass.async_add_executor_job(
usb.get_serial_by_id, self._device_path usb.get_serial_by_id, self._radio_mgr.device_path
) )
return self.async_create_entry( return self.async_create_entry(
title=self._title, title=self._title,
data={ data={
CONF_DEVICE: device_settings, CONF_DEVICE: device_settings,
CONF_RADIO_TYPE: self._radio_type.name, CONF_RADIO_TYPE: self._radio_mgr.radio_type.name,
}, },
) )
@ -264,9 +174,9 @@ class BaseZhaFlow(FlowHandler):
return await self.async_step_manual_pick_radio_type() return await self.async_step_manual_pick_radio_type()
port = ports[list_of_ports.index(user_selection)] port = ports[list_of_ports.index(user_selection)]
self._device_path = port.device self._radio_mgr.device_path = port.device
if not await self._detect_radio_type(): if not await self._radio_mgr.detect_radio_type():
# Did not autodetect anything, proceed to manual selection # Did not autodetect anything, proceed to manual selection
return await self.async_step_manual_pick_radio_type() return await self.async_step_manual_pick_radio_type()
@ -282,9 +192,9 @@ class BaseZhaFlow(FlowHandler):
# Pre-select the currently configured port # Pre-select the currently configured port
default_port = vol.UNDEFINED default_port = vol.UNDEFINED
if self._device_path is not None: if self._radio_mgr.device_path is not None:
for description, port in zip(list_of_ports, ports): for description, port in zip(list_of_ports, ports):
if port.device == self._device_path: if port.device == self._radio_mgr.device_path:
default_port = description default_port = description
break break
else: else:
@ -304,14 +214,16 @@ class BaseZhaFlow(FlowHandler):
) -> FlowResult: ) -> FlowResult:
"""Manually select the radio type.""" """Manually select the radio type."""
if user_input is not None: if user_input is not None:
self._radio_type = RadioType.get_by_description(user_input[CONF_RADIO_TYPE]) self._radio_mgr.radio_type = RadioType.get_by_description(
user_input[CONF_RADIO_TYPE]
)
return await self.async_step_manual_port_config() return await self.async_step_manual_port_config()
# Pre-select the current radio type # Pre-select the current radio type
default = vol.UNDEFINED default = vol.UNDEFINED
if self._radio_type is not None: if self._radio_mgr.radio_type is not None:
default = self._radio_type.description default = self._radio_mgr.radio_type.description
schema = { schema = {
vol.Required(CONF_RADIO_TYPE, default=default): vol.In(RadioType.list()) vol.Required(CONF_RADIO_TYPE, default=default): vol.In(RadioType.list())
@ -326,35 +238,43 @@ class BaseZhaFlow(FlowHandler):
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Enter port settings specific for this type of radio.""" """Enter port settings specific for this type of radio."""
assert self._radio_type is not None assert self._radio_mgr.radio_type is not None
errors = {} errors = {}
if user_input is not None: if user_input is not None:
self._title = user_input[CONF_DEVICE_PATH] self._title = user_input[CONF_DEVICE_PATH]
self._device_path = user_input[CONF_DEVICE_PATH] self._radio_mgr.device_path = user_input[CONF_DEVICE_PATH]
self._device_settings = user_input.copy() self._radio_mgr.device_settings = user_input.copy()
if await self._radio_type.controller.probe(user_input): if await self._radio_mgr.radio_type.controller.probe(user_input):
return await self.async_step_choose_formation_strategy() return await self.async_step_choose_formation_strategy()
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
schema = { schema = {
vol.Required( vol.Required(
CONF_DEVICE_PATH, default=self._device_path or vol.UNDEFINED CONF_DEVICE_PATH, default=self._radio_mgr.device_path or vol.UNDEFINED
): str ): str
} }
source = self.context.get("source") source = self.context.get("source")
for param, value in self._radio_type.controller.SCHEMA_DEVICE.schema.items(): for (
param,
value,
) in self._radio_mgr.radio_type.controller.SCHEMA_DEVICE.schema.items():
if param not in SUPPORTED_PORT_SETTINGS: if param not in SUPPORTED_PORT_SETTINGS:
continue continue
if source == config_entries.SOURCE_ZEROCONF and param == CONF_BAUDRATE: if source == config_entries.SOURCE_ZEROCONF and param == CONF_BAUDRATE:
value = 115200 value = 115200
param = vol.Required(CONF_BAUDRATE, default=value) param = vol.Required(CONF_BAUDRATE, default=value)
elif self._device_settings is not None and param in self._device_settings: elif (
param = vol.Required(str(param), default=self._device_settings[param]) self._radio_mgr.device_settings is not None
and param in self._radio_mgr.device_settings
):
param = vol.Required(
str(param), default=self._radio_mgr.device_settings[param]
)
schema[param] = value schema[param] = value
@ -364,43 +284,26 @@ class BaseZhaFlow(FlowHandler):
errors=errors, errors=errors,
) )
async def _async_load_network_settings(self) -> None:
"""Connect to the radio and load its current network settings."""
async with self._connect_zigpy_app() as app:
# Check if the stick has any settings and load them
try:
await app.load_network_info()
except NetworkNotFormed:
pass
else:
self._current_settings = zigpy.backups.NetworkBackup(
network_info=app.state.network_info,
node_info=app.state.node_info,
)
# The list of backups will always exist
self._backups = app.backups.backups.copy()
async def async_step_choose_formation_strategy( async def async_step_choose_formation_strategy(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Choose how to deal with the current radio's settings.""" """Choose how to deal with the current radio's settings."""
await self._async_load_network_settings() await self._radio_mgr.async_load_network_settings()
strategies = [] strategies = []
# Check if we have any automatic backups *and* if the backups differ from # Check if we have any automatic backups *and* if the backups differ from
# the current radio settings, if they exist (since restoring would be redundant) # the current radio settings, if they exist (since restoring would be redundant)
if self._backups and ( if self._radio_mgr.backups and (
self._current_settings is None self._radio_mgr.current_settings is None
or any( or any(
not backup.is_compatible_with(self._current_settings) not backup.is_compatible_with(self._radio_mgr.current_settings)
for backup in self._backups for backup in self._radio_mgr.backups
) )
): ):
strategies.append(CHOOSE_AUTOMATIC_BACKUP) strategies.append(CHOOSE_AUTOMATIC_BACKUP)
if self._current_settings is not None: if self._radio_mgr.current_settings is not None:
strategies.append(FORMATION_REUSE_SETTINGS) strategies.append(FORMATION_REUSE_SETTINGS)
strategies.append(FORMATION_UPLOAD_MANUAL_BACKUP) strategies.append(FORMATION_UPLOAD_MANUAL_BACKUP)
@ -415,16 +318,14 @@ class BaseZhaFlow(FlowHandler):
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Reuse the existing network settings on the stick.""" """Reuse the existing network settings on the stick."""
return await self._async_create_radio_entity() return await self._async_create_radio_entry()
async def async_step_form_new_network( async def async_step_form_new_network(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Form a brand new network.""" """Form a brand new network."""
async with self._connect_zigpy_app() as app: await self._radio_mgr.async_form_network()
await app.form_network() return await self._async_create_radio_entry()
return await self._async_create_radio_entity()
def _parse_uploaded_backup( def _parse_uploaded_backup(
self, uploaded_file_id: str self, uploaded_file_id: str
@ -443,7 +344,7 @@ class BaseZhaFlow(FlowHandler):
if user_input is not None: if user_input is not None:
try: try:
self._chosen_backup = await self.hass.async_add_executor_job( self._radio_mgr.chosen_backup = await self.hass.async_add_executor_job(
self._parse_uploaded_backup, user_input[UPLOADED_BACKUP_FILE] self._parse_uploaded_backup, user_input[UPLOADED_BACKUP_FILE]
) )
except ValueError: except ValueError:
@ -470,23 +371,24 @@ class BaseZhaFlow(FlowHandler):
if self.show_advanced_options: if self.show_advanced_options:
# Always show the PAN IDs when in advanced mode # Always show the PAN IDs when in advanced mode
choices = [ choices = [
_format_backup_choice(backup, pan_ids=True) for backup in self._backups _format_backup_choice(backup, pan_ids=True)
for backup in self._radio_mgr.backups
] ]
else: else:
# Only show the PAN IDs for multiple backups taken on the same day # Only show the PAN IDs for multiple backups taken on the same day
num_backups_on_date = collections.Counter( num_backups_on_date = collections.Counter(
backup.backup_time.date() for backup in self._backups backup.backup_time.date() for backup in self._radio_mgr.backups
) )
choices = [ choices = [
_format_backup_choice( _format_backup_choice(
backup, pan_ids=(num_backups_on_date[backup.backup_time.date()] > 1) backup, pan_ids=(num_backups_on_date[backup.backup_time.date()] > 1)
) )
for backup in self._backups for backup in self._radio_mgr.backups
] ]
if user_input is not None: if user_input is not None:
index = choices.index(user_input[CHOOSE_AUTOMATIC_BACKUP]) index = choices.index(user_input[CHOOSE_AUTOMATIC_BACKUP])
self._chosen_backup = self._backups[index] self._radio_mgr.chosen_backup = self._radio_mgr.backups[index]
return await self.async_step_maybe_confirm_ezsp_restore() return await self.async_step_maybe_confirm_ezsp_restore()
@ -505,46 +407,47 @@ class BaseZhaFlow(FlowHandler):
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Confirm restore for EZSP radios that require permanent IEEE writes.""" """Confirm restore for EZSP radios that require permanent IEEE writes."""
assert self._chosen_backup is not None assert self._radio_mgr.chosen_backup is not None
if self._radio_type != RadioType.ezsp: if self._radio_mgr.radio_type != RadioType.ezsp:
await self._restore_backup(self._chosen_backup) await self._radio_mgr.restore_backup(self._radio_mgr.chosen_backup)
return await self._async_create_radio_entity() return await self._async_create_radio_entry()
# We have no way to partially load network settings if no network is formed # We have no way to partially load network settings if no network is formed
if self._current_settings is None: if self._radio_mgr.current_settings is None:
# Since we are going to be restoring the backup anyways, write it to the # Since we are going to be restoring the backup anyways, write it to the
# radio without overwriting the IEEE but don't take a backup with these # radio without overwriting the IEEE but don't take a backup with these
# temporary settings # temporary settings
temp_backup = _prevent_overwrite_ezsp_ieee(self._chosen_backup) temp_backup = _prevent_overwrite_ezsp_ieee(self._radio_mgr.chosen_backup)
await self._restore_backup(temp_backup, create_new=False) await self._radio_mgr.restore_backup(temp_backup, create_new=False)
await self._async_load_network_settings() await self._radio_mgr.async_load_network_settings()
assert self._current_settings is not None assert self._radio_mgr.current_settings is not None
if ( if (
self._current_settings.node_info.ieee == self._chosen_backup.node_info.ieee self._radio_mgr.current_settings.node_info.ieee
or not self._current_settings.network_info.metadata["ezsp"][ == self._radio_mgr.chosen_backup.node_info.ieee
or not self._radio_mgr.current_settings.network_info.metadata["ezsp"][
"can_write_custom_eui64" "can_write_custom_eui64"
] ]
): ):
# No point in prompting the user if the backup doesn't have a new IEEE # No point in prompting the user if the backup doesn't have a new IEEE
# address or if there is no way to overwrite the IEEE address a second time # address or if there is no way to overwrite the IEEE address a second time
await self._restore_backup(self._chosen_backup) await self._radio_mgr.restore_backup(self._radio_mgr.chosen_backup)
return await self._async_create_radio_entity() return await self._async_create_radio_entry()
if user_input is not None: if user_input is not None:
backup = self._chosen_backup backup = self._radio_mgr.chosen_backup
if user_input[OVERWRITE_COORDINATOR_IEEE]: if user_input[OVERWRITE_COORDINATOR_IEEE]:
backup = _allow_overwrite_ezsp_ieee(backup) backup = _allow_overwrite_ezsp_ieee(backup)
# If the user declined to overwrite the IEEE *and* we wrote the backup to # If the user declined to overwrite the IEEE *and* we wrote the backup to
# their empty radio above, restoring it again would be redundant. # their empty radio above, restoring it again would be redundant.
await self._restore_backup(backup) await self._radio_mgr.restore_backup(backup)
return await self._async_create_radio_entity() return await self._async_create_radio_entry()
return self.async_show_form( return self.async_show_form(
step_id="maybe_confirm_ezsp_restore", step_id="maybe_confirm_ezsp_restore",
@ -608,13 +511,16 @@ class ZhaConfigFlowHandler(BaseZhaFlow, config_entries.ConfigFlow, domain=DOMAIN
# config flow logic that interacts with hardware! # config flow logic that interacts with hardware!
if user_input is not None or not onboarding.async_is_onboarded(self.hass): if user_input is not None or not onboarding.async_is_onboarded(self.hass):
# Probe the radio type if we don't have one yet # Probe the radio type if we don't have one yet
if self._radio_type is None and not await self._detect_radio_type(): if (
self._radio_mgr.radio_type is None
and not await self._radio_mgr.detect_radio_type()
):
# This path probably will not happen now that we have # This path probably will not happen now that we have
# more precise USB matching unless there is a problem # more precise USB matching unless there is a problem
# with the device # with the device
return self.async_abort(reason="usb_probe_failed") return self.async_abort(reason="usb_probe_failed")
if self._device_settings is None: if self._radio_mgr.device_settings is None:
return await self.async_step_manual_port_config() return await self.async_step_manual_port_config()
return await self.async_step_choose_formation_strategy() return await self.async_step_choose_formation_strategy()
@ -647,7 +553,7 @@ class ZhaConfigFlowHandler(BaseZhaFlow, config_entries.ConfigFlow, domain=DOMAIN
if entry.source != config_entries.SOURCE_IGNORE: if entry.source != config_entries.SOURCE_IGNORE:
return self.async_abort(reason="not_zha_device") return self.async_abort(reason="not_zha_device")
self._device_path = dev_path self._radio_mgr.device_path = dev_path
self._title = description or usb.human_readable_device_name( self._title = description or usb.human_readable_device_name(
dev_path, dev_path,
serial_number, serial_number,
@ -673,13 +579,13 @@ class ZhaConfigFlowHandler(BaseZhaFlow, config_entries.ConfigFlow, domain=DOMAIN
port = DEFAULT_ZHA_ZEROCONF_PORT port = DEFAULT_ZHA_ZEROCONF_PORT
if "radio_type" in discovery_info.properties: if "radio_type" in discovery_info.properties:
self._radio_type = self._parse_radio_type( self._radio_mgr.radio_type = self._radio_mgr.parse_radio_type(
discovery_info.properties["radio_type"] discovery_info.properties["radio_type"]
) )
elif "efr32" in local_name: elif "efr32" in local_name:
self._radio_type = RadioType.ezsp self._radio_mgr.radio_type = RadioType.ezsp
else: else:
self._radio_type = RadioType.znp self._radio_mgr.radio_type = RadioType.znp
node_name = local_name[: -len(".local")] node_name = local_name[: -len(".local")]
device_path = f"socket://{discovery_info.host}:{port}" device_path = f"socket://{discovery_info.host}:{port}"
@ -691,7 +597,7 @@ class ZhaConfigFlowHandler(BaseZhaFlow, config_entries.ConfigFlow, domain=DOMAIN
self.context["title_placeholders"] = {CONF_NAME: node_name} self.context["title_placeholders"] = {CONF_NAME: node_name}
self._title = device_path self._title = device_path
self._device_path = device_path self._radio_mgr.device_path = device_path
return await self.async_step_confirm() return await self.async_step_confirm()
@ -705,7 +611,7 @@ class ZhaConfigFlowHandler(BaseZhaFlow, config_entries.ConfigFlow, domain=DOMAIN
return self.async_abort(reason="invalid_hardware_data") return self.async_abort(reason="invalid_hardware_data")
name = discovery_data["name"] name = discovery_data["name"]
radio_type = self._parse_radio_type(discovery_data["radio_type"]) radio_type = self._radio_mgr.parse_radio_type(discovery_data["radio_type"])
try: try:
device_settings = radio_type.controller.SCHEMA_DEVICE( device_settings = radio_type.controller.SCHEMA_DEVICE(
@ -720,9 +626,9 @@ class ZhaConfigFlowHandler(BaseZhaFlow, config_entries.ConfigFlow, domain=DOMAIN
) )
self._title = name self._title = name
self._radio_type = radio_type self._radio_mgr.radio_type = radio_type
self._device_path = device_settings[CONF_DEVICE_PATH] self._radio_mgr.device_path = device_settings[CONF_DEVICE_PATH]
self._device_settings = device_settings self._radio_mgr.device_settings = device_settings
self.context["title_placeholders"] = {CONF_NAME: name} self.context["title_placeholders"] = {CONF_NAME: name}
return await self.async_step_confirm() return await self.async_step_confirm()
@ -736,9 +642,9 @@ class ZhaOptionsFlowHandler(BaseZhaFlow, config_entries.OptionsFlow):
super().__init__() super().__init__()
self.config_entry = config_entry self.config_entry = config_entry
self._device_path = config_entry.data[CONF_DEVICE][CONF_DEVICE_PATH] self._radio_mgr.device_path = config_entry.data[CONF_DEVICE][CONF_DEVICE_PATH]
self._device_settings = config_entry.data[CONF_DEVICE] self._radio_mgr.device_settings = config_entry.data[CONF_DEVICE]
self._radio_type = RadioType[config_entry.data[CONF_RADIO_TYPE]] self._radio_mgr.radio_type = RadioType[config_entry.data[CONF_RADIO_TYPE]]
self._title = config_entry.title self._title = config_entry.title
async def async_step_init( async def async_step_init(
@ -781,9 +687,7 @@ class ZhaOptionsFlowHandler(BaseZhaFlow, config_entries.OptionsFlow):
"""Confirm the user wants to reset their current radio.""" """Confirm the user wants to reset their current radio."""
if user_input is not None: if user_input is not None:
# Reset the current adapter await self._radio_mgr.async_reset_adapter()
async with self._connect_zigpy_app() as app:
await app.reset_network_info()
return await self.async_step_instruct_unplug() return await self.async_step_instruct_unplug()
@ -800,11 +704,11 @@ class ZhaOptionsFlowHandler(BaseZhaFlow, config_entries.OptionsFlow):
return self.async_show_form(step_id="instruct_unplug") return self.async_show_form(step_id="instruct_unplug")
async def _async_create_radio_entity(self): async def _async_create_radio_entry(self):
"""Re-implementation of the base flow's final step to update the config.""" """Re-implementation of the base flow's final step to update the config."""
device_settings = self._device_settings.copy() device_settings = self._radio_mgr.device_settings.copy()
device_settings[CONF_DEVICE_PATH] = await self.hass.async_add_executor_job( device_settings[CONF_DEVICE_PATH] = await self.hass.async_add_executor_job(
usb.get_serial_by_id, self._device_path usb.get_serial_by_id, self._radio_mgr.device_path
) )
# Avoid creating both `.options` and `.data` by directly writing `data` here # Avoid creating both `.options` and `.data` by directly writing `data` here
@ -812,7 +716,7 @@ class ZhaOptionsFlowHandler(BaseZhaFlow, config_entries.OptionsFlow):
entry=self.config_entry, entry=self.config_entry,
data={ data={
CONF_DEVICE: device_settings, CONF_DEVICE: device_settings,
CONF_RADIO_TYPE: self._radio_type.name, CONF_RADIO_TYPE: self._radio_mgr.radio_type.name,
}, },
options=self.config_entry.options, options=self.config_entry.options,
) )

View file

@ -0,0 +1,158 @@
"""Config flow for ZHA."""
from __future__ import annotations
import asyncio
import contextlib
import logging
import os
from typing import Any
from zigpy.application import ControllerApplication
import zigpy.backups
from zigpy.config import CONF_DEVICE, CONF_DEVICE_PATH
from zigpy.exceptions import NetworkNotFormed
from homeassistant.core import HomeAssistant
from .core.const import (
CONF_DATABASE,
CONF_ZIGPY,
DATA_ZHA,
DATA_ZHA_CONFIG,
DEFAULT_DATABASE_NAME,
RadioType,
)
# Only the common radio types will be autoprobed, ordered by new device popularity.
# XBee takes too long to probe since it scans through all possible bauds and likely has
# very few users to begin with.
AUTOPROBE_RADIOS = (
RadioType.ezsp,
RadioType.znp,
RadioType.deconz,
RadioType.zigate,
)
CONNECT_DELAY_S = 1.0
_LOGGER = logging.getLogger(__name__)
class ZhaRadioManager:
"""Helper class with radio related functionality."""
hass: HomeAssistant
def __init__(self) -> None:
"""Initialize ZhaRadioManager instance."""
self.device_path: str | None = None
self.device_settings: dict[str, Any] | None = None
self.radio_type: RadioType | None = None
self.current_settings: zigpy.backups.NetworkBackup | None = None
self.backups: list[zigpy.backups.NetworkBackup] = []
self.chosen_backup: zigpy.backups.NetworkBackup | None = None
@contextlib.asynccontextmanager
async def _connect_zigpy_app(self) -> ControllerApplication:
"""Connect to the radio with the current config and then clean up."""
assert self.radio_type is not None
config = self.hass.data.get(DATA_ZHA, {}).get(DATA_ZHA_CONFIG, {})
app_config = config.get(CONF_ZIGPY, {}).copy()
database_path = config.get(
CONF_DATABASE,
self.hass.config.path(DEFAULT_DATABASE_NAME),
)
# Don't create `zigbee.db` if it doesn't already exist
if not await self.hass.async_add_executor_job(os.path.exists, database_path):
database_path = None
app_config[CONF_DATABASE] = database_path
app_config[CONF_DEVICE] = self.device_settings
app_config = self.radio_type.controller.SCHEMA(app_config)
app = await self.radio_type.controller.new(
app_config, auto_form=False, start_radio=False
)
try:
await app.connect()
yield app
finally:
await app.disconnect()
await asyncio.sleep(CONNECT_DELAY_S)
async def restore_backup(
self, backup: zigpy.backups.NetworkBackup, **kwargs: Any
) -> None:
"""Restore the provided network backup, passing through kwargs."""
if self.current_settings is not None and self.current_settings.supersedes(
self.chosen_backup
):
return
async with self._connect_zigpy_app() as app:
await app.backups.restore_backup(backup, **kwargs)
def parse_radio_type(self, radio_type: str) -> RadioType:
"""Parse a radio type name, accounting for past aliases."""
if radio_type == "efr32":
return RadioType.ezsp
return RadioType[radio_type]
async def detect_radio_type(self) -> bool:
"""Probe all radio types on the current port."""
for radio in AUTOPROBE_RADIOS:
_LOGGER.debug("Attempting to probe radio type %s", radio)
dev_config = radio.controller.SCHEMA_DEVICE(
{CONF_DEVICE_PATH: self.device_path}
)
probe_result = await radio.controller.probe(dev_config)
if not probe_result:
continue
# Radio library probing can succeed and return new device settings
if isinstance(probe_result, dict):
dev_config = probe_result
self.radio_type = radio
self.device_settings = dev_config
return True
return False
async def async_load_network_settings(self, create_backup: bool = False) -> None:
"""Connect to the radio and load its current network settings."""
async with self._connect_zigpy_app() as app:
# Check if the stick has any settings and load them
try:
await app.load_network_info()
except NetworkNotFormed:
pass
else:
self.current_settings = zigpy.backups.NetworkBackup(
network_info=app.state.network_info,
node_info=app.state.node_info,
)
if create_backup:
await app.backups.create_backup()
# The list of backups will always exist
self.backups = app.backups.backups.copy()
async def async_form_network(self) -> None:
"""Form a brand new network."""
async with self._connect_zigpy_app() as app:
await app.form_network()
async def async_reset_adapter(self) -> None:
"""Reset the current adapter."""
async with self._connect_zigpy_app() as app:
await app.reset_network_info()

View file

@ -49,7 +49,7 @@ def disable_platform_only():
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reduce_reconnect_timeout(): def reduce_reconnect_timeout():
"""Reduces reconnect timeout to speed up tests.""" """Reduces reconnect timeout to speed up tests."""
with patch("homeassistant.components.zha.config_flow.CONNECT_DELAY_S", 0.01): with patch("homeassistant.components.zha.radio_manager.CONNECT_DELAY_S", 0.01):
yield yield
@ -76,12 +76,12 @@ def backup():
def mock_detect_radio_type(radio_type=RadioType.ezsp, ret=True): def mock_detect_radio_type(radio_type=RadioType.ezsp, ret=True):
"""Mock `_detect_radio_type` that just sets the appropriate attributes.""" """Mock `detect_radio_type` that just sets the appropriate attributes."""
async def detect(self): async def detect(self):
self._radio_type = radio_type self.radio_type = radio_type
self._device_settings = radio_type.controller.SCHEMA_DEVICE( self.device_settings = radio_type.controller.SCHEMA_DEVICE(
{CONF_DEVICE_PATH: self._device_path} {CONF_DEVICE_PATH: self.device_path}
) )
return ret return ret
@ -669,7 +669,7 @@ async def test_discovery_already_setup(hass):
@patch( @patch(
"homeassistant.components.zha.config_flow.ZhaConfigFlowHandler._detect_radio_type", "homeassistant.components.zha.radio_manager.ZhaRadioManager.detect_radio_type",
mock_detect_radio_type(radio_type=RadioType.deconz), mock_detect_radio_type(radio_type=RadioType.deconz),
) )
@patch("serial.tools.list_ports.comports", MagicMock(return_value=[com_port()])) @patch("serial.tools.list_ports.comports", MagicMock(return_value=[com_port()]))
@ -707,7 +707,7 @@ async def test_user_flow(hass):
@patch( @patch(
"homeassistant.components.zha.config_flow.ZhaConfigFlowHandler._detect_radio_type", "homeassistant.components.zha.radio_manager.ZhaRadioManager.detect_radio_type",
mock_detect_radio_type(ret=False), mock_detect_radio_type(ret=False),
) )
@patch("serial.tools.list_ports.comports", MagicMock(return_value=[com_port()])) @patch("serial.tools.list_ports.comports", MagicMock(return_value=[com_port()]))
@ -799,12 +799,14 @@ async def test_detect_radio_type_success(
"""Test detect radios successfully.""" """Test detect radios successfully."""
handler = config_flow.ZhaConfigFlowHandler() handler = config_flow.ZhaConfigFlowHandler()
handler._device_path = "/dev/null" handler._radio_mgr.device_path = "/dev/null"
await handler._detect_radio_type() await handler._radio_mgr.detect_radio_type()
assert handler._radio_type == RadioType.znp assert handler._radio_mgr.radio_type == RadioType.znp
assert handler._device_settings[zigpy.config.CONF_DEVICE_PATH] == "/dev/null" assert (
handler._radio_mgr.device_settings[zigpy.config.CONF_DEVICE_PATH] == "/dev/null"
)
assert bellows_probe.await_count == 1 assert bellows_probe.await_count == 1
assert znp_probe.await_count == 1 assert znp_probe.await_count == 1
@ -825,12 +827,14 @@ async def test_detect_radio_type_success_with_settings(
"""Test detect radios successfully but probing returns new settings.""" """Test detect radios successfully but probing returns new settings."""
handler = config_flow.ZhaConfigFlowHandler() handler = config_flow.ZhaConfigFlowHandler()
handler._device_path = "/dev/null" handler._radio_mgr.device_path = "/dev/null"
await handler._detect_radio_type() await handler._radio_mgr.detect_radio_type()
assert handler._radio_type == RadioType.ezsp assert handler._radio_mgr.radio_type == RadioType.ezsp
assert handler._device_settings["new_setting"] == 123 assert handler._radio_mgr.device_settings["new_setting"] == 123
assert handler._device_settings[zigpy.config.CONF_DEVICE_PATH] == "/dev/null" assert (
handler._radio_mgr.device_settings[zigpy.config.CONF_DEVICE_PATH] == "/dev/null"
)
assert bellows_probe.await_count == 1 assert bellows_probe.await_count == 1
assert znp_probe.await_count == 0 assert znp_probe.await_count == 0
@ -1047,7 +1051,7 @@ def pick_radio(hass):
port_select = f"{port}, s/n: {port.serial_number} - {port.manufacturer}" port_select = f"{port}, s/n: {port.serial_number} - {port.manufacturer}"
with patch( with patch(
"homeassistant.components.zha.config_flow.ZhaConfigFlowHandler._detect_radio_type", "homeassistant.components.zha.radio_manager.ZhaRadioManager.detect_radio_type",
mock_detect_radio_type(radio_type=radio_type), mock_detect_radio_type(radio_type=radio_type),
): ):
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -1631,6 +1635,7 @@ async def test_options_flow_defaults_socket(hass):
assert result5["step_id"] == "choose_formation_strategy" assert result5["step_id"] == "choose_formation_strategy"
@patch("serial.tools.list_ports.comports", MagicMock(return_value=[com_port()]))
@patch("homeassistant.components.zha.async_setup_entry", return_value=True) @patch("homeassistant.components.zha.async_setup_entry", return_value=True)
async def test_options_flow_restarts_running_zha_if_cancelled(async_setup_entry, hass): async def test_options_flow_restarts_running_zha_if_cancelled(async_setup_entry, hass):
"""Test options flow restarts a previously-running ZHA if it's cancelled.""" """Test options flow restarts a previously-running ZHA if it's cancelled."""
@ -1683,6 +1688,7 @@ async def test_options_flow_restarts_running_zha_if_cancelled(async_setup_entry,
async_setup_entry.assert_called_once_with(hass, entry) async_setup_entry.assert_called_once_with(hass, entry)
@patch("serial.tools.list_ports.comports", MagicMock(return_value=[com_port()]))
@patch("homeassistant.components.zha.async_setup_entry", AsyncMock(return_value=True)) @patch("homeassistant.components.zha.async_setup_entry", AsyncMock(return_value=True))
async def test_options_flow_migration_reset_old_adapter(hass, mock_app): async def test_options_flow_migration_reset_old_adapter(hass, mock_app):
"""Test options flow for migrating from an old radio.""" """Test options flow for migrating from an old radio."""