Make SchemaFlowFormStep functions async (#82962)

* Make validate async in SchemaOptionsFlowHandler

* Adjust group

* Adjust tests

* Move all to async

* Adjust integrations

* Missed an integration

* Missed one

* Rebase to fix conflict
This commit is contained in:
epenet 2022-11-30 12:26:52 +01:00 committed by GitHub
parent 490aec0b11
commit 98f263c289
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 80 additions and 54 deletions

View file

@ -69,7 +69,7 @@ OPTIONS_SCHEMA = vol.Schema(
) )
def get_options_schema(handler: SchemaCommonFlowHandler) -> vol.Schema: async def get_options_schema(handler: SchemaCommonFlowHandler) -> vol.Schema:
"""Get options schema.""" """Get options schema."""
options_flow: SchemaOptionsFlowHandler options_flow: SchemaOptionsFlowHandler
options_flow = cast(SchemaOptionsFlowHandler, handler.parent_handler) options_flow = cast(SchemaOptionsFlowHandler, handler.parent_handler)

View file

@ -58,7 +58,7 @@ OPTIONS_SCHEMA = vol.Schema(
) )
def _options_suggested_values(handler: SchemaCommonFlowHandler) -> dict[str, Any]: async def _options_suggested_values(handler: SchemaCommonFlowHandler) -> dict[str, Any]:
parent_handler = cast(SchemaOptionsFlowHandler, handler.parent_handler) parent_handler = cast(SchemaOptionsFlowHandler, handler.parent_handler)
suggested_values = copy.deepcopy(dict(parent_handler.config_entry.data)) suggested_values = copy.deepcopy(dict(parent_handler.config_entry.data))
suggested_values.update(parent_handler.options) suggested_values.update(parent_handler.options)

View file

@ -1,7 +1,7 @@
"""Config flow for Group integration.""" """Config flow for Group integration."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Mapping from collections.abc import Callable, Coroutine, Mapping
from functools import partial from functools import partial
from typing import Any, cast from typing import Any, cast
@ -24,7 +24,7 @@ from .binary_sensor import CONF_ALL
from .const import CONF_HIDE_MEMBERS from .const import CONF_HIDE_MEMBERS
def basic_group_options_schema( async def basic_group_options_schema(
domain: str, handler: SchemaCommonFlowHandler domain: str, handler: SchemaCommonFlowHandler
) -> vol.Schema: ) -> vol.Schema:
"""Generate options schema.""" """Generate options schema."""
@ -52,9 +52,9 @@ def basic_group_config_schema(domain: str) -> vol.Schema:
) )
def binary_sensor_options_schema(handler: SchemaCommonFlowHandler) -> vol.Schema: async def binary_sensor_options_schema(handler: SchemaCommonFlowHandler) -> vol.Schema:
"""Generate options schema.""" """Generate options schema."""
return basic_group_options_schema("binary_sensor", handler).extend( return (await basic_group_options_schema("binary_sensor", handler)).extend(
{ {
vol.Required(CONF_ALL, default=False): selector.BooleanSelector(), vol.Required(CONF_ALL, default=False): selector.BooleanSelector(),
} }
@ -68,11 +68,11 @@ BINARY_SENSOR_CONFIG_SCHEMA = basic_group_config_schema("binary_sensor").extend(
) )
def light_switch_options_schema( async def light_switch_options_schema(
domain: str, handler: SchemaCommonFlowHandler domain: str, handler: SchemaCommonFlowHandler
) -> vol.Schema: ) -> vol.Schema:
"""Generate options schema.""" """Generate options schema."""
return basic_group_options_schema(domain, handler).extend( return (await basic_group_options_schema(domain, handler)).extend(
{ {
vol.Required( vol.Required(
CONF_ALL, default=False, description={"advanced": True} CONF_ALL, default=False, description={"advanced": True}
@ -92,19 +92,19 @@ GROUP_TYPES = [
] ]
@callback async def choose_options_step(options: dict[str, Any]) -> str:
def choose_options_step(options: dict[str, Any]) -> str:
"""Return next step_id for options flow according to group_type.""" """Return next step_id for options flow according to group_type."""
return cast(str, options["group_type"]) return cast(str, options["group_type"])
def set_group_type( def set_group_type(
group_type: str, group_type: str,
) -> Callable[[SchemaCommonFlowHandler, dict[str, Any]], dict[str, Any]]: ) -> Callable[
[SchemaCommonFlowHandler, dict[str, Any]], Coroutine[Any, Any, dict[str, Any]]
]:
"""Set group type.""" """Set group type."""
@callback async def _set_group_type(
def _set_group_type(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any] handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Add group type to user input.""" """Add group type to user input."""
@ -116,23 +116,32 @@ def set_group_type(
CONFIG_FLOW = { CONFIG_FLOW = {
"user": SchemaFlowMenuStep(GROUP_TYPES), "user": SchemaFlowMenuStep(GROUP_TYPES),
"binary_sensor": SchemaFlowFormStep( "binary_sensor": SchemaFlowFormStep(
BINARY_SENSOR_CONFIG_SCHEMA, set_group_type("binary_sensor") BINARY_SENSOR_CONFIG_SCHEMA,
validate_user_input=set_group_type("binary_sensor"),
), ),
"cover": SchemaFlowFormStep( "cover": SchemaFlowFormStep(
basic_group_config_schema("cover"), set_group_type("cover") basic_group_config_schema("cover"),
validate_user_input=set_group_type("cover"),
),
"fan": SchemaFlowFormStep(
basic_group_config_schema("fan"),
validate_user_input=set_group_type("fan"),
), ),
"fan": SchemaFlowFormStep(basic_group_config_schema("fan"), set_group_type("fan")),
"light": SchemaFlowFormStep( "light": SchemaFlowFormStep(
basic_group_config_schema("light"), set_group_type("light") basic_group_config_schema("light"),
validate_user_input=set_group_type("light"),
), ),
"lock": SchemaFlowFormStep( "lock": SchemaFlowFormStep(
basic_group_config_schema("lock"), set_group_type("lock") basic_group_config_schema("lock"),
validate_user_input=set_group_type("lock"),
), ),
"media_player": SchemaFlowFormStep( "media_player": SchemaFlowFormStep(
basic_group_config_schema("media_player"), set_group_type("media_player") basic_group_config_schema("media_player"),
validate_user_input=set_group_type("media_player"),
), ),
"switch": SchemaFlowFormStep( "switch": SchemaFlowFormStep(
basic_group_config_schema("switch"), set_group_type("switch") basic_group_config_schema("switch"),
validate_user_input=set_group_type("switch"),
), ),
} }

View file

@ -116,7 +116,7 @@ SENSOR_SETUP = {
} }
def validate_rest_setup( async def validate_rest_setup(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any] handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Validate rest setup.""" """Validate rest setup."""
@ -129,7 +129,7 @@ def validate_rest_setup(
return user_input return user_input
def validate_sensor_setup( async def validate_sensor_setup(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any] handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Validate sensor input.""" """Validate sensor input."""
@ -143,7 +143,7 @@ def validate_sensor_setup(
return {} return {}
def get_remove_sensor_schema(handler: SchemaCommonFlowHandler) -> vol.Schema: async def get_remove_sensor_schema(handler: SchemaCommonFlowHandler) -> vol.Schema:
"""Return schema for sensor removal.""" """Return schema for sensor removal."""
return vol.Schema( return vol.Schema(
{ {
@ -157,7 +157,7 @@ def get_remove_sensor_schema(handler: SchemaCommonFlowHandler) -> vol.Schema:
) )
def validate_remove_sensor( async def validate_remove_sensor(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any] handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Validate remove sensor.""" """Validate remove sensor."""

View file

@ -19,7 +19,7 @@ from homeassistant.helpers.schema_config_entry_flow import (
from .const import CONF_HYSTERESIS, CONF_LOWER, CONF_UPPER, DEFAULT_HYSTERESIS, DOMAIN from .const import CONF_HYSTERESIS, CONF_LOWER, CONF_UPPER, DEFAULT_HYSTERESIS, DOMAIN
def _validate_mode( async def _validate_mode(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any] handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Validate the threshold mode, and set limits to None if not set.""" """Validate the threshold mode, and set limits to None if not set."""

View file

@ -47,7 +47,7 @@ METER_TYPES = [
] ]
def _validate_config( async def _validate_config(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any] handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Validate config.""" """Validate config."""

View file

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable, Mapping from collections.abc import Callable, Coroutine, Mapping
import copy import copy
from dataclasses import dataclass from dataclasses import dataclass
import types import types
@ -32,7 +32,7 @@ class SchemaFlowFormStep(SchemaFlowStep):
"""Define a config or options flow form step.""" """Define a config or options flow form step."""
schema: vol.Schema | Callable[ schema: vol.Schema | Callable[
[SchemaCommonFlowHandler], vol.Schema | None [SchemaCommonFlowHandler], Coroutine[Any, Any, vol.Schema | None]
] | None = None ] | None = None
"""Optional voluptuous schema, or function which returns a schema or None, for """Optional voluptuous schema, or function which returns a schema or None, for
requesting and validating user input. requesting and validating user input.
@ -44,7 +44,7 @@ class SchemaFlowFormStep(SchemaFlowStep):
""" """
validate_user_input: Callable[ validate_user_input: Callable[
[SchemaCommonFlowHandler, dict[str, Any]], dict[str, Any] [SchemaCommonFlowHandler, dict[str, Any]], Coroutine[Any, Any, dict[str, Any]]
] | None = None ] | None = None
"""Optional function to validate user input. """Optional function to validate user input.
@ -54,7 +54,9 @@ class SchemaFlowFormStep(SchemaFlowStep):
- The `validate_user_input` should raise `SchemaFlowError` if user input is invalid. - The `validate_user_input` should raise `SchemaFlowError` if user input is invalid.
""" """
next_step: Callable[[dict[str, Any]], str | None] | str | None = None next_step: Callable[
[dict[str, Any]], Coroutine[Any, Any, str | None]
] | str | None = None
"""Optional property to identify next step. """Optional property to identify next step.
- If `next_step` is a function, it is called if the schema validates successfully or - If `next_step` is a function, it is called if the schema validates successfully or
@ -65,7 +67,7 @@ class SchemaFlowFormStep(SchemaFlowStep):
""" """
suggested_values: Callable[ suggested_values: Callable[
[SchemaCommonFlowHandler], dict[str, Any] [SchemaCommonFlowHandler], Coroutine[Any, Any, dict[str, Any]]
] | None | UndefinedType = UNDEFINED ] | None | UndefinedType = UNDEFINED
"""Optional property to populate suggested values. """Optional property to populate suggested values.
@ -127,12 +129,12 @@ class SchemaCommonFlowHandler:
return await self._async_form_step(step_id, user_input) return await self._async_form_step(step_id, user_input)
return await self._async_menu_step(step_id, user_input) return await self._async_menu_step(step_id, user_input)
def _get_schema(self, form_step: SchemaFlowFormStep) -> vol.Schema | None: async def _get_schema(self, form_step: SchemaFlowFormStep) -> vol.Schema | None:
if form_step.schema is None: if form_step.schema is None:
return None return None
if isinstance(form_step.schema, vol.Schema): if isinstance(form_step.schema, vol.Schema):
return form_step.schema return form_step.schema
return form_step.schema(self) return await form_step.schema(self)
async def _async_form_step( async def _async_form_step(
self, step_id: str, user_input: dict[str, Any] | None = None self, step_id: str, user_input: dict[str, Any] | None = None
@ -142,7 +144,7 @@ class SchemaCommonFlowHandler:
if ( if (
user_input is not None user_input is not None
and (data_schema := self._get_schema(form_step)) and (data_schema := await self._get_schema(form_step))
and data_schema.schema and data_schema.schema
and not self._handler.show_advanced_options and not self._handler.show_advanced_options
): ):
@ -160,35 +162,35 @@ class SchemaCommonFlowHandler:
if user_input is not None and form_step.validate_user_input is not None: if user_input is not None and form_step.validate_user_input is not None:
# Do extra validation of user input # Do extra validation of user input
try: try:
user_input = form_step.validate_user_input(self, user_input) user_input = await form_step.validate_user_input(self, user_input)
except SchemaFlowError as exc: except SchemaFlowError as exc:
return self._show_next_step(step_id, exc, user_input) return await self._show_next_step(step_id, exc, user_input)
if user_input is not None: if user_input is not None:
# User input was validated successfully, update options # User input was validated successfully, update options
self._options.update(user_input) self._options.update(user_input)
if user_input is not None or form_step.schema is None: if user_input is not None or form_step.schema is None:
return self._show_next_step_or_create_entry(form_step) return await self._show_next_step_or_create_entry(form_step)
return self._show_next_step(step_id) return await self._show_next_step(step_id)
def _show_next_step_or_create_entry( async def _show_next_step_or_create_entry(
self, form_step: SchemaFlowFormStep self, form_step: SchemaFlowFormStep
) -> FlowResult: ) -> FlowResult:
next_step_id_or_end_flow: str | None next_step_id_or_end_flow: str | None
if callable(form_step.next_step): if callable(form_step.next_step):
next_step_id_or_end_flow = form_step.next_step(self._options) next_step_id_or_end_flow = await form_step.next_step(self._options)
else: else:
next_step_id_or_end_flow = form_step.next_step next_step_id_or_end_flow = form_step.next_step
if next_step_id_or_end_flow is None: if next_step_id_or_end_flow is None:
# Flow done, create entry or update config entry options # Flow done, create entry or update config entry options
return self._handler.async_create_entry(data=self._options) return self._handler.async_create_entry(data=self._options)
return self._show_next_step(next_step_id_or_end_flow) return await self._show_next_step(next_step_id_or_end_flow)
def _show_next_step( async def _show_next_step(
self, self,
next_step_id: str, next_step_id: str,
error: SchemaFlowError | None = None, error: SchemaFlowError | None = None,
@ -204,14 +206,14 @@ class SchemaCommonFlowHandler:
form_step = cast(SchemaFlowFormStep, self._flow[next_step_id]) form_step = cast(SchemaFlowFormStep, self._flow[next_step_id])
if (data_schema := self._get_schema(form_step)) is None: if (data_schema := await self._get_schema(form_step)) is None:
return self._show_next_step_or_create_entry(form_step) return await self._show_next_step_or_create_entry(form_step)
suggested_values: dict[str, Any] = {} suggested_values: dict[str, Any] = {}
if form_step.suggested_values is UNDEFINED: if form_step.suggested_values is UNDEFINED:
suggested_values = self._options suggested_values = self._options
elif form_step.suggested_values: elif form_step.suggested_values:
suggested_values = form_step.suggested_values(self) suggested_values = await form_step.suggested_values(self)
if user_input: if user_input:
# We don't want to mutate the existing options # We don't want to mutate the existing options

View file

@ -303,9 +303,12 @@ async def test_menu_step(hass: HomeAssistant) -> None:
MENU_1 = ["option1", "option2"] MENU_1 = ["option1", "option2"]
MENU_2 = ["option3", "option4"] MENU_2 = ["option3", "option4"]
async def _option1_next_step(_: dict[str, Any]) -> str:
return "menu2"
CONFIG_FLOW: dict[str, SchemaFlowFormStep | SchemaFlowMenuStep] = { CONFIG_FLOW: dict[str, SchemaFlowFormStep | SchemaFlowMenuStep] = {
"user": SchemaFlowMenuStep(MENU_1), "user": SchemaFlowMenuStep(MENU_1),
"option1": SchemaFlowFormStep(vol.Schema({}), next_step=lambda _: "menu2"), "option1": SchemaFlowFormStep(vol.Schema({}), next_step=_option1_next_step),
"menu2": SchemaFlowMenuStep(MENU_2), "menu2": SchemaFlowMenuStep(MENU_2),
"option3": SchemaFlowFormStep(vol.Schema({}), next_step="option4"), "option3": SchemaFlowFormStep(vol.Schema({}), next_step="option4"),
"option4": SchemaFlowFormStep(vol.Schema({})), "option4": SchemaFlowFormStep(vol.Schema({})),
@ -384,10 +387,13 @@ async def test_schema_none(hass: HomeAssistant) -> None:
async def test_last_step(hass: HomeAssistant) -> None: async def test_last_step(hass: HomeAssistant) -> None:
"""Test SchemaFlowFormStep with schema set to None.""" """Test SchemaFlowFormStep with schema set to None."""
async def _step2_next_step(_: dict[str, Any]) -> str:
return "step3"
CONFIG_FLOW: dict[str, SchemaFlowFormStep | SchemaFlowMenuStep] = { CONFIG_FLOW: dict[str, SchemaFlowFormStep | SchemaFlowMenuStep] = {
"user": SchemaFlowFormStep(next_step="step1"), "user": SchemaFlowFormStep(next_step="step1"),
"step1": SchemaFlowFormStep(vol.Schema({}), next_step="step2"), "step1": SchemaFlowFormStep(vol.Schema({}), next_step="step2"),
"step2": SchemaFlowFormStep(vol.Schema({}), next_step=lambda _: "step3"), "step2": SchemaFlowFormStep(vol.Schema({}), next_step=_step2_next_step),
"step3": SchemaFlowFormStep(vol.Schema({}), next_step=None), "step3": SchemaFlowFormStep(vol.Schema({}), next_step=None),
} }
@ -422,10 +428,16 @@ async def test_last_step(hass: HomeAssistant) -> None:
async def test_next_step_function(hass: HomeAssistant) -> None: async def test_next_step_function(hass: HomeAssistant) -> None:
"""Test SchemaFlowFormStep with a next_step function.""" """Test SchemaFlowFormStep with a next_step function."""
async def _step1_next_step(_: dict[str, Any]) -> str:
return "step2"
async def _step2_next_step(_: dict[str, Any]) -> None:
return None
CONFIG_FLOW: dict[str, SchemaFlowFormStep | SchemaFlowMenuStep] = { CONFIG_FLOW: dict[str, SchemaFlowFormStep | SchemaFlowMenuStep] = {
"user": SchemaFlowFormStep(next_step="step1"), "user": SchemaFlowFormStep(next_step="step1"),
"step1": SchemaFlowFormStep(vol.Schema({}), next_step=lambda _: "step2"), "step1": SchemaFlowFormStep(vol.Schema({}), next_step=_step1_next_step),
"step2": SchemaFlowFormStep(vol.Schema({}), next_step=lambda _: None), "step2": SchemaFlowFormStep(vol.Schema({}), next_step=_step2_next_step),
} }
class TestConfigFlow(SchemaConfigFlowHandler, domain=TEST_DOMAIN): class TestConfigFlow(SchemaConfigFlowHandler, domain=TEST_DOMAIN):
@ -459,19 +471,22 @@ async def test_suggested_values(
{vol.Optional("option1", default="a very reasonable default"): str} {vol.Optional("option1", default="a very reasonable default"): str}
) )
def _validate_user_input( async def _validate_user_input(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any] handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
if user_input["option1"] == "not a valid value": if user_input["option1"] == "not a valid value":
raise SchemaFlowError("option1 not using a valid value") raise SchemaFlowError("option1 not using a valid value")
return user_input return user_input
async def _step_2_suggested_values(_: SchemaCommonFlowHandler) -> dict[str, Any]:
return {"option1": "a random override"}
OPTIONS_FLOW: dict[str, SchemaFlowFormStep | SchemaFlowMenuStep] = { OPTIONS_FLOW: dict[str, SchemaFlowFormStep | SchemaFlowMenuStep] = {
"init": SchemaFlowFormStep(OPTIONS_SCHEMA, next_step="step_1"), "init": SchemaFlowFormStep(OPTIONS_SCHEMA, next_step="step_1"),
"step_1": SchemaFlowFormStep(OPTIONS_SCHEMA, next_step="step_2"), "step_1": SchemaFlowFormStep(OPTIONS_SCHEMA, next_step="step_2"),
"step_2": SchemaFlowFormStep( "step_2": SchemaFlowFormStep(
OPTIONS_SCHEMA, OPTIONS_SCHEMA,
suggested_values=lambda _: {"option1": "a random override"}, suggested_values=_step_2_suggested_values,
next_step="step_3", next_step="step_3",
), ),
"step_3": SchemaFlowFormStep( "step_3": SchemaFlowFormStep(
@ -565,16 +580,16 @@ async def test_options_flow_state(hass: HomeAssistant) -> None:
{vol.Optional("option1", default="a very reasonable default"): str} {vol.Optional("option1", default="a very reasonable default"): str}
) )
def _init_schema(handler: SchemaCommonFlowHandler) -> None: async def _init_schema(handler: SchemaCommonFlowHandler) -> None:
handler.flow_state["idx"] = None handler.flow_state["idx"] = None
def _validate_step1_input( async def _validate_step1_input(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any] handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
handler.flow_state["idx"] = user_input["option1"] handler.flow_state["idx"] = user_input["option1"]
return user_input return user_input
def _validate_step2_input( async def _validate_step2_input(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any] handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
user_input["idx_from_flow_state"] = handler.flow_state["idx"] user_input["idx_from_flow_state"] = handler.flow_state["idx"]