Improve / clean up Plugwise config_flow code (#127238)

This commit is contained in:
Bouwe Westerdijk 2024-10-01 21:52:16 +02:00 committed by GitHub
parent dd478fe681
commit 0616bc7fec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 17 additions and 23 deletions

View file

@ -16,8 +16,9 @@ from plugwise.exceptions import (
import voluptuous as vol
from homeassistant.components.zeroconf import ZeroconfServiceInfo
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
from homeassistant.config_entries import SOURCE_USER, ConfigFlow, ConfigFlowResult
from homeassistant.const import (
ATTR_CONFIGURATION_URL,
CONF_BASE,
CONF_HOST,
CONF_NAME,
@ -29,13 +30,11 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import (
API,
DEFAULT_PORT,
DEFAULT_USERNAME,
DOMAIN,
FLOW_SMILE,
FLOW_STRETCH,
PW_TYPE,
SMILE,
STRETCH,
STRETCH_USERNAME,
@ -43,12 +42,12 @@ from .const import (
)
def _base_gw_schema(discovery_info: ZeroconfServiceInfo | None) -> vol.Schema:
def base_schema(discovery_info: ZeroconfServiceInfo | None) -> vol.Schema:
"""Generate base schema for gateways."""
base_gw_schema = vol.Schema({vol.Required(CONF_PASSWORD): str})
schema = vol.Schema({vol.Required(CONF_PASSWORD): str})
if not discovery_info:
base_gw_schema = base_gw_schema.extend(
schema = schema.extend(
{
vol.Required(CONF_HOST): str,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): int,
@ -58,13 +57,13 @@ def _base_gw_schema(discovery_info: ZeroconfServiceInfo | None) -> vol.Schema:
}
)
return base_gw_schema
return schema
async def validate_gw_input(hass: HomeAssistant, data: dict[str, Any]) -> Smile:
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> Smile:
"""Validate whether the user input allows us to connect to the gateway.
Data has the keys from _base_gw_schema() with values provided by the user.
Data has the keys from base_schema() with values provided by the user.
"""
websession = async_get_clientsession(hass, verify_ssl=False)
api = Smile(
@ -85,7 +84,7 @@ class PlugwiseConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1
discovery_info: ZeroconfServiceInfo | None = None
product: str | None = None
product: str = "Unknown Smile"
_username: str = DEFAULT_USERNAME
async def async_step_zeroconf(
@ -98,7 +97,7 @@ class PlugwiseConfigFlow(ConfigFlow, domain=DOMAIN):
unique_id = discovery_info.hostname.split(".")[0].split("-")[0]
if config_entry := await self.async_set_unique_id(unique_id):
try:
await validate_gw_input(
await validate_input(
self.hass,
{
CONF_HOST: discovery_info.host,
@ -119,7 +118,7 @@ class PlugwiseConfigFlow(ConfigFlow, domain=DOMAIN):
if DEFAULT_USERNAME not in unique_id:
self._username = STRETCH_USERNAME
self.product = _product = _properties.get("product", None)
self.product = _product = _properties.get("product", "Unknown Smile")
_version = _properties.get("version", "n/a")
_name = f"{ZEROCONF_MAP.get(_product, _product)} v{_version}"
@ -137,7 +136,7 @@ class PlugwiseConfigFlow(ConfigFlow, domain=DOMAIN):
self.context.update(
{
"title_placeholders": {CONF_NAME: _name},
"configuration_url": (
ATTR_CONFIGURATION_URL: (
f"http://{discovery_info.host}:{discovery_info.port}"
),
}
@ -160,7 +159,7 @@ class PlugwiseConfigFlow(ConfigFlow, domain=DOMAIN):
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the initial step when using network/gateway setups."""
errors = {}
errors: dict[str, str] = {}
if user_input is not None:
if self.discovery_info:
@ -169,7 +168,7 @@ class PlugwiseConfigFlow(ConfigFlow, domain=DOMAIN):
user_input[CONF_USERNAME] = self._username
try:
api = await validate_gw_input(self.hass, user_input)
api = await validate_input(self.hass, user_input)
except ConnectionFailedError:
errors[CONF_BASE] = "cannot_connect"
except InvalidAuthentication:
@ -188,11 +187,10 @@ class PlugwiseConfigFlow(ConfigFlow, domain=DOMAIN):
)
self._abort_if_unique_id_configured()
user_input[PW_TYPE] = API
return self.async_create_entry(title=api.smile_name, data=user_input)
return self.async_show_form(
step_id="user",
data_schema=_base_gw_schema(self.discovery_info),
step_id=SOURCE_USER,
data_schema=base_schema(self.discovery_info),
errors=errors,
)

View file

@ -12,7 +12,7 @@ from plugwise.exceptions import (
)
import pytest
from homeassistant.components.plugwise.const import API, DEFAULT_PORT, DOMAIN, PW_TYPE
from homeassistant.components.plugwise.const import DEFAULT_PORT, DOMAIN
from homeassistant.components.zeroconf import ZeroconfServiceInfo
from homeassistant.config_entries import SOURCE_USER, SOURCE_ZEROCONF
from homeassistant.const import (
@ -123,7 +123,6 @@ async def test_form(
CONF_PASSWORD: TEST_PASSWORD,
CONF_PORT: DEFAULT_PORT,
CONF_USERNAME: TEST_USERNAME,
PW_TYPE: API,
}
assert len(mock_setup_entry.mock_calls) == 1
@ -168,7 +167,6 @@ async def test_zeroconf_flow(
CONF_PASSWORD: TEST_PASSWORD,
CONF_PORT: DEFAULT_PORT,
CONF_USERNAME: TEST_USERNAME,
PW_TYPE: API,
}
assert len(mock_setup_entry.mock_calls) == 1
@ -204,7 +202,6 @@ async def test_zeroconf_flow_stretch(
CONF_PASSWORD: TEST_PASSWORD,
CONF_PORT: DEFAULT_PORT,
CONF_USERNAME: TEST_USERNAME2,
PW_TYPE: API,
}
assert len(mock_setup_entry.mock_calls) == 1
@ -308,7 +305,6 @@ async def test_flow_errors(
CONF_PASSWORD: TEST_PASSWORD,
CONF_PORT: DEFAULT_PORT,
CONF_USERNAME: TEST_USERNAME,
PW_TYPE: API,
}
assert len(mock_setup_entry.mock_calls) == 1