From a4dbb9a24e676d3d87003b678f3b14347630b1a9 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Fri, 25 Nov 2022 09:29:54 +0100 Subject: [PATCH] Add handler to validate_user_input (#82681) * Add handler to validate_user_input * Adjust group config flow --- homeassistant/components/group/config_flow.py | 9 +++++++-- homeassistant/components/scrape/config_flow.py | 9 +++++++-- homeassistant/components/threshold/config_flow.py | 9 ++++++--- homeassistant/components/utility_meter/config_flow.py | 9 ++++++--- homeassistant/helpers/schema_config_entry_flow.py | 8 +++++--- 5 files changed, 31 insertions(+), 13 deletions(-) diff --git a/homeassistant/components/group/config_flow.py b/homeassistant/components/group/config_flow.py index 5453d3024f5..1dd01f6193a 100644 --- a/homeassistant/components/group/config_flow.py +++ b/homeassistant/components/group/config_flow.py @@ -11,6 +11,7 @@ from homeassistant.const import CONF_ENTITIES from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import entity_registry as er, selector from homeassistant.helpers.schema_config_entry_flow import ( + SchemaCommonFlowHandler, SchemaConfigFlowHandler, SchemaFlowFormStep, SchemaFlowMenuStep, @@ -104,11 +105,15 @@ def choose_options_step(options: dict[str, Any]) -> str: return cast(str, options["group_type"]) -def set_group_type(group_type: str) -> Callable[[dict[str, Any]], dict[str, Any]]: +def set_group_type( + group_type: str, +) -> Callable[[SchemaCommonFlowHandler, dict[str, Any]], dict[str, Any]]: """Set group type.""" @callback - def _set_group_type(user_input: dict[str, Any]) -> dict[str, Any]: + def _set_group_type( + handler: SchemaCommonFlowHandler, user_input: dict[str, Any] + ) -> dict[str, Any]: """Add group type to user input.""" return {"group_type": group_type, **user_input} diff --git a/homeassistant/components/scrape/config_flow.py b/homeassistant/components/scrape/config_flow.py index eedc584a394..b53f2f12bd2 100644 --- a/homeassistant/components/scrape/config_flow.py +++ b/homeassistant/components/scrape/config_flow.py @@ -36,6 +36,7 @@ from homeassistant.const import ( ) from homeassistant.core import async_get_hass from homeassistant.helpers.schema_config_entry_flow import ( + SchemaCommonFlowHandler, SchemaConfigFlowHandler, SchemaFlowError, SchemaFlowFormStep, @@ -113,7 +114,9 @@ SENSOR_SETUP = { } -def validate_rest_setup(user_input: dict[str, Any]) -> dict[str, Any]: +def validate_rest_setup( + handler: SchemaCommonFlowHandler, user_input: dict[str, Any] +) -> dict[str, Any]: """Validate rest setup.""" hass = async_get_hass() rest_config: dict[str, Any] = COMBINED_SCHEMA(user_input) @@ -124,7 +127,9 @@ def validate_rest_setup(user_input: dict[str, Any]) -> dict[str, Any]: return user_input -def validate_sensor_setup(user_input: dict[str, Any]) -> dict[str, Any]: +def validate_sensor_setup( + handler: SchemaCommonFlowHandler, user_input: dict[str, Any] +) -> dict[str, Any]: """Validate sensor setup.""" return { "sensor": [ diff --git a/homeassistant/components/threshold/config_flow.py b/homeassistant/components/threshold/config_flow.py index 52b2fb44be9..373e48e7ba3 100644 --- a/homeassistant/components/threshold/config_flow.py +++ b/homeassistant/components/threshold/config_flow.py @@ -10,6 +10,7 @@ from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN from homeassistant.const import CONF_ENTITY_ID, CONF_NAME from homeassistant.helpers import selector from homeassistant.helpers.schema_config_entry_flow import ( + SchemaCommonFlowHandler, SchemaConfigFlowHandler, SchemaFlowError, SchemaFlowFormStep, @@ -18,11 +19,13 @@ from homeassistant.helpers.schema_config_entry_flow import ( from .const import CONF_HYSTERESIS, CONF_LOWER, CONF_UPPER, DEFAULT_HYSTERESIS, DOMAIN -def _validate_mode(data: Any) -> Any: +def _validate_mode( + handler: SchemaCommonFlowHandler, user_input: dict[str, Any] +) -> dict[str, Any]: """Validate the threshold mode, and set limits to None if not set.""" - if CONF_LOWER not in data and CONF_UPPER not in data: + if CONF_LOWER not in user_input and CONF_UPPER not in user_input: raise SchemaFlowError("need_lower_upper") - return {CONF_LOWER: None, CONF_UPPER: None, **data} + return {CONF_LOWER: None, CONF_UPPER: None, **user_input} OPTIONS_SCHEMA = vol.Schema( diff --git a/homeassistant/components/utility_meter/config_flow.py b/homeassistant/components/utility_meter/config_flow.py index 0f43ccf29cc..59bde5ac300 100644 --- a/homeassistant/components/utility_meter/config_flow.py +++ b/homeassistant/components/utility_meter/config_flow.py @@ -10,6 +10,7 @@ from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN from homeassistant.const import CONF_NAME from homeassistant.helpers import selector from homeassistant.helpers.schema_config_entry_flow import ( + SchemaCommonFlowHandler, SchemaConfigFlowHandler, SchemaFlowError, SchemaFlowFormStep, @@ -46,14 +47,16 @@ METER_TYPES = [ ] -def _validate_config(data: Any) -> Any: +def _validate_config( + handler: SchemaCommonFlowHandler, user_input: dict[str, Any] +) -> dict[str, Any]: """Validate config.""" try: - vol.Unique()(data[CONF_TARIFFS]) + vol.Unique()(user_input[CONF_TARIFFS]) except vol.Invalid as exc: raise SchemaFlowError("tariffs_not_unique") from exc - return data + return user_input OPTIONS_SCHEMA = vol.Schema( diff --git a/homeassistant/helpers/schema_config_entry_flow.py b/homeassistant/helpers/schema_config_entry_flow.py index 526f59bd103..dd4de4e3f33 100644 --- a/homeassistant/helpers/schema_config_entry_flow.py +++ b/homeassistant/helpers/schema_config_entry_flow.py @@ -44,7 +44,9 @@ class SchemaFlowFormStep(SchemaFlowStep): user input is requested. """ - validate_user_input: Callable[[dict[str, Any]], dict[str, Any]] = lambda x: x + validate_user_input: Callable[ + [SchemaCommonFlowHandler, dict[str, Any]], dict[str, Any] + ] | None = None """Optional function to validate user input. - The `validate_user_input` function is called if the schema validates successfully. @@ -124,10 +126,10 @@ class SchemaCommonFlowHandler: ): user_input[str(key.schema)] = key.default() - if user_input is not None and form_step.schema is not None: + if user_input is not None and form_step.validate_user_input is not None: # Do extra validation of user input try: - user_input = form_step.validate_user_input(user_input) + user_input = form_step.validate_user_input(self, user_input) except SchemaFlowError as exc: return self._show_next_step(step_id, exc, user_input)