Add script llm tool (#118936)
* Add script llm tool * Add tests * More tests * more test * more test * Add area and floor resolving * coverage * coverage * fix ColorTempSelector * fix mypy * fix mypy * add script reload test * Cache script tool parameters * Make custom_serializer a part of api --------- Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
parent
77fea8a73e
commit
2386ed3830
14 changed files with 639 additions and 55 deletions
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal
|
||||
|
||||
from google.api_core.exceptions import GoogleAPICallError
|
||||
|
@ -89,10 +90,14 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
|
|||
return result
|
||||
|
||||
|
||||
def _format_tool(tool: llm.Tool) -> dict[str, Any]:
|
||||
def _format_tool(
|
||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||
) -> dict[str, Any]:
|
||||
"""Format tool specification."""
|
||||
|
||||
parameters = _format_schema(convert(tool.parameters))
|
||||
parameters = _format_schema(
|
||||
convert(tool.parameters, custom_serializer=custom_serializer)
|
||||
)
|
||||
|
||||
return protos.Tool(
|
||||
{
|
||||
|
@ -193,7 +198,9 @@ class GoogleGenerativeAIConversationEntity(
|
|||
f"Error preparing LLM API: {err}",
|
||||
)
|
||||
return result
|
||||
tools = [_format_tool(tool) for tool in llm_api.tools]
|
||||
tools = [
|
||||
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
||||
]
|
||||
|
||||
try:
|
||||
prompt = await self._async_render_prompt(user_input, llm_api, llm_context)
|
||||
|
|
|
@ -9,5 +9,5 @@
|
|||
"integration_type": "service",
|
||||
"iot_class": "cloud_polling",
|
||||
"quality_scale": "platinum",
|
||||
"requirements": ["google-generativeai==0.6.0", "voluptuous-openapi==0.0.4"]
|
||||
"requirements": ["google-generativeai==0.6.0"]
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
"""Conversation support for OpenAI."""
|
||||
|
||||
from collections.abc import Callable
|
||||
import json
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import openai
|
||||
from openai._types import NOT_GIVEN
|
||||
|
@ -58,9 +59,14 @@ async def async_setup_entry(
|
|||
async_add_entities([agent])
|
||||
|
||||
|
||||
def _format_tool(tool: llm.Tool) -> ChatCompletionToolParam:
|
||||
def _format_tool(
|
||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||
) -> ChatCompletionToolParam:
|
||||
"""Format tool specification."""
|
||||
tool_spec = FunctionDefinition(name=tool.name, parameters=convert(tool.parameters))
|
||||
tool_spec = FunctionDefinition(
|
||||
name=tool.name,
|
||||
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
||||
)
|
||||
if tool.description:
|
||||
tool_spec["description"] = tool.description
|
||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||
|
@ -139,7 +145,9 @@ class OpenAIConversationEntity(
|
|||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
tools = [_format_tool(tool) for tool in llm_api.tools]
|
||||
tools = [
|
||||
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
||||
]
|
||||
|
||||
if user_input.conversation_id is None:
|
||||
conversation_id = ulid.ulid_now()
|
||||
|
|
|
@ -8,5 +8,5 @@
|
|||
"documentation": "https://www.home-assistant.io/integrations/openai_conversation",
|
||||
"integration_type": "service",
|
||||
"iot_class": "cloud_polling",
|
||||
"requirements": ["openai==1.3.8", "voluptuous-openapi==0.0.4"]
|
||||
"requirements": ["openai==1.3.8"]
|
||||
}
|
||||
|
|
|
@ -352,7 +352,7 @@ class MatchTargetsCandidate:
|
|||
matched_name: str | None = None
|
||||
|
||||
|
||||
def _find_areas(
|
||||
def find_areas(
|
||||
name: str, areas: area_registry.AreaRegistry
|
||||
) -> Iterable[area_registry.AreaEntry]:
|
||||
"""Find all areas matching a name (including aliases)."""
|
||||
|
@ -372,7 +372,7 @@ def _find_areas(
|
|||
break
|
||||
|
||||
|
||||
def _find_floors(
|
||||
def find_floors(
|
||||
name: str, floors: floor_registry.FloorRegistry
|
||||
) -> Iterable[floor_registry.FloorEntry]:
|
||||
"""Find all floors matching a name (including aliases)."""
|
||||
|
@ -530,7 +530,7 @@ def async_match_targets( # noqa: C901
|
|||
if not states:
|
||||
return MatchTargetsResult(False, MatchFailedReason.STATE)
|
||||
|
||||
# Exit early so we can to avoid registry lookups
|
||||
# Exit early so we can avoid registry lookups
|
||||
if not (
|
||||
constraints.name
|
||||
or constraints.features
|
||||
|
@ -580,7 +580,7 @@ def async_match_targets( # noqa: C901
|
|||
if constraints.floor_name:
|
||||
# Filter by areas associated with floor
|
||||
fr = floor_registry.async_get(hass)
|
||||
targeted_floors = list(_find_floors(constraints.floor_name, fr))
|
||||
targeted_floors = list(find_floors(constraints.floor_name, fr))
|
||||
if not targeted_floors:
|
||||
return MatchTargetsResult(
|
||||
False,
|
||||
|
@ -609,7 +609,7 @@ def async_match_targets( # noqa: C901
|
|||
possible_area_ids = {area.id for area in ar.async_list_areas()}
|
||||
|
||||
if constraints.area_name:
|
||||
targeted_areas = list(_find_areas(constraints.area_name, ar))
|
||||
targeted_areas = list(find_areas(constraints.area_name, ar))
|
||||
if not targeted_areas:
|
||||
return MatchTargetsResult(
|
||||
False,
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
|
@ -11,6 +12,7 @@ from typing import Any
|
|||
|
||||
import slugify as unicode_slug
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import UNSUPPORTED, convert
|
||||
|
||||
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
|
||||
from homeassistant.components.conversation.trace import (
|
||||
|
@ -20,22 +22,39 @@ from homeassistant.components.conversation.trace import (
|
|||
from homeassistant.components.cover.intent import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.components.intent import async_device_supports_timers
|
||||
from homeassistant.components.script import ATTR_VARIABLES, DOMAIN as SCRIPT_DOMAIN
|
||||
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.const import (
|
||||
ATTR_DOMAIN,
|
||||
ATTR_ENTITY_ID,
|
||||
ATTR_SERVICE,
|
||||
EVENT_HOMEASSISTANT_CLOSE,
|
||||
EVENT_SERVICE_REMOVED,
|
||||
SERVICE_TURN_ON,
|
||||
)
|
||||
from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util import yaml
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
from . import (
|
||||
area_registry as ar,
|
||||
config_validation as cv,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
floor_registry as fr,
|
||||
intent,
|
||||
selector,
|
||||
service,
|
||||
)
|
||||
from .singleton import singleton
|
||||
|
||||
SCRIPT_PARAMETERS_CACHE: HassKey[dict[str, tuple[str | None, vol.Schema]]] = HassKey(
|
||||
"llm_script_parameters_cache"
|
||||
)
|
||||
|
||||
|
||||
LLM_API_ASSIST = "assist"
|
||||
|
||||
BASE_PROMPT = (
|
||||
|
@ -143,6 +162,7 @@ class APIInstance:
|
|||
api_prompt: str
|
||||
llm_context: LLMContext
|
||||
tools: list[Tool]
|
||||
custom_serializer: Callable[[Any], Any] | None = None
|
||||
|
||||
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
||||
"""Call a LLM tool, validate args and return the response."""
|
||||
|
@ -284,6 +304,7 @@ class AssistAPI(API):
|
|||
api_prompt=self._async_get_api_prompt(llm_context, exposed_entities),
|
||||
llm_context=llm_context,
|
||||
tools=self._async_get_tools(llm_context, exposed_entities),
|
||||
custom_serializer=_selector_serializer,
|
||||
)
|
||||
|
||||
@callback
|
||||
|
@ -372,7 +393,7 @@ class AssistAPI(API):
|
|||
exposed_domains: set[str] | None = None
|
||||
if exposed_entities is not None:
|
||||
exposed_domains = {
|
||||
entity_id.split(".")[0] for entity_id in exposed_entities
|
||||
split_entity_id(entity_id)[0] for entity_id in exposed_entities
|
||||
}
|
||||
intent_handlers = [
|
||||
intent_handler
|
||||
|
@ -381,11 +402,22 @@ class AssistAPI(API):
|
|||
or intent_handler.platforms & exposed_domains
|
||||
]
|
||||
|
||||
return [
|
||||
tools: list[Tool] = [
|
||||
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
|
||||
for intent_handler in intent_handlers
|
||||
]
|
||||
|
||||
if llm_context.assistant is not None:
|
||||
for state in self.hass.states.async_all(SCRIPT_DOMAIN):
|
||||
if not async_should_expose(
|
||||
self.hass, llm_context.assistant, state.entity_id
|
||||
):
|
||||
continue
|
||||
|
||||
tools.append(ScriptTool(self.hass, state.entity_id))
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
def _get_exposed_entities(
|
||||
hass: HomeAssistant, assistant: str
|
||||
|
@ -413,13 +445,15 @@ def _get_exposed_entities(
|
|||
entities = {}
|
||||
|
||||
for state in hass.states.async_all():
|
||||
if state.domain == SCRIPT_DOMAIN:
|
||||
continue
|
||||
|
||||
if not async_should_expose(hass, assistant, state.entity_id):
|
||||
continue
|
||||
|
||||
entity_entry = entity_registry.async_get(state.entity_id)
|
||||
names = [state.name]
|
||||
area_names = []
|
||||
description: str | None = None
|
||||
|
||||
if entity_entry is not None:
|
||||
names.extend(entity_entry.aliases)
|
||||
|
@ -439,25 +473,11 @@ def _get_exposed_entities(
|
|||
area_names.append(area.name)
|
||||
area_names.extend(area.aliases)
|
||||
|
||||
if (
|
||||
state.domain == "script"
|
||||
and entity_entry.unique_id
|
||||
and (
|
||||
service_desc := service.async_get_cached_service_description(
|
||||
hass, "script", entity_entry.unique_id
|
||||
)
|
||||
)
|
||||
):
|
||||
description = service_desc.get("description")
|
||||
|
||||
info: dict[str, Any] = {
|
||||
"names": ", ".join(names),
|
||||
"state": state.state,
|
||||
}
|
||||
|
||||
if description:
|
||||
info["description"] = description
|
||||
|
||||
if area_names:
|
||||
info["areas"] = ", ".join(area_names)
|
||||
|
||||
|
@ -473,3 +493,231 @@ def _get_exposed_entities(
|
|||
entities[state.entity_id] = info
|
||||
|
||||
return entities
|
||||
|
||||
|
||||
def _selector_serializer(schema: Any) -> Any: # noqa: C901
|
||||
"""Convert selectors into OpenAPI schema."""
|
||||
if not isinstance(schema, selector.Selector):
|
||||
return UNSUPPORTED
|
||||
|
||||
if isinstance(schema, selector.BackupLocationSelector):
|
||||
return {"type": "string", "pattern": "^(?:\\/backup|\\w+)$"}
|
||||
|
||||
if isinstance(schema, selector.BooleanSelector):
|
||||
return {"type": "boolean"}
|
||||
|
||||
if isinstance(schema, selector.ColorRGBSelector):
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {"type": "number"},
|
||||
"minItems": 3,
|
||||
"maxItems": 3,
|
||||
"format": "RGB",
|
||||
}
|
||||
|
||||
if isinstance(schema, selector.ConditionSelector):
|
||||
return convert(cv.CONDITIONS_SCHEMA)
|
||||
|
||||
if isinstance(schema, selector.ConstantSelector):
|
||||
return {"enum": [schema.config["value"]]}
|
||||
|
||||
result: dict[str, Any]
|
||||
if isinstance(schema, selector.ColorTempSelector):
|
||||
result = {"type": "number"}
|
||||
if "min" in schema.config:
|
||||
result["minimum"] = schema.config["min"]
|
||||
elif "min_mireds" in schema.config:
|
||||
result["minimum"] = schema.config["min_mireds"]
|
||||
if "max" in schema.config:
|
||||
result["maximum"] = schema.config["max"]
|
||||
elif "max_mireds" in schema.config:
|
||||
result["maximum"] = schema.config["max_mireds"]
|
||||
return result
|
||||
|
||||
if isinstance(schema, selector.CountrySelector):
|
||||
if schema.config.get("countries"):
|
||||
return {"type": "string", "enum": schema.config["countries"]}
|
||||
return {"type": "string", "format": "ISO 3166-1 alpha-2"}
|
||||
|
||||
if isinstance(schema, selector.DateSelector):
|
||||
return {"type": "string", "format": "date"}
|
||||
|
||||
if isinstance(schema, selector.DateTimeSelector):
|
||||
return {"type": "string", "format": "date-time"}
|
||||
|
||||
if isinstance(schema, selector.DurationSelector):
|
||||
return convert(cv.time_period_dict)
|
||||
|
||||
if isinstance(schema, selector.EntitySelector):
|
||||
if schema.config.get("multiple"):
|
||||
return {"type": "array", "items": {"type": "string", "format": "entity_id"}}
|
||||
|
||||
return {"type": "string", "format": "entity_id"}
|
||||
|
||||
if isinstance(schema, selector.LanguageSelector):
|
||||
if schema.config.get("languages"):
|
||||
return {"type": "string", "enum": schema.config["languages"]}
|
||||
return {"type": "string", "format": "RFC 5646"}
|
||||
|
||||
if isinstance(schema, (selector.LocationSelector, selector.MediaSelector)):
|
||||
return convert(schema.DATA_SCHEMA)
|
||||
|
||||
if isinstance(schema, selector.NumberSelector):
|
||||
result = {"type": "number"}
|
||||
if "min" in schema.config:
|
||||
result["minimum"] = schema.config["min"]
|
||||
if "max" in schema.config:
|
||||
result["maximum"] = schema.config["max"]
|
||||
return result
|
||||
|
||||
if isinstance(schema, selector.ObjectSelector):
|
||||
return {"type": "object"}
|
||||
|
||||
if isinstance(schema, selector.SelectSelector):
|
||||
options = [
|
||||
x["value"] if isinstance(x, dict) else x for x in schema.config["options"]
|
||||
]
|
||||
if schema.config.get("multiple"):
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "enum": options},
|
||||
"uniqueItems": True,
|
||||
}
|
||||
return {"type": "string", "enum": options}
|
||||
|
||||
if isinstance(schema, selector.TargetSelector):
|
||||
return convert(cv.TARGET_SERVICE_FIELDS)
|
||||
|
||||
if isinstance(schema, selector.TemplateSelector):
|
||||
return {"type": "string", "format": "jinja2"}
|
||||
|
||||
if isinstance(schema, selector.TimeSelector):
|
||||
return {"type": "string", "format": "time"}
|
||||
|
||||
if isinstance(schema, selector.TriggerSelector):
|
||||
return convert(cv.TRIGGER_SCHEMA)
|
||||
|
||||
if schema.config.get("multiple"):
|
||||
return {"type": "array", "items": {"type": "string"}}
|
||||
|
||||
return {"type": "string"}
|
||||
|
||||
|
||||
class ScriptTool(Tool):
|
||||
"""LLM Tool representing a Script."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
script_entity_id: str,
|
||||
) -> None:
|
||||
"""Init the class."""
|
||||
entity_registry = er.async_get(hass)
|
||||
|
||||
self.name = split_entity_id(script_entity_id)[1]
|
||||
self.parameters = vol.Schema({})
|
||||
entity_entry = entity_registry.async_get(script_entity_id)
|
||||
if entity_entry and entity_entry.unique_id:
|
||||
parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)
|
||||
|
||||
if parameters_cache is None:
|
||||
parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {}
|
||||
|
||||
@callback
|
||||
def clear_cache(event: Event) -> None:
|
||||
"""Clear script parameter cache on script reload or delete."""
|
||||
if (
|
||||
event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN
|
||||
and event.data[ATTR_SERVICE] in parameters_cache
|
||||
):
|
||||
parameters_cache.pop(event.data[ATTR_SERVICE])
|
||||
|
||||
cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
|
||||
|
||||
@callback
|
||||
def on_homeassistant_close(event: Event) -> None:
|
||||
"""Cleanup."""
|
||||
cancel()
|
||||
|
||||
hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close
|
||||
)
|
||||
|
||||
if entity_entry.unique_id in parameters_cache:
|
||||
self.description, self.parameters = parameters_cache[
|
||||
entity_entry.unique_id
|
||||
]
|
||||
return
|
||||
|
||||
if service_desc := service.async_get_cached_service_description(
|
||||
hass, SCRIPT_DOMAIN, entity_entry.unique_id
|
||||
):
|
||||
self.description = service_desc.get("description")
|
||||
schema: dict[vol.Marker, Any] = {}
|
||||
fields = service_desc.get("fields", {})
|
||||
|
||||
for field, config in fields.items():
|
||||
description = config.get("description")
|
||||
if not description:
|
||||
description = config.get("name")
|
||||
if config.get("required"):
|
||||
key = vol.Required(field, description=description)
|
||||
else:
|
||||
key = vol.Optional(field, description=description)
|
||||
if "selector" in config:
|
||||
schema[key] = selector.selector(config["selector"])
|
||||
else:
|
||||
schema[key] = cv.string
|
||||
|
||||
self.parameters = vol.Schema(schema)
|
||||
|
||||
parameters_cache[entity_entry.unique_id] = (
|
||||
self.description,
|
||||
self.parameters,
|
||||
)
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||
) -> JsonObjectType:
|
||||
"""Run the script."""
|
||||
|
||||
for field, validator in self.parameters.schema.items():
|
||||
if field not in tool_input.tool_args:
|
||||
continue
|
||||
if isinstance(validator, selector.AreaSelector):
|
||||
area_reg = ar.async_get(hass)
|
||||
if validator.config.get("multiple"):
|
||||
areas: list[ar.AreaEntry] = []
|
||||
for area in tool_input.tool_args[field]:
|
||||
areas.extend(intent.find_areas(area, area_reg))
|
||||
tool_input.tool_args[field] = list({area.id for area in areas})
|
||||
else:
|
||||
area = tool_input.tool_args[field]
|
||||
area = list(intent.find_areas(area, area_reg))[0].id
|
||||
tool_input.tool_args[field] = area
|
||||
|
||||
elif isinstance(validator, selector.FloorSelector):
|
||||
floor_reg = fr.async_get(hass)
|
||||
if validator.config.get("multiple"):
|
||||
floors: list[fr.FloorEntry] = []
|
||||
for floor in tool_input.tool_args[field]:
|
||||
floors.extend(intent.find_floors(floor, floor_reg))
|
||||
tool_input.tool_args[field] = list(
|
||||
{floor.floor_id for floor in floors}
|
||||
)
|
||||
else:
|
||||
floor = tool_input.tool_args[field]
|
||||
floor = list(intent.find_floors(floor, floor_reg))[0].floor_id
|
||||
tool_input.tool_args[field] = floor
|
||||
|
||||
await hass.services.async_call(
|
||||
SCRIPT_DOMAIN,
|
||||
SERVICE_TURN_ON,
|
||||
{
|
||||
ATTR_ENTITY_ID: SCRIPT_DOMAIN + "." + self.name,
|
||||
ATTR_VARIABLES: tool_input.tool_args,
|
||||
},
|
||||
context=llm_context.context,
|
||||
)
|
||||
|
||||
return {"success": True}
|
||||
|
|
|
@ -75,6 +75,13 @@ class Selector[_T: Mapping[str, Any]]:
|
|||
|
||||
self.config = self.CONFIG_SCHEMA(config)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Check equality."""
|
||||
if not isinstance(other, Selector):
|
||||
return NotImplemented
|
||||
|
||||
return self.selector_type == other.selector_type and self.config == other.config
|
||||
|
||||
def serialize(self) -> dict[str, dict[str, _T]]:
|
||||
"""Serialize Selector for voluptuous_serialize."""
|
||||
return {"selector": {self.selector_type: self.config}}
|
||||
|
@ -278,7 +285,7 @@ class AssistPipelineSelector(Selector[AssistPipelineSelectorConfig]):
|
|||
|
||||
CONFIG_SCHEMA = vol.Schema({})
|
||||
|
||||
def __init__(self, config: AssistPipelineSelectorConfig) -> None:
|
||||
def __init__(self, config: AssistPipelineSelectorConfig | None = None) -> None:
|
||||
"""Instantiate a selector."""
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -430,10 +437,10 @@ class ColorTempSelector(Selector[ColorTempSelectorConfig]):
|
|||
range_min = self.config.get("min")
|
||||
range_max = self.config.get("max")
|
||||
|
||||
if not range_min:
|
||||
if range_min is None:
|
||||
range_min = self.config.get("min_mireds")
|
||||
|
||||
if not range_max:
|
||||
if range_max is None:
|
||||
range_max = self.config.get("max_mireds")
|
||||
|
||||
value: int = vol.All(
|
||||
|
@ -517,7 +524,7 @@ class ConstantSelector(Selector[ConstantSelectorConfig]):
|
|||
}
|
||||
)
|
||||
|
||||
def __init__(self, config: ConstantSelectorConfig | None = None) -> None:
|
||||
def __init__(self, config: ConstantSelectorConfig) -> None:
|
||||
"""Instantiate a selector."""
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -560,7 +567,7 @@ class QrCodeSelector(Selector[QrCodeSelectorConfig]):
|
|||
}
|
||||
)
|
||||
|
||||
def __init__(self, config: QrCodeSelectorConfig | None = None) -> None:
|
||||
def __init__(self, config: QrCodeSelectorConfig) -> None:
|
||||
"""Instantiate a selector."""
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -588,7 +595,7 @@ class ConversationAgentSelector(Selector[ConversationAgentSelectorConfig]):
|
|||
}
|
||||
)
|
||||
|
||||
def __init__(self, config: ConversationAgentSelectorConfig) -> None:
|
||||
def __init__(self, config: ConversationAgentSelectorConfig | None = None) -> None:
|
||||
"""Instantiate a selector."""
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -820,7 +827,7 @@ class FloorSelectorConfig(TypedDict, total=False):
|
|||
|
||||
|
||||
@SELECTORS.register("floor")
|
||||
class FloorSelector(Selector[AreaSelectorConfig]):
|
||||
class FloorSelector(Selector[FloorSelectorConfig]):
|
||||
"""Selector of a single or list of floors."""
|
||||
|
||||
selector_type = "floor"
|
||||
|
@ -934,7 +941,7 @@ class LanguageSelector(Selector[LanguageSelectorConfig]):
|
|||
}
|
||||
)
|
||||
|
||||
def __init__(self, config: LanguageSelectorConfig) -> None:
|
||||
def __init__(self, config: LanguageSelectorConfig | None = None) -> None:
|
||||
"""Instantiate a selector."""
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -1159,7 +1166,7 @@ class SelectSelector(Selector[SelectSelectorConfig]):
|
|||
}
|
||||
)
|
||||
|
||||
def __init__(self, config: SelectSelectorConfig | None = None) -> None:
|
||||
def __init__(self, config: SelectSelectorConfig) -> None:
|
||||
"""Instantiate a selector."""
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -1434,7 +1441,7 @@ class FileSelector(Selector[FileSelectorConfig]):
|
|||
}
|
||||
)
|
||||
|
||||
def __init__(self, config: FileSelectorConfig | None = None) -> None:
|
||||
def __init__(self, config: FileSelectorConfig) -> None:
|
||||
"""Instantiate a selector."""
|
||||
super().__init__(config)
|
||||
|
||||
|
|
|
@ -58,6 +58,7 @@ SQLAlchemy==2.0.31
|
|||
typing-extensions>=4.12.2,<5.0
|
||||
ulid-transform==0.9.0
|
||||
urllib3>=1.26.5,<2
|
||||
voluptuous-openapi==0.0.4
|
||||
voluptuous-serialize==2.6.0
|
||||
voluptuous==0.13.1
|
||||
webrtc-noise-gain==1.2.3
|
||||
|
|
|
@ -69,6 +69,7 @@ dependencies = [
|
|||
"urllib3>=1.26.5,<2",
|
||||
"voluptuous==0.13.1",
|
||||
"voluptuous-serialize==2.6.0",
|
||||
"voluptuous-openapi==0.0.4",
|
||||
"yarl==1.9.4",
|
||||
]
|
||||
|
||||
|
|
|
@ -41,4 +41,5 @@ ulid-transform==0.9.0
|
|||
urllib3>=1.26.5,<2
|
||||
voluptuous==0.13.1
|
||||
voluptuous-serialize==2.6.0
|
||||
voluptuous-openapi==0.0.4
|
||||
yarl==1.9.4
|
||||
|
|
|
@ -2846,10 +2846,6 @@ voip-utils==0.1.0
|
|||
# homeassistant.components.volkszaehler
|
||||
volkszaehler==0.4.0
|
||||
|
||||
# homeassistant.components.google_generative_ai_conversation
|
||||
# homeassistant.components.openai_conversation
|
||||
voluptuous-openapi==0.0.4
|
||||
|
||||
# homeassistant.components.volvooncall
|
||||
volvooncall==0.10.3
|
||||
|
||||
|
|
|
@ -2217,10 +2217,6 @@ vilfo-api-client==0.5.0
|
|||
# homeassistant.components.voip
|
||||
voip-utils==0.1.0
|
||||
|
||||
# homeassistant.components.google_generative_ai_conversation
|
||||
# homeassistant.components.openai_conversation
|
||||
voluptuous-openapi==0.0.4
|
||||
|
||||
# homeassistant.components.volvooncall
|
||||
volvooncall==0.10.3
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.components.intent import async_register_timer_handler
|
||||
from homeassistant.components.script.config import ScriptConfig
|
||||
from homeassistant.core import Context, HomeAssistant, State
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import (
|
||||
|
@ -18,6 +19,7 @@ from homeassistant.helpers import (
|
|||
floor_registry as fr,
|
||||
intent,
|
||||
llm,
|
||||
selector,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import yaml
|
||||
|
@ -564,11 +566,6 @@ async def test_assist_api_prompt(
|
|||
"names": "Unnamed Device",
|
||||
"state": "unavailable",
|
||||
},
|
||||
"script.test_script": {
|
||||
"description": "This is a test script",
|
||||
"names": "test_script",
|
||||
"state": "off",
|
||||
},
|
||||
}
|
||||
exposed_entities_prompt = (
|
||||
"An overview of the areas and the devices in this smart home:\n"
|
||||
|
@ -634,3 +631,323 @@ async def test_assist_api_prompt(
|
|||
{area_prompt}
|
||||
{exposed_entities_prompt}"""
|
||||
)
|
||||
|
||||
|
||||
async def test_script_tool(
|
||||
hass: HomeAssistant,
|
||||
area_registry: ar.AreaRegistry,
|
||||
floor_registry: fr.FloorRegistry,
|
||||
) -> None:
|
||||
"""Test ScriptTool for the assist API."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
context = Context()
|
||||
llm_context = llm.LLMContext(
|
||||
platform="test_platform",
|
||||
context=context,
|
||||
user_prompt="test_text",
|
||||
language="*",
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
)
|
||||
|
||||
# Create a script with a unique ID
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"script",
|
||||
{
|
||||
"script": {
|
||||
"test_script": {
|
||||
"description": "This is a test script",
|
||||
"sequence": [],
|
||||
"fields": {
|
||||
"beer": {"description": "Number of beers", "required": True},
|
||||
"wine": {"selector": {"number": {"min": 0, "max": 3}}},
|
||||
"where": {"selector": {"area": {}}},
|
||||
"area_list": {"selector": {"area": {"multiple": True}}},
|
||||
"floor": {"selector": {"floor": {}}},
|
||||
"floor_list": {"selector": {"floor": {"multiple": True}}},
|
||||
"extra_field": {"selector": {"area": {}}},
|
||||
},
|
||||
},
|
||||
"unexposed_script": {
|
||||
"sequence": [],
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async_expose_entity(hass, "conversation", "script.test_script", True)
|
||||
|
||||
area = area_registry.async_create("Living room")
|
||||
floor = floor_registry.async_create("2")
|
||||
|
||||
assert llm.SCRIPT_PARAMETERS_CACHE not in hass.data
|
||||
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
||||
assert len(tools) == 1
|
||||
|
||||
tool = tools[0]
|
||||
assert tool.name == "test_script"
|
||||
assert tool.description == "This is a test script"
|
||||
schema = {
|
||||
vol.Required("beer", description="Number of beers"): cv.string,
|
||||
vol.Optional("wine"): selector.NumberSelector({"min": 0, "max": 3}),
|
||||
vol.Optional("where"): selector.AreaSelector(),
|
||||
vol.Optional("area_list"): selector.AreaSelector({"multiple": True}),
|
||||
vol.Optional("floor"): selector.FloorSelector(),
|
||||
vol.Optional("floor_list"): selector.FloorSelector({"multiple": True}),
|
||||
vol.Optional("extra_field"): selector.AreaSelector(),
|
||||
}
|
||||
assert tool.parameters.schema == schema
|
||||
|
||||
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
|
||||
"test_script": ("This is a test script", vol.Schema(schema))
|
||||
}
|
||||
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name="test_script",
|
||||
tool_args={
|
||||
"beer": "3",
|
||||
"wine": 0,
|
||||
"where": "Living room",
|
||||
"area_list": ["Living room"],
|
||||
"floor": "2",
|
||||
"floor_list": ["2"],
|
||||
},
|
||||
)
|
||||
|
||||
with patch("homeassistant.core.ServiceRegistry.async_call") as mock_service_call:
|
||||
response = await api.async_call_tool(tool_input)
|
||||
|
||||
mock_service_call.assert_awaited_once_with(
|
||||
"script",
|
||||
"turn_on",
|
||||
{
|
||||
"entity_id": "script.test_script",
|
||||
"variables": {
|
||||
"beer": "3",
|
||||
"wine": 0,
|
||||
"where": area.id,
|
||||
"area_list": [area.id],
|
||||
"floor": floor.floor_id,
|
||||
"floor_list": [floor.floor_id],
|
||||
},
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
assert response == {"success": True}
|
||||
|
||||
# Test reload script with new parameters
|
||||
config = {
|
||||
"script": {
|
||||
"test_script": ScriptConfig(
|
||||
{
|
||||
"description": "This is a new test script",
|
||||
"sequence": [],
|
||||
"mode": "single",
|
||||
"max": 2,
|
||||
"max_exceeded": "WARNING",
|
||||
"trace": {},
|
||||
"fields": {
|
||||
"beer": {"description": "Number of beers", "required": True},
|
||||
},
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
with patch(
|
||||
"homeassistant.helpers.entity_component.EntityComponent.async_prepare_reload",
|
||||
return_value=config,
|
||||
):
|
||||
await hass.services.async_call("script", "reload", blocking=True)
|
||||
|
||||
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {}
|
||||
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
||||
assert len(tools) == 1
|
||||
|
||||
tool = tools[0]
|
||||
assert tool.name == "test_script"
|
||||
assert tool.description == "This is a new test script"
|
||||
schema = {vol.Required("beer", description="Number of beers"): cv.string}
|
||||
assert tool.parameters.schema == schema
|
||||
|
||||
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
|
||||
"test_script": ("This is a new test script", vol.Schema(schema))
|
||||
}
|
||||
|
||||
|
||||
async def test_selector_serializer(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> None:
|
||||
"""Test serialization of Selectors in Open API format."""
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
selector_serializer = api.custom_serializer
|
||||
|
||||
assert selector_serializer(selector.ActionSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.AddonSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.AreaSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.AreaSelector({"multiple": True})) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
assert selector_serializer(selector.AssistPipelineSelector()) == {"type": "string"}
|
||||
assert selector_serializer(
|
||||
selector.AttributeSelector({"entity_id": "sensor.test"})
|
||||
) == {"type": "string"}
|
||||
assert selector_serializer(selector.BackupLocationSelector()) == {
|
||||
"type": "string",
|
||||
"pattern": "^(?:\\/backup|\\w+)$",
|
||||
}
|
||||
assert selector_serializer(selector.BooleanSelector()) == {"type": "boolean"}
|
||||
assert selector_serializer(selector.ColorRGBSelector()) == {
|
||||
"type": "array",
|
||||
"items": {"type": "number"},
|
||||
"maxItems": 3,
|
||||
"minItems": 3,
|
||||
"format": "RGB",
|
||||
}
|
||||
assert selector_serializer(selector.ColorTempSelector()) == {"type": "number"}
|
||||
assert selector_serializer(selector.ColorTempSelector({"min": 0, "max": 1000})) == {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1000,
|
||||
}
|
||||
assert selector_serializer(
|
||||
selector.ColorTempSelector({"min_mireds": 100, "max_mireds": 1000})
|
||||
) == {"type": "number", "minimum": 100, "maximum": 1000}
|
||||
assert selector_serializer(selector.ConfigEntrySelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.ConstantSelector({"value": "test"})) == {
|
||||
"enum": ["test"]
|
||||
}
|
||||
assert selector_serializer(selector.ConstantSelector({"value": 1})) == {"enum": [1]}
|
||||
assert selector_serializer(selector.ConstantSelector({"value": True})) == {
|
||||
"enum": [True]
|
||||
}
|
||||
assert selector_serializer(selector.QrCodeSelector({"data": "test"})) == {
|
||||
"type": "string"
|
||||
}
|
||||
assert selector_serializer(selector.ConversationAgentSelector()) == {
|
||||
"type": "string"
|
||||
}
|
||||
assert selector_serializer(selector.CountrySelector()) == {
|
||||
"type": "string",
|
||||
"format": "ISO 3166-1 alpha-2",
|
||||
}
|
||||
assert selector_serializer(
|
||||
selector.CountrySelector({"countries": ["GB", "FR"]})
|
||||
) == {"type": "string", "enum": ["GB", "FR"]}
|
||||
assert selector_serializer(selector.DateSelector()) == {
|
||||
"type": "string",
|
||||
"format": "date",
|
||||
}
|
||||
assert selector_serializer(selector.DateTimeSelector()) == {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
}
|
||||
assert selector_serializer(selector.DeviceSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.DeviceSelector({"multiple": True})) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
assert selector_serializer(selector.EntitySelector()) == {
|
||||
"type": "string",
|
||||
"format": "entity_id",
|
||||
}
|
||||
assert selector_serializer(selector.EntitySelector({"multiple": True})) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": "entity_id"},
|
||||
}
|
||||
assert selector_serializer(selector.FloorSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.FloorSelector({"multiple": True})) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
assert selector_serializer(selector.IconSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.LabelSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.LabelSelector({"multiple": True})) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
assert selector_serializer(selector.LanguageSelector()) == {
|
||||
"type": "string",
|
||||
"format": "RFC 5646",
|
||||
}
|
||||
assert selector_serializer(
|
||||
selector.LanguageSelector({"languages": ["en", "fr"]})
|
||||
) == {"type": "string", "enum": ["en", "fr"]}
|
||||
assert selector_serializer(selector.LocationSelector()) == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {"type": "number"},
|
||||
"longitude": {"type": "number"},
|
||||
"radius": {"type": "number"},
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
}
|
||||
assert selector_serializer(selector.MediaSelector()) == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity_id": {"type": "string"},
|
||||
"media_content_id": {"type": "string"},
|
||||
"media_content_type": {"type": "string"},
|
||||
"metadata": {"type": "object", "additionalProperties": True},
|
||||
},
|
||||
"required": ["entity_id", "media_content_id", "media_content_type"],
|
||||
}
|
||||
assert selector_serializer(selector.NumberSelector({"mode": "box"})) == {
|
||||
"type": "number"
|
||||
}
|
||||
assert selector_serializer(selector.NumberSelector({"min": 30, "max": 100})) == {
|
||||
"type": "number",
|
||||
"minimum": 30,
|
||||
"maximum": 100,
|
||||
}
|
||||
assert selector_serializer(selector.ObjectSelector()) == {"type": "object"}
|
||||
assert selector_serializer(
|
||||
selector.SelectSelector(
|
||||
{
|
||||
"options": [
|
||||
{"value": "A", "label": "Letter A"},
|
||||
{"value": "B", "label": "Letter B"},
|
||||
{"value": "C", "label": "Letter C"},
|
||||
]
|
||||
}
|
||||
)
|
||||
) == {"type": "string", "enum": ["A", "B", "C"]}
|
||||
assert selector_serializer(
|
||||
selector.SelectSelector({"options": ["A", "B", "C"], "multiple": True})
|
||||
) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "enum": ["A", "B", "C"]},
|
||||
"uniqueItems": True,
|
||||
}
|
||||
assert selector_serializer(
|
||||
selector.StateSelector({"entity_id": "sensor.test"})
|
||||
) == {"type": "string"}
|
||||
assert selector_serializer(selector.TemplateSelector()) == {
|
||||
"type": "string",
|
||||
"format": "jinja2",
|
||||
}
|
||||
assert selector_serializer(selector.TextSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.TextSelector({"multiple": True})) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
assert selector_serializer(selector.ThemeSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.TimeSelector()) == {
|
||||
"type": "string",
|
||||
"format": "time",
|
||||
}
|
||||
assert selector_serializer(selector.TriggerSelector()) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
assert selector_serializer(selector.FileSelector({"accept": ".txt"})) == {
|
||||
"type": "string"
|
||||
}
|
||||
|
|
|
@ -55,6 +55,8 @@ def _test_selector(
|
|||
config = {selector_type: schema}
|
||||
selector.validate_selector(config)
|
||||
selector_instance = selector.selector(config)
|
||||
assert selector_instance == selector.selector(config)
|
||||
assert selector_instance != 5
|
||||
# We do not allow enums in the config, as they cannot serialize
|
||||
assert not any(isinstance(val, Enum) for val in selector_instance.config.values())
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue