Improve Selector typing (#82636)

This commit is contained in:
epenet 2022-11-24 19:27:26 +01:00 committed by GitHub
parent 34633b0ede
commit 9132c42037
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,8 +1,8 @@
"""Selectors for Home Assistant."""
from __future__ import annotations
from collections.abc import Callable, Sequence
from typing import Any, Literal, TypedDict, cast
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Generic, Literal, TypedDict, TypeVar, cast
from uuid import UUID
import voluptuous as vol
@ -17,6 +17,8 @@ from . import config_validation as cv
SELECTORS: decorator.Registry[str, type[Selector]] = decorator.Registry()
_T = TypeVar("_T", bound=Mapping[str, Any])
def _get_selector_class(config: Any) -> type[Selector]:
"""Get selector class type."""
@ -56,14 +58,14 @@ def validate_selector(config: Any) -> dict:
}
class Selector:
class Selector(Generic[_T]):
"""Base class for selectors."""
CONFIG_SCHEMA: Callable
config: Any
config: _T
selector_type: str
def __init__(self, config: Any = None) -> None:
def __init__(self, config: Mapping[str, Any] | None = None) -> None:
"""Instantiate a selector."""
# Selectors can be empty
if config is None:
@ -71,7 +73,7 @@ class Selector:
self.config = self.CONFIG_SCHEMA(config)
def serialize(self) -> Any:
def serialize(self) -> dict[str, dict[str, _T]]:
"""Serialize Selector for voluptuous_serialize."""
return {"selector": {self.selector_type: self.config}}
@ -124,7 +126,7 @@ class ActionSelectorConfig(TypedDict):
@SELECTORS.register("action")
class ActionSelector(Selector):
class ActionSelector(Selector[ActionSelectorConfig]):
"""Selector of an action sequence (script syntax)."""
selector_type = "action"
@ -148,7 +150,7 @@ class AddonSelectorConfig(TypedDict, total=False):
@SELECTORS.register("addon")
class AddonSelector(Selector):
class AddonSelector(Selector[AddonSelectorConfig]):
"""Selector of a add-on."""
selector_type = "addon"
@ -179,7 +181,7 @@ class AreaSelectorConfig(TypedDict, total=False):
@SELECTORS.register("area")
class AreaSelector(Selector):
class AreaSelector(Selector[AreaSelectorConfig]):
"""Selector of a single or list of areas."""
selector_type = "area"
@ -214,7 +216,7 @@ class AttributeSelectorConfig(TypedDict, total=False):
@SELECTORS.register("attribute")
class AttributeSelector(Selector):
class AttributeSelector(Selector[AttributeSelectorConfig]):
"""Selector for an entity attribute."""
selector_type = "attribute"
@ -243,7 +245,7 @@ class BooleanSelectorConfig(TypedDict):
@SELECTORS.register("boolean")
class BooleanSelector(Selector):
class BooleanSelector(Selector[BooleanSelectorConfig]):
"""Selector of a boolean value."""
selector_type = "boolean"
@ -265,7 +267,7 @@ class ColorRGBSelectorConfig(TypedDict):
@SELECTORS.register("color_rgb")
class ColorRGBSelector(Selector):
class ColorRGBSelector(Selector[ColorRGBSelectorConfig]):
"""Selector of an RGB color value."""
selector_type = "color_rgb"
@ -290,7 +292,7 @@ class ColorTempSelectorConfig(TypedDict, total=False):
@SELECTORS.register("color_temp")
class ColorTempSelector(Selector):
class ColorTempSelector(Selector[ColorTempSelectorConfig]):
"""Selector of an color temperature."""
selector_type = "color_temp"
@ -325,7 +327,7 @@ class ConfigEntrySelectorConfig(TypedDict, total=False):
@SELECTORS.register("config_entry")
class ConfigEntrySelector(Selector):
class ConfigEntrySelector(Selector[ConfigEntrySelectorConfig]):
"""Selector of a config entry."""
selector_type = "config_entry"
@ -351,7 +353,7 @@ class DateSelectorConfig(TypedDict):
@SELECTORS.register("date")
class DateSelector(Selector):
class DateSelector(Selector[DateSelectorConfig]):
"""Selector of a date."""
selector_type = "date"
@ -373,7 +375,7 @@ class DateTimeSelectorConfig(TypedDict):
@SELECTORS.register("datetime")
class DateTimeSelector(Selector):
class DateTimeSelector(Selector[DateTimeSelectorConfig]):
"""Selector of a datetime."""
selector_type = "datetime"
@ -401,7 +403,7 @@ class DeviceSelectorConfig(TypedDict, total=False):
@SELECTORS.register("device")
class DeviceSelector(Selector):
class DeviceSelector(Selector[DeviceSelectorConfig]):
"""Selector of a single or list of devices."""
selector_type = "device"
@ -431,7 +433,7 @@ class DurationSelectorConfig(TypedDict, total=False):
@SELECTORS.register("duration")
class DurationSelector(Selector):
class DurationSelector(Selector[DurationSelectorConfig]):
"""Selector for a duration."""
selector_type = "duration"
@ -463,7 +465,7 @@ class EntitySelectorConfig(SingleEntitySelectorConfig, total=False):
@SELECTORS.register("entity")
class EntitySelector(Selector):
class EntitySelector(Selector[EntitySelectorConfig]):
"""Selector of a single or list of entities."""
selector_type = "entity"
@ -517,7 +519,7 @@ class IconSelectorConfig(TypedDict, total=False):
@SELECTORS.register("icon")
class IconSelector(Selector):
class IconSelector(Selector[IconSelectorConfig]):
"""Selector for an icon."""
selector_type = "icon"
@ -545,7 +547,7 @@ class LocationSelectorConfig(TypedDict, total=False):
@SELECTORS.register("location")
class LocationSelector(Selector):
class LocationSelector(Selector[LocationSelectorConfig]):
"""Selector for a location."""
selector_type = "location"
@ -576,7 +578,7 @@ class MediaSelectorConfig(TypedDict):
@SELECTORS.register("media")
class MediaSelector(Selector):
class MediaSelector(Selector[MediaSelectorConfig]):
"""Selector for media."""
selector_type = "media"
@ -636,7 +638,7 @@ def validate_slider(data: Any) -> Any:
@SELECTORS.register("number")
class NumberSelector(Selector):
class NumberSelector(Selector[NumberSelectorConfig]):
"""Selector of a numeric value."""
selector_type = "number"
@ -682,7 +684,7 @@ class ObjectSelectorConfig(TypedDict):
@SELECTORS.register("object")
class ObjectSelector(Selector):
class ObjectSelector(Selector[ObjectSelectorConfig]):
"""Selector for an arbitrary object."""
selector_type = "object"
@ -733,7 +735,7 @@ class SelectSelectorConfig(TypedDict, total=False):
@SELECTORS.register("select")
class SelectSelector(Selector):
class SelectSelector(Selector[SelectSelectorConfig]):
"""Selector for an single-choice input select."""
selector_type = "select"
@ -755,12 +757,15 @@ class SelectSelector(Selector):
def __call__(self, data: Any) -> Any:
"""Validate the passed selection."""
options = []
if self.config["options"]:
if isinstance(self.config["options"][0], str):
options = self.config["options"]
options: Sequence[str] = []
if config_options := self.config["options"]:
if isinstance(config_options[0], str):
options = cast(Sequence[str], config_options)
else:
options = [option["value"] for option in self.config["options"]]
options = [
option["value"]
for option in cast(Sequence[SelectOptionDict], config_options)
]
parent_schema = vol.In(options)
if self.config["custom_value"]:
@ -787,7 +792,7 @@ class StateSelectorConfig(TypedDict, total=False):
@SELECTORS.register("state")
class StateSelector(Selector):
class StateSelector(Selector[StateSelectorConfig]):
"""Selector for an entity state."""
selector_type = "state"
@ -814,7 +819,7 @@ class StateSelector(Selector):
@SELECTORS.register("target")
class TargetSelector(Selector):
class TargetSelector(Selector[TargetSelectorConfig]):
"""Selector of a target value (area ID, device ID, entity ID etc).
Value should follow cv.TARGET_SERVICE_FIELDS format.
@ -846,7 +851,7 @@ class TemplateSelectorConfig(TypedDict):
@SELECTORS.register("template")
class TemplateSelector(Selector):
class TemplateSelector(Selector[TemplateSelectorConfig]):
"""Selector for an template."""
selector_type = "template"
@ -891,7 +896,7 @@ class TextSelectorType(StrEnum):
@SELECTORS.register("text")
class TextSelector(Selector):
class TextSelector(Selector[TextSelectorConfig]):
"""Selector for a multi-line text string."""
selector_type = "text"
@ -924,7 +929,7 @@ class ThemeSelectorConfig(TypedDict):
@SELECTORS.register("theme")
class ThemeSelector(Selector):
class ThemeSelector(Selector[ThemeSelectorConfig]):
"""Selector for an theme."""
selector_type = "theme"
@ -946,7 +951,7 @@ class TimeSelectorConfig(TypedDict):
@SELECTORS.register("time")
class TimeSelector(Selector):
class TimeSelector(Selector[TimeSelectorConfig]):
"""Selector of a time value."""
selector_type = "time"
@ -970,7 +975,7 @@ class FileSelectorConfig(TypedDict):
@SELECTORS.register("file")
class FileSelector(Selector):
class FileSelector(Selector[FileSelectorConfig]):
"""Selector of a file."""
selector_type = "file"