From 9feb64cebdb841ad9f94f0d197d68e146acd548f Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Fri, 25 Nov 2022 10:50:38 +0100 Subject: [PATCH] Simplify schema callback in SchemaFlowFormStep (#82682) * Simplify SchemaFlowFormStep.schema callback * Expose parent handler * Adjust docstrings --- homeassistant/components/group/config_flow.py | 23 +++++--------- .../helpers/schema_config_entry_flow.py | 30 +++++++++---------- 2 files changed, 23 insertions(+), 30 deletions(-) diff --git a/homeassistant/components/group/config_flow.py b/homeassistant/components/group/config_flow.py index 1dd01f6193a..4d689914378 100644 --- a/homeassistant/components/group/config_flow.py +++ b/homeassistant/components/group/config_flow.py @@ -25,16 +25,14 @@ from .const import CONF_HIDE_MEMBERS def basic_group_options_schema( - domain: str, - handler: SchemaConfigFlowHandler | SchemaOptionsFlowHandler, - options: dict[str, Any], + domain: str, handler: SchemaCommonFlowHandler ) -> vol.Schema: """Generate options schema.""" - handler = cast(SchemaOptionsFlowHandler, handler) return vol.Schema( { vol.Required(CONF_ENTITIES): entity_selector_without_own_entities( - handler, selector.EntitySelectorConfig(domain=domain, multiple=True) + cast(SchemaOptionsFlowHandler, handler.parent_handler), + selector.EntitySelectorConfig(domain=domain, multiple=True), ), vol.Required(CONF_HIDE_MEMBERS, default=False): selector.BooleanSelector(), } @@ -54,12 +52,9 @@ def basic_group_config_schema(domain: str) -> vol.Schema: ) -def binary_sensor_options_schema( - handler: SchemaConfigFlowHandler | SchemaOptionsFlowHandler, - options: dict[str, Any], -) -> vol.Schema: +def binary_sensor_options_schema(handler: SchemaCommonFlowHandler) -> vol.Schema: """Generate options schema.""" - return basic_group_options_schema("binary_sensor", handler, options).extend( + return basic_group_options_schema("binary_sensor", handler).extend( { vol.Required(CONF_ALL, default=False): selector.BooleanSelector(), } @@ -74,12 +69,10 @@ BINARY_SENSOR_CONFIG_SCHEMA = basic_group_config_schema("binary_sensor").extend( def light_switch_options_schema( - domain: str, - handler: SchemaConfigFlowHandler | SchemaOptionsFlowHandler, - options: dict[str, Any], + domain: str, handler: SchemaCommonFlowHandler ) -> vol.Schema: """Generate options schema.""" - return basic_group_options_schema(domain, handler, options).extend( + return basic_group_options_schema(domain, handler).extend( { vol.Required( CONF_ALL, default=False, description={"advanced": True} @@ -145,7 +138,7 @@ CONFIG_FLOW = { OPTIONS_FLOW = { - "init": SchemaFlowFormStep(None, next_step=choose_options_step), + "init": SchemaFlowFormStep(next_step=choose_options_step), "binary_sensor": SchemaFlowFormStep(binary_sensor_options_schema), "cover": SchemaFlowFormStep(partial(basic_group_options_schema, "cover")), "fan": SchemaFlowFormStep(partial(basic_group_options_schema, "fan")), diff --git a/homeassistant/helpers/schema_config_entry_flow.py b/homeassistant/helpers/schema_config_entry_flow.py index dd4de4e3f33..c721b9d4ad6 100644 --- a/homeassistant/helpers/schema_config_entry_flow.py +++ b/homeassistant/helpers/schema_config_entry_flow.py @@ -31,15 +31,13 @@ class SchemaFlowFormStep(SchemaFlowStep): """Define a config or options flow form step.""" schema: vol.Schema | Callable[ - [SchemaConfigFlowHandler | SchemaOptionsFlowHandler, dict[str, Any]], - vol.Schema | None, - ] | None + [SchemaCommonFlowHandler], vol.Schema | None + ] | None = None """Optional voluptuous schema, or function which returns a schema or None, for requesting and validating user input. - - If a function is specified, the function will be passed the handler, which is - either an instance of SchemaConfigFlowHandler or SchemaOptionsFlowHandler, and the - union of config entry options and user input from previous steps. + - If a function is specified, the function will be passed the current + `SchemaCommonFlowHandler`. - If schema validation fails, the step will be retried. If the schema is None, no user input is requested. """ @@ -50,7 +48,8 @@ class SchemaFlowFormStep(SchemaFlowStep): """Optional function to validate user input. - The `validate_user_input` function is called if the schema validates successfully. - - The `validate_user_input` function is passed the user input from the current step. + - The first argument is a reference to the current `SchemaCommonFlowHandler`. + - The second argument is the user input from the current step. - The `validate_user_input` should raise `SchemaFlowError` is user input is invalid. """ @@ -86,6 +85,11 @@ class SchemaCommonFlowHandler: self._handler = handler self._options = options if options is not None else {} + @property + def parent_handler(self) -> SchemaConfigFlowHandler | SchemaOptionsFlowHandler: + """Return parent handler.""" + return self._handler + async def async_step( self, step_id: str, user_input: dict[str, Any] | None = None ) -> FlowResult: @@ -94,14 +98,12 @@ class SchemaCommonFlowHandler: return await self._async_form_step(step_id, user_input) return await self._async_menu_step(step_id, user_input) - def _get_schema( - self, form_step: SchemaFlowFormStep, options: dict[str, Any] - ) -> vol.Schema | None: + def _get_schema(self, form_step: SchemaFlowFormStep) -> vol.Schema | None: if form_step.schema is None: return None if isinstance(form_step.schema, vol.Schema): return form_step.schema - return form_step.schema(self._handler, options) + return form_step.schema(self) async def _async_form_step( self, step_id: str, user_input: dict[str, Any] | None = None @@ -111,7 +113,7 @@ class SchemaCommonFlowHandler: if ( user_input is not None - and (data_schema := self._get_schema(form_step, self._options)) + and (data_schema := self._get_schema(form_step)) and data_schema.schema and not self._handler.show_advanced_options ): @@ -171,9 +173,7 @@ class SchemaCommonFlowHandler: form_step = cast(SchemaFlowFormStep, self._flow[next_step_id]) - if ( - data_schema := self._get_schema(form_step, self._options) - ) and data_schema.schema: + if (data_schema := self._get_schema(form_step)) and data_schema.schema: # Make a copy of the schema with suggested values set to saved options schema = {} for key, val in data_schema.schema.items():