Add script llm tool (#118936)
* Add script llm tool * Add tests * More tests * more test * more test * Add area and floor resolving * coverage * coverage * fix ColorTempSelector * fix mypy * fix mypy * add script reload test * Cache script tool parameters * Make custom_serializer a part of api --------- Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
parent
77fea8a73e
commit
2386ed3830
14 changed files with 639 additions and 55 deletions
|
@ -3,6 +3,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from google.api_core.exceptions import GoogleAPICallError
|
from google.api_core.exceptions import GoogleAPICallError
|
||||||
|
@ -89,10 +90,14 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _format_tool(tool: llm.Tool) -> dict[str, Any]:
|
def _format_tool(
|
||||||
|
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Format tool specification."""
|
"""Format tool specification."""
|
||||||
|
|
||||||
parameters = _format_schema(convert(tool.parameters))
|
parameters = _format_schema(
|
||||||
|
convert(tool.parameters, custom_serializer=custom_serializer)
|
||||||
|
)
|
||||||
|
|
||||||
return protos.Tool(
|
return protos.Tool(
|
||||||
{
|
{
|
||||||
|
@ -193,7 +198,9 @@ class GoogleGenerativeAIConversationEntity(
|
||||||
f"Error preparing LLM API: {err}",
|
f"Error preparing LLM API: {err}",
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
tools = [_format_tool(tool) for tool in llm_api.tools]
|
tools = [
|
||||||
|
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prompt = await self._async_render_prompt(user_input, llm_api, llm_context)
|
prompt = await self._async_render_prompt(user_input, llm_api, llm_context)
|
||||||
|
|
|
@ -9,5 +9,5 @@
|
||||||
"integration_type": "service",
|
"integration_type": "service",
|
||||||
"iot_class": "cloud_polling",
|
"iot_class": "cloud_polling",
|
||||||
"quality_scale": "platinum",
|
"quality_scale": "platinum",
|
||||||
"requirements": ["google-generativeai==0.6.0", "voluptuous-openapi==0.0.4"]
|
"requirements": ["google-generativeai==0.6.0"]
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
"""Conversation support for OpenAI."""
|
"""Conversation support for OpenAI."""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
import json
|
import json
|
||||||
from typing import Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from openai._types import NOT_GIVEN
|
from openai._types import NOT_GIVEN
|
||||||
|
@ -58,9 +59,14 @@ async def async_setup_entry(
|
||||||
async_add_entities([agent])
|
async_add_entities([agent])
|
||||||
|
|
||||||
|
|
||||||
def _format_tool(tool: llm.Tool) -> ChatCompletionToolParam:
|
def _format_tool(
|
||||||
|
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||||
|
) -> ChatCompletionToolParam:
|
||||||
"""Format tool specification."""
|
"""Format tool specification."""
|
||||||
tool_spec = FunctionDefinition(name=tool.name, parameters=convert(tool.parameters))
|
tool_spec = FunctionDefinition(
|
||||||
|
name=tool.name,
|
||||||
|
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
||||||
|
)
|
||||||
if tool.description:
|
if tool.description:
|
||||||
tool_spec["description"] = tool.description
|
tool_spec["description"] = tool.description
|
||||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||||
|
@ -139,7 +145,9 @@ class OpenAIConversationEntity(
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=user_input.conversation_id
|
response=intent_response, conversation_id=user_input.conversation_id
|
||||||
)
|
)
|
||||||
tools = [_format_tool(tool) for tool in llm_api.tools]
|
tools = [
|
||||||
|
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
||||||
|
]
|
||||||
|
|
||||||
if user_input.conversation_id is None:
|
if user_input.conversation_id is None:
|
||||||
conversation_id = ulid.ulid_now()
|
conversation_id = ulid.ulid_now()
|
||||||
|
|
|
@ -8,5 +8,5 @@
|
||||||
"documentation": "https://www.home-assistant.io/integrations/openai_conversation",
|
"documentation": "https://www.home-assistant.io/integrations/openai_conversation",
|
||||||
"integration_type": "service",
|
"integration_type": "service",
|
||||||
"iot_class": "cloud_polling",
|
"iot_class": "cloud_polling",
|
||||||
"requirements": ["openai==1.3.8", "voluptuous-openapi==0.0.4"]
|
"requirements": ["openai==1.3.8"]
|
||||||
}
|
}
|
||||||
|
|
|
@ -352,7 +352,7 @@ class MatchTargetsCandidate:
|
||||||
matched_name: str | None = None
|
matched_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def _find_areas(
|
def find_areas(
|
||||||
name: str, areas: area_registry.AreaRegistry
|
name: str, areas: area_registry.AreaRegistry
|
||||||
) -> Iterable[area_registry.AreaEntry]:
|
) -> Iterable[area_registry.AreaEntry]:
|
||||||
"""Find all areas matching a name (including aliases)."""
|
"""Find all areas matching a name (including aliases)."""
|
||||||
|
@ -372,7 +372,7 @@ def _find_areas(
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def _find_floors(
|
def find_floors(
|
||||||
name: str, floors: floor_registry.FloorRegistry
|
name: str, floors: floor_registry.FloorRegistry
|
||||||
) -> Iterable[floor_registry.FloorEntry]:
|
) -> Iterable[floor_registry.FloorEntry]:
|
||||||
"""Find all floors matching a name (including aliases)."""
|
"""Find all floors matching a name (including aliases)."""
|
||||||
|
@ -530,7 +530,7 @@ def async_match_targets( # noqa: C901
|
||||||
if not states:
|
if not states:
|
||||||
return MatchTargetsResult(False, MatchFailedReason.STATE)
|
return MatchTargetsResult(False, MatchFailedReason.STATE)
|
||||||
|
|
||||||
# Exit early so we can to avoid registry lookups
|
# Exit early so we can avoid registry lookups
|
||||||
if not (
|
if not (
|
||||||
constraints.name
|
constraints.name
|
||||||
or constraints.features
|
or constraints.features
|
||||||
|
@ -580,7 +580,7 @@ def async_match_targets( # noqa: C901
|
||||||
if constraints.floor_name:
|
if constraints.floor_name:
|
||||||
# Filter by areas associated with floor
|
# Filter by areas associated with floor
|
||||||
fr = floor_registry.async_get(hass)
|
fr = floor_registry.async_get(hass)
|
||||||
targeted_floors = list(_find_floors(constraints.floor_name, fr))
|
targeted_floors = list(find_floors(constraints.floor_name, fr))
|
||||||
if not targeted_floors:
|
if not targeted_floors:
|
||||||
return MatchTargetsResult(
|
return MatchTargetsResult(
|
||||||
False,
|
False,
|
||||||
|
@ -609,7 +609,7 @@ def async_match_targets( # noqa: C901
|
||||||
possible_area_ids = {area.id for area in ar.async_list_areas()}
|
possible_area_ids = {area.id for area in ar.async_list_areas()}
|
||||||
|
|
||||||
if constraints.area_name:
|
if constraints.area_name:
|
||||||
targeted_areas = list(_find_areas(constraints.area_name, ar))
|
targeted_areas = list(find_areas(constraints.area_name, ar))
|
||||||
if not targeted_areas:
|
if not targeted_areas:
|
||||||
return MatchTargetsResult(
|
return MatchTargetsResult(
|
||||||
False,
|
False,
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -11,6 +12,7 @@ from typing import Any
|
||||||
|
|
||||||
import slugify as unicode_slug
|
import slugify as unicode_slug
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
from voluptuous_openapi import UNSUPPORTED, convert
|
||||||
|
|
||||||
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
|
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
|
||||||
from homeassistant.components.conversation.trace import (
|
from homeassistant.components.conversation.trace import (
|
||||||
|
@ -20,22 +22,39 @@ from homeassistant.components.conversation.trace import (
|
||||||
from homeassistant.components.cover.intent import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
from homeassistant.components.cover.intent import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||||
from homeassistant.components.intent import async_device_supports_timers
|
from homeassistant.components.intent import async_device_supports_timers
|
||||||
|
from homeassistant.components.script import ATTR_VARIABLES, DOMAIN as SCRIPT_DOMAIN
|
||||||
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
|
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
|
||||||
from homeassistant.core import Context, HomeAssistant, callback
|
from homeassistant.const import (
|
||||||
|
ATTR_DOMAIN,
|
||||||
|
ATTR_ENTITY_ID,
|
||||||
|
ATTR_SERVICE,
|
||||||
|
EVENT_HOMEASSISTANT_CLOSE,
|
||||||
|
EVENT_SERVICE_REMOVED,
|
||||||
|
SERVICE_TURN_ON,
|
||||||
|
)
|
||||||
|
from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.util import yaml
|
from homeassistant.util import yaml
|
||||||
|
from homeassistant.util.hass_dict import HassKey
|
||||||
from homeassistant.util.json import JsonObjectType
|
from homeassistant.util.json import JsonObjectType
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
area_registry as ar,
|
area_registry as ar,
|
||||||
|
config_validation as cv,
|
||||||
device_registry as dr,
|
device_registry as dr,
|
||||||
entity_registry as er,
|
entity_registry as er,
|
||||||
floor_registry as fr,
|
floor_registry as fr,
|
||||||
intent,
|
intent,
|
||||||
|
selector,
|
||||||
service,
|
service,
|
||||||
)
|
)
|
||||||
from .singleton import singleton
|
from .singleton import singleton
|
||||||
|
|
||||||
|
SCRIPT_PARAMETERS_CACHE: HassKey[dict[str, tuple[str | None, vol.Schema]]] = HassKey(
|
||||||
|
"llm_script_parameters_cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
LLM_API_ASSIST = "assist"
|
LLM_API_ASSIST = "assist"
|
||||||
|
|
||||||
BASE_PROMPT = (
|
BASE_PROMPT = (
|
||||||
|
@ -143,6 +162,7 @@ class APIInstance:
|
||||||
api_prompt: str
|
api_prompt: str
|
||||||
llm_context: LLMContext
|
llm_context: LLMContext
|
||||||
tools: list[Tool]
|
tools: list[Tool]
|
||||||
|
custom_serializer: Callable[[Any], Any] | None = None
|
||||||
|
|
||||||
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
||||||
"""Call a LLM tool, validate args and return the response."""
|
"""Call a LLM tool, validate args and return the response."""
|
||||||
|
@ -284,6 +304,7 @@ class AssistAPI(API):
|
||||||
api_prompt=self._async_get_api_prompt(llm_context, exposed_entities),
|
api_prompt=self._async_get_api_prompt(llm_context, exposed_entities),
|
||||||
llm_context=llm_context,
|
llm_context=llm_context,
|
||||||
tools=self._async_get_tools(llm_context, exposed_entities),
|
tools=self._async_get_tools(llm_context, exposed_entities),
|
||||||
|
custom_serializer=_selector_serializer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -372,7 +393,7 @@ class AssistAPI(API):
|
||||||
exposed_domains: set[str] | None = None
|
exposed_domains: set[str] | None = None
|
||||||
if exposed_entities is not None:
|
if exposed_entities is not None:
|
||||||
exposed_domains = {
|
exposed_domains = {
|
||||||
entity_id.split(".")[0] for entity_id in exposed_entities
|
split_entity_id(entity_id)[0] for entity_id in exposed_entities
|
||||||
}
|
}
|
||||||
intent_handlers = [
|
intent_handlers = [
|
||||||
intent_handler
|
intent_handler
|
||||||
|
@ -381,11 +402,22 @@ class AssistAPI(API):
|
||||||
or intent_handler.platforms & exposed_domains
|
or intent_handler.platforms & exposed_domains
|
||||||
]
|
]
|
||||||
|
|
||||||
return [
|
tools: list[Tool] = [
|
||||||
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
|
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
|
||||||
for intent_handler in intent_handlers
|
for intent_handler in intent_handlers
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if llm_context.assistant is not None:
|
||||||
|
for state in self.hass.states.async_all(SCRIPT_DOMAIN):
|
||||||
|
if not async_should_expose(
|
||||||
|
self.hass, llm_context.assistant, state.entity_id
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
tools.append(ScriptTool(self.hass, state.entity_id))
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
def _get_exposed_entities(
|
def _get_exposed_entities(
|
||||||
hass: HomeAssistant, assistant: str
|
hass: HomeAssistant, assistant: str
|
||||||
|
@ -413,13 +445,15 @@ def _get_exposed_entities(
|
||||||
entities = {}
|
entities = {}
|
||||||
|
|
||||||
for state in hass.states.async_all():
|
for state in hass.states.async_all():
|
||||||
|
if state.domain == SCRIPT_DOMAIN:
|
||||||
|
continue
|
||||||
|
|
||||||
if not async_should_expose(hass, assistant, state.entity_id):
|
if not async_should_expose(hass, assistant, state.entity_id):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
entity_entry = entity_registry.async_get(state.entity_id)
|
entity_entry = entity_registry.async_get(state.entity_id)
|
||||||
names = [state.name]
|
names = [state.name]
|
||||||
area_names = []
|
area_names = []
|
||||||
description: str | None = None
|
|
||||||
|
|
||||||
if entity_entry is not None:
|
if entity_entry is not None:
|
||||||
names.extend(entity_entry.aliases)
|
names.extend(entity_entry.aliases)
|
||||||
|
@ -439,25 +473,11 @@ def _get_exposed_entities(
|
||||||
area_names.append(area.name)
|
area_names.append(area.name)
|
||||||
area_names.extend(area.aliases)
|
area_names.extend(area.aliases)
|
||||||
|
|
||||||
if (
|
|
||||||
state.domain == "script"
|
|
||||||
and entity_entry.unique_id
|
|
||||||
and (
|
|
||||||
service_desc := service.async_get_cached_service_description(
|
|
||||||
hass, "script", entity_entry.unique_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
description = service_desc.get("description")
|
|
||||||
|
|
||||||
info: dict[str, Any] = {
|
info: dict[str, Any] = {
|
||||||
"names": ", ".join(names),
|
"names": ", ".join(names),
|
||||||
"state": state.state,
|
"state": state.state,
|
||||||
}
|
}
|
||||||
|
|
||||||
if description:
|
|
||||||
info["description"] = description
|
|
||||||
|
|
||||||
if area_names:
|
if area_names:
|
||||||
info["areas"] = ", ".join(area_names)
|
info["areas"] = ", ".join(area_names)
|
||||||
|
|
||||||
|
@ -473,3 +493,231 @@ def _get_exposed_entities(
|
||||||
entities[state.entity_id] = info
|
entities[state.entity_id] = info
|
||||||
|
|
||||||
return entities
|
return entities
|
||||||
|
|
||||||
|
|
||||||
|
def _selector_serializer(schema: Any) -> Any: # noqa: C901
|
||||||
|
"""Convert selectors into OpenAPI schema."""
|
||||||
|
if not isinstance(schema, selector.Selector):
|
||||||
|
return UNSUPPORTED
|
||||||
|
|
||||||
|
if isinstance(schema, selector.BackupLocationSelector):
|
||||||
|
return {"type": "string", "pattern": "^(?:\\/backup|\\w+)$"}
|
||||||
|
|
||||||
|
if isinstance(schema, selector.BooleanSelector):
|
||||||
|
return {"type": "boolean"}
|
||||||
|
|
||||||
|
if isinstance(schema, selector.ColorRGBSelector):
|
||||||
|
return {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "number"},
|
||||||
|
"minItems": 3,
|
||||||
|
"maxItems": 3,
|
||||||
|
"format": "RGB",
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(schema, selector.ConditionSelector):
|
||||||
|
return convert(cv.CONDITIONS_SCHEMA)
|
||||||
|
|
||||||
|
if isinstance(schema, selector.ConstantSelector):
|
||||||
|
return {"enum": [schema.config["value"]]}
|
||||||
|
|
||||||
|
result: dict[str, Any]
|
||||||
|
if isinstance(schema, selector.ColorTempSelector):
|
||||||
|
result = {"type": "number"}
|
||||||
|
if "min" in schema.config:
|
||||||
|
result["minimum"] = schema.config["min"]
|
||||||
|
elif "min_mireds" in schema.config:
|
||||||
|
result["minimum"] = schema.config["min_mireds"]
|
||||||
|
if "max" in schema.config:
|
||||||
|
result["maximum"] = schema.config["max"]
|
||||||
|
elif "max_mireds" in schema.config:
|
||||||
|
result["maximum"] = schema.config["max_mireds"]
|
||||||
|
return result
|
||||||
|
|
||||||
|
if isinstance(schema, selector.CountrySelector):
|
||||||
|
if schema.config.get("countries"):
|
||||||
|
return {"type": "string", "enum": schema.config["countries"]}
|
||||||
|
return {"type": "string", "format": "ISO 3166-1 alpha-2"}
|
||||||
|
|
||||||
|
if isinstance(schema, selector.DateSelector):
|
||||||
|
return {"type": "string", "format": "date"}
|
||||||
|
|
||||||
|
if isinstance(schema, selector.DateTimeSelector):
|
||||||
|
return {"type": "string", "format": "date-time"}
|
||||||
|
|
||||||
|
if isinstance(schema, selector.DurationSelector):
|
||||||
|
return convert(cv.time_period_dict)
|
||||||
|
|
||||||
|
if isinstance(schema, selector.EntitySelector):
|
||||||
|
if schema.config.get("multiple"):
|
||||||
|
return {"type": "array", "items": {"type": "string", "format": "entity_id"}}
|
||||||
|
|
||||||
|
return {"type": "string", "format": "entity_id"}
|
||||||
|
|
||||||
|
if isinstance(schema, selector.LanguageSelector):
|
||||||
|
if schema.config.get("languages"):
|
||||||
|
return {"type": "string", "enum": schema.config["languages"]}
|
||||||
|
return {"type": "string", "format": "RFC 5646"}
|
||||||
|
|
||||||
|
if isinstance(schema, (selector.LocationSelector, selector.MediaSelector)):
|
||||||
|
return convert(schema.DATA_SCHEMA)
|
||||||
|
|
||||||
|
if isinstance(schema, selector.NumberSelector):
|
||||||
|
result = {"type": "number"}
|
||||||
|
if "min" in schema.config:
|
||||||
|
result["minimum"] = schema.config["min"]
|
||||||
|
if "max" in schema.config:
|
||||||
|
result["maximum"] = schema.config["max"]
|
||||||
|
return result
|
||||||
|
|
||||||
|
if isinstance(schema, selector.ObjectSelector):
|
||||||
|
return {"type": "object"}
|
||||||
|
|
||||||
|
if isinstance(schema, selector.SelectSelector):
|
||||||
|
options = [
|
||||||
|
x["value"] if isinstance(x, dict) else x for x in schema.config["options"]
|
||||||
|
]
|
||||||
|
if schema.config.get("multiple"):
|
||||||
|
return {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string", "enum": options},
|
||||||
|
"uniqueItems": True,
|
||||||
|
}
|
||||||
|
return {"type": "string", "enum": options}
|
||||||
|
|
||||||
|
if isinstance(schema, selector.TargetSelector):
|
||||||
|
return convert(cv.TARGET_SERVICE_FIELDS)
|
||||||
|
|
||||||
|
if isinstance(schema, selector.TemplateSelector):
|
||||||
|
return {"type": "string", "format": "jinja2"}
|
||||||
|
|
||||||
|
if isinstance(schema, selector.TimeSelector):
|
||||||
|
return {"type": "string", "format": "time"}
|
||||||
|
|
||||||
|
if isinstance(schema, selector.TriggerSelector):
|
||||||
|
return convert(cv.TRIGGER_SCHEMA)
|
||||||
|
|
||||||
|
if schema.config.get("multiple"):
|
||||||
|
return {"type": "array", "items": {"type": "string"}}
|
||||||
|
|
||||||
|
return {"type": "string"}
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptTool(Tool):
|
||||||
|
"""LLM Tool representing a Script."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
script_entity_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Init the class."""
|
||||||
|
entity_registry = er.async_get(hass)
|
||||||
|
|
||||||
|
self.name = split_entity_id(script_entity_id)[1]
|
||||||
|
self.parameters = vol.Schema({})
|
||||||
|
entity_entry = entity_registry.async_get(script_entity_id)
|
||||||
|
if entity_entry and entity_entry.unique_id:
|
||||||
|
parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)
|
||||||
|
|
||||||
|
if parameters_cache is None:
|
||||||
|
parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {}
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def clear_cache(event: Event) -> None:
|
||||||
|
"""Clear script parameter cache on script reload or delete."""
|
||||||
|
if (
|
||||||
|
event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN
|
||||||
|
and event.data[ATTR_SERVICE] in parameters_cache
|
||||||
|
):
|
||||||
|
parameters_cache.pop(event.data[ATTR_SERVICE])
|
||||||
|
|
||||||
|
cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def on_homeassistant_close(event: Event) -> None:
|
||||||
|
"""Cleanup."""
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
hass.bus.async_listen_once(
|
||||||
|
EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close
|
||||||
|
)
|
||||||
|
|
||||||
|
if entity_entry.unique_id in parameters_cache:
|
||||||
|
self.description, self.parameters = parameters_cache[
|
||||||
|
entity_entry.unique_id
|
||||||
|
]
|
||||||
|
return
|
||||||
|
|
||||||
|
if service_desc := service.async_get_cached_service_description(
|
||||||
|
hass, SCRIPT_DOMAIN, entity_entry.unique_id
|
||||||
|
):
|
||||||
|
self.description = service_desc.get("description")
|
||||||
|
schema: dict[vol.Marker, Any] = {}
|
||||||
|
fields = service_desc.get("fields", {})
|
||||||
|
|
||||||
|
for field, config in fields.items():
|
||||||
|
description = config.get("description")
|
||||||
|
if not description:
|
||||||
|
description = config.get("name")
|
||||||
|
if config.get("required"):
|
||||||
|
key = vol.Required(field, description=description)
|
||||||
|
else:
|
||||||
|
key = vol.Optional(field, description=description)
|
||||||
|
if "selector" in config:
|
||||||
|
schema[key] = selector.selector(config["selector"])
|
||||||
|
else:
|
||||||
|
schema[key] = cv.string
|
||||||
|
|
||||||
|
self.parameters = vol.Schema(schema)
|
||||||
|
|
||||||
|
parameters_cache[entity_entry.unique_id] = (
|
||||||
|
self.description,
|
||||||
|
self.parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_call(
|
||||||
|
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||||
|
) -> JsonObjectType:
|
||||||
|
"""Run the script."""
|
||||||
|
|
||||||
|
for field, validator in self.parameters.schema.items():
|
||||||
|
if field not in tool_input.tool_args:
|
||||||
|
continue
|
||||||
|
if isinstance(validator, selector.AreaSelector):
|
||||||
|
area_reg = ar.async_get(hass)
|
||||||
|
if validator.config.get("multiple"):
|
||||||
|
areas: list[ar.AreaEntry] = []
|
||||||
|
for area in tool_input.tool_args[field]:
|
||||||
|
areas.extend(intent.find_areas(area, area_reg))
|
||||||
|
tool_input.tool_args[field] = list({area.id for area in areas})
|
||||||
|
else:
|
||||||
|
area = tool_input.tool_args[field]
|
||||||
|
area = list(intent.find_areas(area, area_reg))[0].id
|
||||||
|
tool_input.tool_args[field] = area
|
||||||
|
|
||||||
|
elif isinstance(validator, selector.FloorSelector):
|
||||||
|
floor_reg = fr.async_get(hass)
|
||||||
|
if validator.config.get("multiple"):
|
||||||
|
floors: list[fr.FloorEntry] = []
|
||||||
|
for floor in tool_input.tool_args[field]:
|
||||||
|
floors.extend(intent.find_floors(floor, floor_reg))
|
||||||
|
tool_input.tool_args[field] = list(
|
||||||
|
{floor.floor_id for floor in floors}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
floor = tool_input.tool_args[field]
|
||||||
|
floor = list(intent.find_floors(floor, floor_reg))[0].floor_id
|
||||||
|
tool_input.tool_args[field] = floor
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
SCRIPT_DOMAIN,
|
||||||
|
SERVICE_TURN_ON,
|
||||||
|
{
|
||||||
|
ATTR_ENTITY_ID: SCRIPT_DOMAIN + "." + self.name,
|
||||||
|
ATTR_VARIABLES: tool_input.tool_args,
|
||||||
|
},
|
||||||
|
context=llm_context.context,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"success": True}
|
||||||
|
|
|
@ -75,6 +75,13 @@ class Selector[_T: Mapping[str, Any]]:
|
||||||
|
|
||||||
self.config = self.CONFIG_SCHEMA(config)
|
self.config = self.CONFIG_SCHEMA(config)
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
"""Check equality."""
|
||||||
|
if not isinstance(other, Selector):
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
return self.selector_type == other.selector_type and self.config == other.config
|
||||||
|
|
||||||
def serialize(self) -> dict[str, dict[str, _T]]:
|
def serialize(self) -> dict[str, dict[str, _T]]:
|
||||||
"""Serialize Selector for voluptuous_serialize."""
|
"""Serialize Selector for voluptuous_serialize."""
|
||||||
return {"selector": {self.selector_type: self.config}}
|
return {"selector": {self.selector_type: self.config}}
|
||||||
|
@ -278,7 +285,7 @@ class AssistPipelineSelector(Selector[AssistPipelineSelectorConfig]):
|
||||||
|
|
||||||
CONFIG_SCHEMA = vol.Schema({})
|
CONFIG_SCHEMA = vol.Schema({})
|
||||||
|
|
||||||
def __init__(self, config: AssistPipelineSelectorConfig) -> None:
|
def __init__(self, config: AssistPipelineSelectorConfig | None = None) -> None:
|
||||||
"""Instantiate a selector."""
|
"""Instantiate a selector."""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
@ -430,10 +437,10 @@ class ColorTempSelector(Selector[ColorTempSelectorConfig]):
|
||||||
range_min = self.config.get("min")
|
range_min = self.config.get("min")
|
||||||
range_max = self.config.get("max")
|
range_max = self.config.get("max")
|
||||||
|
|
||||||
if not range_min:
|
if range_min is None:
|
||||||
range_min = self.config.get("min_mireds")
|
range_min = self.config.get("min_mireds")
|
||||||
|
|
||||||
if not range_max:
|
if range_max is None:
|
||||||
range_max = self.config.get("max_mireds")
|
range_max = self.config.get("max_mireds")
|
||||||
|
|
||||||
value: int = vol.All(
|
value: int = vol.All(
|
||||||
|
@ -517,7 +524,7 @@ class ConstantSelector(Selector[ConstantSelectorConfig]):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, config: ConstantSelectorConfig | None = None) -> None:
|
def __init__(self, config: ConstantSelectorConfig) -> None:
|
||||||
"""Instantiate a selector."""
|
"""Instantiate a selector."""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
@ -560,7 +567,7 @@ class QrCodeSelector(Selector[QrCodeSelectorConfig]):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, config: QrCodeSelectorConfig | None = None) -> None:
|
def __init__(self, config: QrCodeSelectorConfig) -> None:
|
||||||
"""Instantiate a selector."""
|
"""Instantiate a selector."""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
@ -588,7 +595,7 @@ class ConversationAgentSelector(Selector[ConversationAgentSelectorConfig]):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, config: ConversationAgentSelectorConfig) -> None:
|
def __init__(self, config: ConversationAgentSelectorConfig | None = None) -> None:
|
||||||
"""Instantiate a selector."""
|
"""Instantiate a selector."""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
@ -820,7 +827,7 @@ class FloorSelectorConfig(TypedDict, total=False):
|
||||||
|
|
||||||
|
|
||||||
@SELECTORS.register("floor")
|
@SELECTORS.register("floor")
|
||||||
class FloorSelector(Selector[AreaSelectorConfig]):
|
class FloorSelector(Selector[FloorSelectorConfig]):
|
||||||
"""Selector of a single or list of floors."""
|
"""Selector of a single or list of floors."""
|
||||||
|
|
||||||
selector_type = "floor"
|
selector_type = "floor"
|
||||||
|
@ -934,7 +941,7 @@ class LanguageSelector(Selector[LanguageSelectorConfig]):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, config: LanguageSelectorConfig) -> None:
|
def __init__(self, config: LanguageSelectorConfig | None = None) -> None:
|
||||||
"""Instantiate a selector."""
|
"""Instantiate a selector."""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
@ -1159,7 +1166,7 @@ class SelectSelector(Selector[SelectSelectorConfig]):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, config: SelectSelectorConfig | None = None) -> None:
|
def __init__(self, config: SelectSelectorConfig) -> None:
|
||||||
"""Instantiate a selector."""
|
"""Instantiate a selector."""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
@ -1434,7 +1441,7 @@ class FileSelector(Selector[FileSelectorConfig]):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, config: FileSelectorConfig | None = None) -> None:
|
def __init__(self, config: FileSelectorConfig) -> None:
|
||||||
"""Instantiate a selector."""
|
"""Instantiate a selector."""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
|
|
@ -58,6 +58,7 @@ SQLAlchemy==2.0.31
|
||||||
typing-extensions>=4.12.2,<5.0
|
typing-extensions>=4.12.2,<5.0
|
||||||
ulid-transform==0.9.0
|
ulid-transform==0.9.0
|
||||||
urllib3>=1.26.5,<2
|
urllib3>=1.26.5,<2
|
||||||
|
voluptuous-openapi==0.0.4
|
||||||
voluptuous-serialize==2.6.0
|
voluptuous-serialize==2.6.0
|
||||||
voluptuous==0.13.1
|
voluptuous==0.13.1
|
||||||
webrtc-noise-gain==1.2.3
|
webrtc-noise-gain==1.2.3
|
||||||
|
|
|
@ -69,6 +69,7 @@ dependencies = [
|
||||||
"urllib3>=1.26.5,<2",
|
"urllib3>=1.26.5,<2",
|
||||||
"voluptuous==0.13.1",
|
"voluptuous==0.13.1",
|
||||||
"voluptuous-serialize==2.6.0",
|
"voluptuous-serialize==2.6.0",
|
||||||
|
"voluptuous-openapi==0.0.4",
|
||||||
"yarl==1.9.4",
|
"yarl==1.9.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -41,4 +41,5 @@ ulid-transform==0.9.0
|
||||||
urllib3>=1.26.5,<2
|
urllib3>=1.26.5,<2
|
||||||
voluptuous==0.13.1
|
voluptuous==0.13.1
|
||||||
voluptuous-serialize==2.6.0
|
voluptuous-serialize==2.6.0
|
||||||
|
voluptuous-openapi==0.0.4
|
||||||
yarl==1.9.4
|
yarl==1.9.4
|
||||||
|
|
|
@ -2846,10 +2846,6 @@ voip-utils==0.1.0
|
||||||
# homeassistant.components.volkszaehler
|
# homeassistant.components.volkszaehler
|
||||||
volkszaehler==0.4.0
|
volkszaehler==0.4.0
|
||||||
|
|
||||||
# homeassistant.components.google_generative_ai_conversation
|
|
||||||
# homeassistant.components.openai_conversation
|
|
||||||
voluptuous-openapi==0.0.4
|
|
||||||
|
|
||||||
# homeassistant.components.volvooncall
|
# homeassistant.components.volvooncall
|
||||||
volvooncall==0.10.3
|
volvooncall==0.10.3
|
||||||
|
|
||||||
|
|
|
@ -2217,10 +2217,6 @@ vilfo-api-client==0.5.0
|
||||||
# homeassistant.components.voip
|
# homeassistant.components.voip
|
||||||
voip-utils==0.1.0
|
voip-utils==0.1.0
|
||||||
|
|
||||||
# homeassistant.components.google_generative_ai_conversation
|
|
||||||
# homeassistant.components.openai_conversation
|
|
||||||
voluptuous-openapi==0.0.4
|
|
||||||
|
|
||||||
# homeassistant.components.volvooncall
|
# homeassistant.components.volvooncall
|
||||||
volvooncall==0.10.3
|
volvooncall==0.10.3
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||||
from homeassistant.components.intent import async_register_timer_handler
|
from homeassistant.components.intent import async_register_timer_handler
|
||||||
|
from homeassistant.components.script.config import ScriptConfig
|
||||||
from homeassistant.core import Context, HomeAssistant, State
|
from homeassistant.core import Context, HomeAssistant, State
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
|
@ -18,6 +19,7 @@ from homeassistant.helpers import (
|
||||||
floor_registry as fr,
|
floor_registry as fr,
|
||||||
intent,
|
intent,
|
||||||
llm,
|
llm,
|
||||||
|
selector,
|
||||||
)
|
)
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.util import yaml
|
from homeassistant.util import yaml
|
||||||
|
@ -564,11 +566,6 @@ async def test_assist_api_prompt(
|
||||||
"names": "Unnamed Device",
|
"names": "Unnamed Device",
|
||||||
"state": "unavailable",
|
"state": "unavailable",
|
||||||
},
|
},
|
||||||
"script.test_script": {
|
|
||||||
"description": "This is a test script",
|
|
||||||
"names": "test_script",
|
|
||||||
"state": "off",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
exposed_entities_prompt = (
|
exposed_entities_prompt = (
|
||||||
"An overview of the areas and the devices in this smart home:\n"
|
"An overview of the areas and the devices in this smart home:\n"
|
||||||
|
@ -634,3 +631,323 @@ async def test_assist_api_prompt(
|
||||||
{area_prompt}
|
{area_prompt}
|
||||||
{exposed_entities_prompt}"""
|
{exposed_entities_prompt}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_script_tool(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
area_registry: ar.AreaRegistry,
|
||||||
|
floor_registry: fr.FloorRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test ScriptTool for the assist API."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
assert await async_setup_component(hass, "intent", {})
|
||||||
|
context = Context()
|
||||||
|
llm_context = llm.LLMContext(
|
||||||
|
platform="test_platform",
|
||||||
|
context=context,
|
||||||
|
user_prompt="test_text",
|
||||||
|
language="*",
|
||||||
|
assistant="conversation",
|
||||||
|
device_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a script with a unique ID
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"script",
|
||||||
|
{
|
||||||
|
"script": {
|
||||||
|
"test_script": {
|
||||||
|
"description": "This is a test script",
|
||||||
|
"sequence": [],
|
||||||
|
"fields": {
|
||||||
|
"beer": {"description": "Number of beers", "required": True},
|
||||||
|
"wine": {"selector": {"number": {"min": 0, "max": 3}}},
|
||||||
|
"where": {"selector": {"area": {}}},
|
||||||
|
"area_list": {"selector": {"area": {"multiple": True}}},
|
||||||
|
"floor": {"selector": {"floor": {}}},
|
||||||
|
"floor_list": {"selector": {"floor": {"multiple": True}}},
|
||||||
|
"extra_field": {"selector": {"area": {}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"unexposed_script": {
|
||||||
|
"sequence": [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async_expose_entity(hass, "conversation", "script.test_script", True)
|
||||||
|
|
||||||
|
area = area_registry.async_create("Living room")
|
||||||
|
floor = floor_registry.async_create("2")
|
||||||
|
|
||||||
|
assert llm.SCRIPT_PARAMETERS_CACHE not in hass.data
|
||||||
|
|
||||||
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
|
|
||||||
|
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
||||||
|
assert len(tools) == 1
|
||||||
|
|
||||||
|
tool = tools[0]
|
||||||
|
assert tool.name == "test_script"
|
||||||
|
assert tool.description == "This is a test script"
|
||||||
|
schema = {
|
||||||
|
vol.Required("beer", description="Number of beers"): cv.string,
|
||||||
|
vol.Optional("wine"): selector.NumberSelector({"min": 0, "max": 3}),
|
||||||
|
vol.Optional("where"): selector.AreaSelector(),
|
||||||
|
vol.Optional("area_list"): selector.AreaSelector({"multiple": True}),
|
||||||
|
vol.Optional("floor"): selector.FloorSelector(),
|
||||||
|
vol.Optional("floor_list"): selector.FloorSelector({"multiple": True}),
|
||||||
|
vol.Optional("extra_field"): selector.AreaSelector(),
|
||||||
|
}
|
||||||
|
assert tool.parameters.schema == schema
|
||||||
|
|
||||||
|
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
|
||||||
|
"test_script": ("This is a test script", vol.Schema(schema))
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_input = llm.ToolInput(
|
||||||
|
tool_name="test_script",
|
||||||
|
tool_args={
|
||||||
|
"beer": "3",
|
||||||
|
"wine": 0,
|
||||||
|
"where": "Living room",
|
||||||
|
"area_list": ["Living room"],
|
||||||
|
"floor": "2",
|
||||||
|
"floor_list": ["2"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("homeassistant.core.ServiceRegistry.async_call") as mock_service_call:
|
||||||
|
response = await api.async_call_tool(tool_input)
|
||||||
|
|
||||||
|
mock_service_call.assert_awaited_once_with(
|
||||||
|
"script",
|
||||||
|
"turn_on",
|
||||||
|
{
|
||||||
|
"entity_id": "script.test_script",
|
||||||
|
"variables": {
|
||||||
|
"beer": "3",
|
||||||
|
"wine": 0,
|
||||||
|
"where": area.id,
|
||||||
|
"area_list": [area.id],
|
||||||
|
"floor": floor.floor_id,
|
||||||
|
"floor_list": [floor.floor_id],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
assert response == {"success": True}
|
||||||
|
|
||||||
|
# Test reload script with new parameters
|
||||||
|
config = {
|
||||||
|
"script": {
|
||||||
|
"test_script": ScriptConfig(
|
||||||
|
{
|
||||||
|
"description": "This is a new test script",
|
||||||
|
"sequence": [],
|
||||||
|
"mode": "single",
|
||||||
|
"max": 2,
|
||||||
|
"max_exceeded": "WARNING",
|
||||||
|
"trace": {},
|
||||||
|
"fields": {
|
||||||
|
"beer": {"description": "Number of beers", "required": True},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.helpers.entity_component.EntityComponent.async_prepare_reload",
|
||||||
|
return_value=config,
|
||||||
|
):
|
||||||
|
await hass.services.async_call("script", "reload", blocking=True)
|
||||||
|
|
||||||
|
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {}
|
||||||
|
|
||||||
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
|
|
||||||
|
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
||||||
|
assert len(tools) == 1
|
||||||
|
|
||||||
|
tool = tools[0]
|
||||||
|
assert tool.name == "test_script"
|
||||||
|
assert tool.description == "This is a new test script"
|
||||||
|
schema = {vol.Required("beer", description="Number of beers"): cv.string}
|
||||||
|
assert tool.parameters.schema == schema
|
||||||
|
|
||||||
|
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
|
||||||
|
"test_script": ("This is a new test script", vol.Schema(schema))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_selector_serializer(
|
||||||
|
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||||
|
) -> None:
|
||||||
|
"""Test serialization of Selectors in Open API format."""
|
||||||
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
|
selector_serializer = api.custom_serializer
|
||||||
|
|
||||||
|
assert selector_serializer(selector.ActionSelector()) == {"type": "string"}
|
||||||
|
assert selector_serializer(selector.AddonSelector()) == {"type": "string"}
|
||||||
|
assert selector_serializer(selector.AreaSelector()) == {"type": "string"}
|
||||||
|
assert selector_serializer(selector.AreaSelector({"multiple": True})) == {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.AssistPipelineSelector()) == {"type": "string"}
|
||||||
|
assert selector_serializer(
|
||||||
|
selector.AttributeSelector({"entity_id": "sensor.test"})
|
||||||
|
) == {"type": "string"}
|
||||||
|
assert selector_serializer(selector.BackupLocationSelector()) == {
|
||||||
|
"type": "string",
|
||||||
|
"pattern": "^(?:\\/backup|\\w+)$",
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.BooleanSelector()) == {"type": "boolean"}
|
||||||
|
assert selector_serializer(selector.ColorRGBSelector()) == {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "number"},
|
||||||
|
"maxItems": 3,
|
||||||
|
"minItems": 3,
|
||||||
|
"format": "RGB",
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.ColorTempSelector()) == {"type": "number"}
|
||||||
|
assert selector_serializer(selector.ColorTempSelector({"min": 0, "max": 1000})) == {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 1000,
|
||||||
|
}
|
||||||
|
assert selector_serializer(
|
||||||
|
selector.ColorTempSelector({"min_mireds": 100, "max_mireds": 1000})
|
||||||
|
) == {"type": "number", "minimum": 100, "maximum": 1000}
|
||||||
|
assert selector_serializer(selector.ConfigEntrySelector()) == {"type": "string"}
|
||||||
|
assert selector_serializer(selector.ConstantSelector({"value": "test"})) == {
|
||||||
|
"enum": ["test"]
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.ConstantSelector({"value": 1})) == {"enum": [1]}
|
||||||
|
assert selector_serializer(selector.ConstantSelector({"value": True})) == {
|
||||||
|
"enum": [True]
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.QrCodeSelector({"data": "test"})) == {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.ConversationAgentSelector()) == {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.CountrySelector()) == {
|
||||||
|
"type": "string",
|
||||||
|
"format": "ISO 3166-1 alpha-2",
|
||||||
|
}
|
||||||
|
assert selector_serializer(
|
||||||
|
selector.CountrySelector({"countries": ["GB", "FR"]})
|
||||||
|
) == {"type": "string", "enum": ["GB", "FR"]}
|
||||||
|
assert selector_serializer(selector.DateSelector()) == {
|
||||||
|
"type": "string",
|
||||||
|
"format": "date",
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.DateTimeSelector()) == {
|
||||||
|
"type": "string",
|
||||||
|
"format": "date-time",
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.DeviceSelector()) == {"type": "string"}
|
||||||
|
assert selector_serializer(selector.DeviceSelector({"multiple": True})) == {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.EntitySelector()) == {
|
||||||
|
"type": "string",
|
||||||
|
"format": "entity_id",
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.EntitySelector({"multiple": True})) == {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string", "format": "entity_id"},
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.FloorSelector()) == {"type": "string"}
|
||||||
|
assert selector_serializer(selector.FloorSelector({"multiple": True})) == {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.IconSelector()) == {"type": "string"}
|
||||||
|
assert selector_serializer(selector.LabelSelector()) == {"type": "string"}
|
||||||
|
assert selector_serializer(selector.LabelSelector({"multiple": True})) == {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.LanguageSelector()) == {
|
||||||
|
"type": "string",
|
||||||
|
"format": "RFC 5646",
|
||||||
|
}
|
||||||
|
assert selector_serializer(
|
||||||
|
selector.LanguageSelector({"languages": ["en", "fr"]})
|
||||||
|
) == {"type": "string", "enum": ["en", "fr"]}
|
||||||
|
assert selector_serializer(selector.LocationSelector()) == {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"latitude": {"type": "number"},
|
||||||
|
"longitude": {"type": "number"},
|
||||||
|
"radius": {"type": "number"},
|
||||||
|
},
|
||||||
|
"required": ["latitude", "longitude"],
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.MediaSelector()) == {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"entity_id": {"type": "string"},
|
||||||
|
"media_content_id": {"type": "string"},
|
||||||
|
"media_content_type": {"type": "string"},
|
||||||
|
"metadata": {"type": "object", "additionalProperties": True},
|
||||||
|
},
|
||||||
|
"required": ["entity_id", "media_content_id", "media_content_type"],
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.NumberSelector({"mode": "box"})) == {
|
||||||
|
"type": "number"
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.NumberSelector({"min": 30, "max": 100})) == {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 30,
|
||||||
|
"maximum": 100,
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.ObjectSelector()) == {"type": "object"}
|
||||||
|
assert selector_serializer(
|
||||||
|
selector.SelectSelector(
|
||||||
|
{
|
||||||
|
"options": [
|
||||||
|
{"value": "A", "label": "Letter A"},
|
||||||
|
{"value": "B", "label": "Letter B"},
|
||||||
|
{"value": "C", "label": "Letter C"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
) == {"type": "string", "enum": ["A", "B", "C"]}
|
||||||
|
assert selector_serializer(
|
||||||
|
selector.SelectSelector({"options": ["A", "B", "C"], "multiple": True})
|
||||||
|
) == {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string", "enum": ["A", "B", "C"]},
|
||||||
|
"uniqueItems": True,
|
||||||
|
}
|
||||||
|
assert selector_serializer(
|
||||||
|
selector.StateSelector({"entity_id": "sensor.test"})
|
||||||
|
) == {"type": "string"}
|
||||||
|
assert selector_serializer(selector.TemplateSelector()) == {
|
||||||
|
"type": "string",
|
||||||
|
"format": "jinja2",
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.TextSelector()) == {"type": "string"}
|
||||||
|
assert selector_serializer(selector.TextSelector({"multiple": True})) == {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.ThemeSelector()) == {"type": "string"}
|
||||||
|
assert selector_serializer(selector.TimeSelector()) == {
|
||||||
|
"type": "string",
|
||||||
|
"format": "time",
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.TriggerSelector()) == {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
}
|
||||||
|
assert selector_serializer(selector.FileSelector({"accept": ".txt"})) == {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
|
|
@ -55,6 +55,8 @@ def _test_selector(
|
||||||
config = {selector_type: schema}
|
config = {selector_type: schema}
|
||||||
selector.validate_selector(config)
|
selector.validate_selector(config)
|
||||||
selector_instance = selector.selector(config)
|
selector_instance = selector.selector(config)
|
||||||
|
assert selector_instance == selector.selector(config)
|
||||||
|
assert selector_instance != 5
|
||||||
# We do not allow enums in the config, as they cannot serialize
|
# We do not allow enums in the config, as they cannot serialize
|
||||||
assert not any(isinstance(val, Enum) for val in selector_instance.config.values())
|
assert not any(isinstance(val, Enum) for val in selector_instance.config.values())
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue