diff --git a/homeassistant/components/template/config_flow.py b/homeassistant/components/template/config_flow.py index 5c28a68a8ae..f648d5ca8d5 100644 --- a/homeassistant/components/template/config_flow.py +++ b/homeassistant/components/template/config_flow.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable, Coroutine, Mapping +from functools import partial from typing import Any, cast import voluptuous as vol @@ -39,25 +40,34 @@ from .const import DOMAIN from .sensor import async_create_preview_sensor from .template_entity import TemplateEntity +_SCHEMA_STATE: dict[vol.Marker, Any] = { + vol.Required(CONF_STATE): selector.TemplateSelector(), +} -def generate_schema(domain: str, flow_type: str) -> dict[vol.Marker, Any]: + +def generate_schema(domain: str, flow_type: str) -> vol.Schema: """Generate schema.""" schema: dict[vol.Marker, Any] = {} - if domain == Platform.BINARY_SENSOR and flow_type == "config": - schema = { - vol.Optional(CONF_DEVICE_CLASS): selector.SelectSelector( - selector.SelectSelectorConfig( - options=[cls.value for cls in BinarySensorDeviceClass], - mode=selector.SelectSelectorMode.DROPDOWN, - translation_key="binary_sensor_device_class", - sort=True, + if flow_type == "config": + schema = {vol.Required(CONF_NAME): selector.TextSelector()} + + if domain == Platform.BINARY_SENSOR: + schema |= _SCHEMA_STATE + if flow_type == "config": + schema |= { + vol.Optional(CONF_DEVICE_CLASS): selector.SelectSelector( + selector.SelectSelectorConfig( + options=[cls.value for cls in BinarySensorDeviceClass], + mode=selector.SelectSelectorMode.DROPDOWN, + translation_key="binary_sensor_device_class", + sort=True, + ), ), - ) - } + } if domain == Platform.SENSOR: - schema = { + schema |= _SCHEMA_STATE | { vol.Optional(CONF_UNIT_OF_MEASUREMENT): selector.SelectSelector( selector.SelectSelectorConfig( options=list( @@ -98,26 +108,12 @@ def generate_schema(domain: str, flow_type: str) -> dict[vol.Marker, Any]: schema[vol.Optional(CONF_DEVICE_ID)] = selector.DeviceSelector() - return schema + return vol.Schema(schema) -def options_schema(domain: str) -> vol.Schema: - """Generate options schema.""" - return vol.Schema( - {vol.Required(CONF_STATE): selector.TemplateSelector()} - | generate_schema(domain, "option"), - ) +options_schema = partial(generate_schema, flow_type="options") - -def config_schema(domain: str) -> vol.Schema: - """Generate config schema.""" - return vol.Schema( - { - vol.Required(CONF_NAME): selector.TextSelector(), - vol.Required(CONF_STATE): selector.TemplateSelector(), - } - | generate_schema(domain, "config"), - ) +config_schema = partial(generate_schema, flow_type="config") async def choose_options_step(options: dict[str, Any]) -> str: