Add script llm tool (#118936)

* Add script llm tool

* Add tests

* More tests

* more test

* more test

* Add area and floor resolving

* coverage

* coverage

* fix ColorTempSelector

* fix mypy

* fix mypy

* add script reload test

* Cache script tool parameters

* Make custom_serializer a part of api

---------

Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
Denis Shulyaka 2024-06-25 18:43:26 +03:00 committed by GitHub
parent 77fea8a73e
commit 2386ed3830
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 639 additions and 55 deletions

View file

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import codecs import codecs
from collections.abc import Callable
from typing import Any, Literal from typing import Any, Literal
from google.api_core.exceptions import GoogleAPICallError from google.api_core.exceptions import GoogleAPICallError
@ -89,10 +90,14 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
return result return result
def _format_tool(tool: llm.Tool) -> dict[str, Any]: def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> dict[str, Any]:
"""Format tool specification.""" """Format tool specification."""
parameters = _format_schema(convert(tool.parameters)) parameters = _format_schema(
convert(tool.parameters, custom_serializer=custom_serializer)
)
return protos.Tool( return protos.Tool(
{ {
@ -193,7 +198,9 @@ class GoogleGenerativeAIConversationEntity(
f"Error preparing LLM API: {err}", f"Error preparing LLM API: {err}",
) )
return result return result
tools = [_format_tool(tool) for tool in llm_api.tools] tools = [
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
]
try: try:
prompt = await self._async_render_prompt(user_input, llm_api, llm_context) prompt = await self._async_render_prompt(user_input, llm_api, llm_context)

View file

@ -9,5 +9,5 @@
"integration_type": "service", "integration_type": "service",
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"quality_scale": "platinum", "quality_scale": "platinum",
"requirements": ["google-generativeai==0.6.0", "voluptuous-openapi==0.0.4"] "requirements": ["google-generativeai==0.6.0"]
} }

View file

@ -1,7 +1,8 @@
"""Conversation support for OpenAI.""" """Conversation support for OpenAI."""
from collections.abc import Callable
import json import json
from typing import Literal from typing import Any, Literal
import openai import openai
from openai._types import NOT_GIVEN from openai._types import NOT_GIVEN
@ -58,9 +59,14 @@ async def async_setup_entry(
async_add_entities([agent]) async_add_entities([agent])
def _format_tool(tool: llm.Tool) -> ChatCompletionToolParam: def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> ChatCompletionToolParam:
"""Format tool specification.""" """Format tool specification."""
tool_spec = FunctionDefinition(name=tool.name, parameters=convert(tool.parameters)) tool_spec = FunctionDefinition(
name=tool.name,
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
)
if tool.description: if tool.description:
tool_spec["description"] = tool.description tool_spec["description"] = tool.description
return ChatCompletionToolParam(type="function", function=tool_spec) return ChatCompletionToolParam(type="function", function=tool_spec)
@ -139,7 +145,9 @@ class OpenAIConversationEntity(
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id response=intent_response, conversation_id=user_input.conversation_id
) )
tools = [_format_tool(tool) for tool in llm_api.tools] tools = [
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
]
if user_input.conversation_id is None: if user_input.conversation_id is None:
conversation_id = ulid.ulid_now() conversation_id = ulid.ulid_now()

View file

@ -8,5 +8,5 @@
"documentation": "https://www.home-assistant.io/integrations/openai_conversation", "documentation": "https://www.home-assistant.io/integrations/openai_conversation",
"integration_type": "service", "integration_type": "service",
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"requirements": ["openai==1.3.8", "voluptuous-openapi==0.0.4"] "requirements": ["openai==1.3.8"]
} }

View file

@ -352,7 +352,7 @@ class MatchTargetsCandidate:
matched_name: str | None = None matched_name: str | None = None
def _find_areas( def find_areas(
name: str, areas: area_registry.AreaRegistry name: str, areas: area_registry.AreaRegistry
) -> Iterable[area_registry.AreaEntry]: ) -> Iterable[area_registry.AreaEntry]:
"""Find all areas matching a name (including aliases).""" """Find all areas matching a name (including aliases)."""
@ -372,7 +372,7 @@ def _find_areas(
break break
def _find_floors( def find_floors(
name: str, floors: floor_registry.FloorRegistry name: str, floors: floor_registry.FloorRegistry
) -> Iterable[floor_registry.FloorEntry]: ) -> Iterable[floor_registry.FloorEntry]:
"""Find all floors matching a name (including aliases).""" """Find all floors matching a name (including aliases)."""
@ -530,7 +530,7 @@ def async_match_targets( # noqa: C901
if not states: if not states:
return MatchTargetsResult(False, MatchFailedReason.STATE) return MatchTargetsResult(False, MatchFailedReason.STATE)
# Exit early so we can to avoid registry lookups # Exit early so we can avoid registry lookups
if not ( if not (
constraints.name constraints.name
or constraints.features or constraints.features
@ -580,7 +580,7 @@ def async_match_targets( # noqa: C901
if constraints.floor_name: if constraints.floor_name:
# Filter by areas associated with floor # Filter by areas associated with floor
fr = floor_registry.async_get(hass) fr = floor_registry.async_get(hass)
targeted_floors = list(_find_floors(constraints.floor_name, fr)) targeted_floors = list(find_floors(constraints.floor_name, fr))
if not targeted_floors: if not targeted_floors:
return MatchTargetsResult( return MatchTargetsResult(
False, False,
@ -609,7 +609,7 @@ def async_match_targets( # noqa: C901
possible_area_ids = {area.id for area in ar.async_list_areas()} possible_area_ids = {area.id for area in ar.async_list_areas()}
if constraints.area_name: if constraints.area_name:
targeted_areas = list(_find_areas(constraints.area_name, ar)) targeted_areas = list(find_areas(constraints.area_name, ar))
if not targeted_areas: if not targeted_areas:
return MatchTargetsResult( return MatchTargetsResult(
False, False,

View file

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from decimal import Decimal from decimal import Decimal
from enum import Enum from enum import Enum
@ -11,6 +12,7 @@ from typing import Any
import slugify as unicode_slug import slugify as unicode_slug
import voluptuous as vol import voluptuous as vol
from voluptuous_openapi import UNSUPPORTED, convert
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
from homeassistant.components.conversation.trace import ( from homeassistant.components.conversation.trace import (
@ -20,22 +22,39 @@ from homeassistant.components.conversation.trace import (
from homeassistant.components.cover.intent import INTENT_CLOSE_COVER, INTENT_OPEN_COVER from homeassistant.components.cover.intent import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.components.intent import async_device_supports_timers from homeassistant.components.intent import async_device_supports_timers
from homeassistant.components.script import ATTR_VARIABLES, DOMAIN as SCRIPT_DOMAIN
from homeassistant.components.weather.intent import INTENT_GET_WEATHER from homeassistant.components.weather.intent import INTENT_GET_WEATHER
from homeassistant.core import Context, HomeAssistant, callback from homeassistant.const import (
ATTR_DOMAIN,
ATTR_ENTITY_ID,
ATTR_SERVICE,
EVENT_HOMEASSISTANT_CLOSE,
EVENT_SERVICE_REMOVED,
SERVICE_TURN_ON,
)
from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import yaml from homeassistant.util import yaml
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType from homeassistant.util.json import JsonObjectType
from . import ( from . import (
area_registry as ar, area_registry as ar,
config_validation as cv,
device_registry as dr, device_registry as dr,
entity_registry as er, entity_registry as er,
floor_registry as fr, floor_registry as fr,
intent, intent,
selector,
service, service,
) )
from .singleton import singleton from .singleton import singleton
SCRIPT_PARAMETERS_CACHE: HassKey[dict[str, tuple[str | None, vol.Schema]]] = HassKey(
"llm_script_parameters_cache"
)
LLM_API_ASSIST = "assist" LLM_API_ASSIST = "assist"
BASE_PROMPT = ( BASE_PROMPT = (
@ -143,6 +162,7 @@ class APIInstance:
api_prompt: str api_prompt: str
llm_context: LLMContext llm_context: LLMContext
tools: list[Tool] tools: list[Tool]
custom_serializer: Callable[[Any], Any] | None = None
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
"""Call a LLM tool, validate args and return the response.""" """Call a LLM tool, validate args and return the response."""
@ -284,6 +304,7 @@ class AssistAPI(API):
api_prompt=self._async_get_api_prompt(llm_context, exposed_entities), api_prompt=self._async_get_api_prompt(llm_context, exposed_entities),
llm_context=llm_context, llm_context=llm_context,
tools=self._async_get_tools(llm_context, exposed_entities), tools=self._async_get_tools(llm_context, exposed_entities),
custom_serializer=_selector_serializer,
) )
@callback @callback
@ -372,7 +393,7 @@ class AssistAPI(API):
exposed_domains: set[str] | None = None exposed_domains: set[str] | None = None
if exposed_entities is not None: if exposed_entities is not None:
exposed_domains = { exposed_domains = {
entity_id.split(".")[0] for entity_id in exposed_entities split_entity_id(entity_id)[0] for entity_id in exposed_entities
} }
intent_handlers = [ intent_handlers = [
intent_handler intent_handler
@ -381,11 +402,22 @@ class AssistAPI(API):
or intent_handler.platforms & exposed_domains or intent_handler.platforms & exposed_domains
] ]
return [ tools: list[Tool] = [
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler) IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
for intent_handler in intent_handlers for intent_handler in intent_handlers
] ]
if llm_context.assistant is not None:
for state in self.hass.states.async_all(SCRIPT_DOMAIN):
if not async_should_expose(
self.hass, llm_context.assistant, state.entity_id
):
continue
tools.append(ScriptTool(self.hass, state.entity_id))
return tools
def _get_exposed_entities( def _get_exposed_entities(
hass: HomeAssistant, assistant: str hass: HomeAssistant, assistant: str
@ -413,13 +445,15 @@ def _get_exposed_entities(
entities = {} entities = {}
for state in hass.states.async_all(): for state in hass.states.async_all():
if state.domain == SCRIPT_DOMAIN:
continue
if not async_should_expose(hass, assistant, state.entity_id): if not async_should_expose(hass, assistant, state.entity_id):
continue continue
entity_entry = entity_registry.async_get(state.entity_id) entity_entry = entity_registry.async_get(state.entity_id)
names = [state.name] names = [state.name]
area_names = [] area_names = []
description: str | None = None
if entity_entry is not None: if entity_entry is not None:
names.extend(entity_entry.aliases) names.extend(entity_entry.aliases)
@ -439,25 +473,11 @@ def _get_exposed_entities(
area_names.append(area.name) area_names.append(area.name)
area_names.extend(area.aliases) area_names.extend(area.aliases)
if (
state.domain == "script"
and entity_entry.unique_id
and (
service_desc := service.async_get_cached_service_description(
hass, "script", entity_entry.unique_id
)
)
):
description = service_desc.get("description")
info: dict[str, Any] = { info: dict[str, Any] = {
"names": ", ".join(names), "names": ", ".join(names),
"state": state.state, "state": state.state,
} }
if description:
info["description"] = description
if area_names: if area_names:
info["areas"] = ", ".join(area_names) info["areas"] = ", ".join(area_names)
@ -473,3 +493,231 @@ def _get_exposed_entities(
entities[state.entity_id] = info entities[state.entity_id] = info
return entities return entities
def _selector_serializer(schema: Any) -> Any: # noqa: C901
"""Convert selectors into OpenAPI schema."""
if not isinstance(schema, selector.Selector):
return UNSUPPORTED
if isinstance(schema, selector.BackupLocationSelector):
return {"type": "string", "pattern": "^(?:\\/backup|\\w+)$"}
if isinstance(schema, selector.BooleanSelector):
return {"type": "boolean"}
if isinstance(schema, selector.ColorRGBSelector):
return {
"type": "array",
"items": {"type": "number"},
"minItems": 3,
"maxItems": 3,
"format": "RGB",
}
if isinstance(schema, selector.ConditionSelector):
return convert(cv.CONDITIONS_SCHEMA)
if isinstance(schema, selector.ConstantSelector):
return {"enum": [schema.config["value"]]}
result: dict[str, Any]
if isinstance(schema, selector.ColorTempSelector):
result = {"type": "number"}
if "min" in schema.config:
result["minimum"] = schema.config["min"]
elif "min_mireds" in schema.config:
result["minimum"] = schema.config["min_mireds"]
if "max" in schema.config:
result["maximum"] = schema.config["max"]
elif "max_mireds" in schema.config:
result["maximum"] = schema.config["max_mireds"]
return result
if isinstance(schema, selector.CountrySelector):
if schema.config.get("countries"):
return {"type": "string", "enum": schema.config["countries"]}
return {"type": "string", "format": "ISO 3166-1 alpha-2"}
if isinstance(schema, selector.DateSelector):
return {"type": "string", "format": "date"}
if isinstance(schema, selector.DateTimeSelector):
return {"type": "string", "format": "date-time"}
if isinstance(schema, selector.DurationSelector):
return convert(cv.time_period_dict)
if isinstance(schema, selector.EntitySelector):
if schema.config.get("multiple"):
return {"type": "array", "items": {"type": "string", "format": "entity_id"}}
return {"type": "string", "format": "entity_id"}
if isinstance(schema, selector.LanguageSelector):
if schema.config.get("languages"):
return {"type": "string", "enum": schema.config["languages"]}
return {"type": "string", "format": "RFC 5646"}
if isinstance(schema, (selector.LocationSelector, selector.MediaSelector)):
return convert(schema.DATA_SCHEMA)
if isinstance(schema, selector.NumberSelector):
result = {"type": "number"}
if "min" in schema.config:
result["minimum"] = schema.config["min"]
if "max" in schema.config:
result["maximum"] = schema.config["max"]
return result
if isinstance(schema, selector.ObjectSelector):
return {"type": "object"}
if isinstance(schema, selector.SelectSelector):
options = [
x["value"] if isinstance(x, dict) else x for x in schema.config["options"]
]
if schema.config.get("multiple"):
return {
"type": "array",
"items": {"type": "string", "enum": options},
"uniqueItems": True,
}
return {"type": "string", "enum": options}
if isinstance(schema, selector.TargetSelector):
return convert(cv.TARGET_SERVICE_FIELDS)
if isinstance(schema, selector.TemplateSelector):
return {"type": "string", "format": "jinja2"}
if isinstance(schema, selector.TimeSelector):
return {"type": "string", "format": "time"}
if isinstance(schema, selector.TriggerSelector):
return convert(cv.TRIGGER_SCHEMA)
if schema.config.get("multiple"):
return {"type": "array", "items": {"type": "string"}}
return {"type": "string"}
class ScriptTool(Tool):
"""LLM Tool representing a Script."""
def __init__(
self,
hass: HomeAssistant,
script_entity_id: str,
) -> None:
"""Init the class."""
entity_registry = er.async_get(hass)
self.name = split_entity_id(script_entity_id)[1]
self.parameters = vol.Schema({})
entity_entry = entity_registry.async_get(script_entity_id)
if entity_entry and entity_entry.unique_id:
parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)
if parameters_cache is None:
parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {}
@callback
def clear_cache(event: Event) -> None:
"""Clear script parameter cache on script reload or delete."""
if (
event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN
and event.data[ATTR_SERVICE] in parameters_cache
):
parameters_cache.pop(event.data[ATTR_SERVICE])
cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
@callback
def on_homeassistant_close(event: Event) -> None:
"""Cleanup."""
cancel()
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close
)
if entity_entry.unique_id in parameters_cache:
self.description, self.parameters = parameters_cache[
entity_entry.unique_id
]
return
if service_desc := service.async_get_cached_service_description(
hass, SCRIPT_DOMAIN, entity_entry.unique_id
):
self.description = service_desc.get("description")
schema: dict[vol.Marker, Any] = {}
fields = service_desc.get("fields", {})
for field, config in fields.items():
description = config.get("description")
if not description:
description = config.get("name")
if config.get("required"):
key = vol.Required(field, description=description)
else:
key = vol.Optional(field, description=description)
if "selector" in config:
schema[key] = selector.selector(config["selector"])
else:
schema[key] = cv.string
self.parameters = vol.Schema(schema)
parameters_cache[entity_entry.unique_id] = (
self.description,
self.parameters,
)
async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
) -> JsonObjectType:
"""Run the script."""
for field, validator in self.parameters.schema.items():
if field not in tool_input.tool_args:
continue
if isinstance(validator, selector.AreaSelector):
area_reg = ar.async_get(hass)
if validator.config.get("multiple"):
areas: list[ar.AreaEntry] = []
for area in tool_input.tool_args[field]:
areas.extend(intent.find_areas(area, area_reg))
tool_input.tool_args[field] = list({area.id for area in areas})
else:
area = tool_input.tool_args[field]
area = list(intent.find_areas(area, area_reg))[0].id
tool_input.tool_args[field] = area
elif isinstance(validator, selector.FloorSelector):
floor_reg = fr.async_get(hass)
if validator.config.get("multiple"):
floors: list[fr.FloorEntry] = []
for floor in tool_input.tool_args[field]:
floors.extend(intent.find_floors(floor, floor_reg))
tool_input.tool_args[field] = list(
{floor.floor_id for floor in floors}
)
else:
floor = tool_input.tool_args[field]
floor = list(intent.find_floors(floor, floor_reg))[0].floor_id
tool_input.tool_args[field] = floor
await hass.services.async_call(
SCRIPT_DOMAIN,
SERVICE_TURN_ON,
{
ATTR_ENTITY_ID: SCRIPT_DOMAIN + "." + self.name,
ATTR_VARIABLES: tool_input.tool_args,
},
context=llm_context.context,
)
return {"success": True}

View file

@ -75,6 +75,13 @@ class Selector[_T: Mapping[str, Any]]:
self.config = self.CONFIG_SCHEMA(config) self.config = self.CONFIG_SCHEMA(config)
def __eq__(self, other: object) -> bool:
"""Check equality."""
if not isinstance(other, Selector):
return NotImplemented
return self.selector_type == other.selector_type and self.config == other.config
def serialize(self) -> dict[str, dict[str, _T]]: def serialize(self) -> dict[str, dict[str, _T]]:
"""Serialize Selector for voluptuous_serialize.""" """Serialize Selector for voluptuous_serialize."""
return {"selector": {self.selector_type: self.config}} return {"selector": {self.selector_type: self.config}}
@ -278,7 +285,7 @@ class AssistPipelineSelector(Selector[AssistPipelineSelectorConfig]):
CONFIG_SCHEMA = vol.Schema({}) CONFIG_SCHEMA = vol.Schema({})
def __init__(self, config: AssistPipelineSelectorConfig) -> None: def __init__(self, config: AssistPipelineSelectorConfig | None = None) -> None:
"""Instantiate a selector.""" """Instantiate a selector."""
super().__init__(config) super().__init__(config)
@ -430,10 +437,10 @@ class ColorTempSelector(Selector[ColorTempSelectorConfig]):
range_min = self.config.get("min") range_min = self.config.get("min")
range_max = self.config.get("max") range_max = self.config.get("max")
if not range_min: if range_min is None:
range_min = self.config.get("min_mireds") range_min = self.config.get("min_mireds")
if not range_max: if range_max is None:
range_max = self.config.get("max_mireds") range_max = self.config.get("max_mireds")
value: int = vol.All( value: int = vol.All(
@ -517,7 +524,7 @@ class ConstantSelector(Selector[ConstantSelectorConfig]):
} }
) )
def __init__(self, config: ConstantSelectorConfig | None = None) -> None: def __init__(self, config: ConstantSelectorConfig) -> None:
"""Instantiate a selector.""" """Instantiate a selector."""
super().__init__(config) super().__init__(config)
@ -560,7 +567,7 @@ class QrCodeSelector(Selector[QrCodeSelectorConfig]):
} }
) )
def __init__(self, config: QrCodeSelectorConfig | None = None) -> None: def __init__(self, config: QrCodeSelectorConfig) -> None:
"""Instantiate a selector.""" """Instantiate a selector."""
super().__init__(config) super().__init__(config)
@ -588,7 +595,7 @@ class ConversationAgentSelector(Selector[ConversationAgentSelectorConfig]):
} }
) )
def __init__(self, config: ConversationAgentSelectorConfig) -> None: def __init__(self, config: ConversationAgentSelectorConfig | None = None) -> None:
"""Instantiate a selector.""" """Instantiate a selector."""
super().__init__(config) super().__init__(config)
@ -820,7 +827,7 @@ class FloorSelectorConfig(TypedDict, total=False):
@SELECTORS.register("floor") @SELECTORS.register("floor")
class FloorSelector(Selector[AreaSelectorConfig]): class FloorSelector(Selector[FloorSelectorConfig]):
"""Selector of a single or list of floors.""" """Selector of a single or list of floors."""
selector_type = "floor" selector_type = "floor"
@ -934,7 +941,7 @@ class LanguageSelector(Selector[LanguageSelectorConfig]):
} }
) )
def __init__(self, config: LanguageSelectorConfig) -> None: def __init__(self, config: LanguageSelectorConfig | None = None) -> None:
"""Instantiate a selector.""" """Instantiate a selector."""
super().__init__(config) super().__init__(config)
@ -1159,7 +1166,7 @@ class SelectSelector(Selector[SelectSelectorConfig]):
} }
) )
def __init__(self, config: SelectSelectorConfig | None = None) -> None: def __init__(self, config: SelectSelectorConfig) -> None:
"""Instantiate a selector.""" """Instantiate a selector."""
super().__init__(config) super().__init__(config)
@ -1434,7 +1441,7 @@ class FileSelector(Selector[FileSelectorConfig]):
} }
) )
def __init__(self, config: FileSelectorConfig | None = None) -> None: def __init__(self, config: FileSelectorConfig) -> None:
"""Instantiate a selector.""" """Instantiate a selector."""
super().__init__(config) super().__init__(config)

View file

@ -58,6 +58,7 @@ SQLAlchemy==2.0.31
typing-extensions>=4.12.2,<5.0 typing-extensions>=4.12.2,<5.0
ulid-transform==0.9.0 ulid-transform==0.9.0
urllib3>=1.26.5,<2 urllib3>=1.26.5,<2
voluptuous-openapi==0.0.4
voluptuous-serialize==2.6.0 voluptuous-serialize==2.6.0
voluptuous==0.13.1 voluptuous==0.13.1
webrtc-noise-gain==1.2.3 webrtc-noise-gain==1.2.3

View file

@ -69,6 +69,7 @@ dependencies = [
"urllib3>=1.26.5,<2", "urllib3>=1.26.5,<2",
"voluptuous==0.13.1", "voluptuous==0.13.1",
"voluptuous-serialize==2.6.0", "voluptuous-serialize==2.6.0",
"voluptuous-openapi==0.0.4",
"yarl==1.9.4", "yarl==1.9.4",
] ]

View file

@ -41,4 +41,5 @@ ulid-transform==0.9.0
urllib3>=1.26.5,<2 urllib3>=1.26.5,<2
voluptuous==0.13.1 voluptuous==0.13.1
voluptuous-serialize==2.6.0 voluptuous-serialize==2.6.0
voluptuous-openapi==0.0.4
yarl==1.9.4 yarl==1.9.4

View file

@ -2846,10 +2846,6 @@ voip-utils==0.1.0
# homeassistant.components.volkszaehler # homeassistant.components.volkszaehler
volkszaehler==0.4.0 volkszaehler==0.4.0
# homeassistant.components.google_generative_ai_conversation
# homeassistant.components.openai_conversation
voluptuous-openapi==0.0.4
# homeassistant.components.volvooncall # homeassistant.components.volvooncall
volvooncall==0.10.3 volvooncall==0.10.3

View file

@ -2217,10 +2217,6 @@ vilfo-api-client==0.5.0
# homeassistant.components.voip # homeassistant.components.voip
voip-utils==0.1.0 voip-utils==0.1.0
# homeassistant.components.google_generative_ai_conversation
# homeassistant.components.openai_conversation
voluptuous-openapi==0.0.4
# homeassistant.components.volvooncall # homeassistant.components.volvooncall
volvooncall==0.10.3 volvooncall==0.10.3

View file

@ -8,6 +8,7 @@ import voluptuous as vol
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.components.intent import async_register_timer_handler from homeassistant.components.intent import async_register_timer_handler
from homeassistant.components.script.config import ScriptConfig
from homeassistant.core import Context, HomeAssistant, State from homeassistant.core import Context, HomeAssistant, State
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import ( from homeassistant.helpers import (
@ -18,6 +19,7 @@ from homeassistant.helpers import (
floor_registry as fr, floor_registry as fr,
intent, intent,
llm, llm,
selector,
) )
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util import yaml from homeassistant.util import yaml
@ -564,11 +566,6 @@ async def test_assist_api_prompt(
"names": "Unnamed Device", "names": "Unnamed Device",
"state": "unavailable", "state": "unavailable",
}, },
"script.test_script": {
"description": "This is a test script",
"names": "test_script",
"state": "off",
},
} }
exposed_entities_prompt = ( exposed_entities_prompt = (
"An overview of the areas and the devices in this smart home:\n" "An overview of the areas and the devices in this smart home:\n"
@ -634,3 +631,323 @@ async def test_assist_api_prompt(
{area_prompt} {area_prompt}
{exposed_entities_prompt}""" {exposed_entities_prompt}"""
) )
async def test_script_tool(
hass: HomeAssistant,
area_registry: ar.AreaRegistry,
floor_registry: fr.FloorRegistry,
) -> None:
"""Test ScriptTool for the assist API."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {})
context = Context()
llm_context = llm.LLMContext(
platform="test_platform",
context=context,
user_prompt="test_text",
language="*",
assistant="conversation",
device_id=None,
)
# Create a script with a unique ID
assert await async_setup_component(
hass,
"script",
{
"script": {
"test_script": {
"description": "This is a test script",
"sequence": [],
"fields": {
"beer": {"description": "Number of beers", "required": True},
"wine": {"selector": {"number": {"min": 0, "max": 3}}},
"where": {"selector": {"area": {}}},
"area_list": {"selector": {"area": {"multiple": True}}},
"floor": {"selector": {"floor": {}}},
"floor_list": {"selector": {"floor": {"multiple": True}}},
"extra_field": {"selector": {"area": {}}},
},
},
"unexposed_script": {
"sequence": [],
},
}
},
)
async_expose_entity(hass, "conversation", "script.test_script", True)
area = area_registry.async_create("Living room")
floor = floor_registry.async_create("2")
assert llm.SCRIPT_PARAMETERS_CACHE not in hass.data
api = await llm.async_get_api(hass, "assist", llm_context)
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
assert len(tools) == 1
tool = tools[0]
assert tool.name == "test_script"
assert tool.description == "This is a test script"
schema = {
vol.Required("beer", description="Number of beers"): cv.string,
vol.Optional("wine"): selector.NumberSelector({"min": 0, "max": 3}),
vol.Optional("where"): selector.AreaSelector(),
vol.Optional("area_list"): selector.AreaSelector({"multiple": True}),
vol.Optional("floor"): selector.FloorSelector(),
vol.Optional("floor_list"): selector.FloorSelector({"multiple": True}),
vol.Optional("extra_field"): selector.AreaSelector(),
}
assert tool.parameters.schema == schema
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
"test_script": ("This is a test script", vol.Schema(schema))
}
tool_input = llm.ToolInput(
tool_name="test_script",
tool_args={
"beer": "3",
"wine": 0,
"where": "Living room",
"area_list": ["Living room"],
"floor": "2",
"floor_list": ["2"],
},
)
with patch("homeassistant.core.ServiceRegistry.async_call") as mock_service_call:
response = await api.async_call_tool(tool_input)
mock_service_call.assert_awaited_once_with(
"script",
"turn_on",
{
"entity_id": "script.test_script",
"variables": {
"beer": "3",
"wine": 0,
"where": area.id,
"area_list": [area.id],
"floor": floor.floor_id,
"floor_list": [floor.floor_id],
},
},
context=context,
)
assert response == {"success": True}
# Test reload script with new parameters
config = {
"script": {
"test_script": ScriptConfig(
{
"description": "This is a new test script",
"sequence": [],
"mode": "single",
"max": 2,
"max_exceeded": "WARNING",
"trace": {},
"fields": {
"beer": {"description": "Number of beers", "required": True},
},
}
)
}
}
with patch(
"homeassistant.helpers.entity_component.EntityComponent.async_prepare_reload",
return_value=config,
):
await hass.services.async_call("script", "reload", blocking=True)
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {}
api = await llm.async_get_api(hass, "assist", llm_context)
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
assert len(tools) == 1
tool = tools[0]
assert tool.name == "test_script"
assert tool.description == "This is a new test script"
schema = {vol.Required("beer", description="Number of beers"): cv.string}
assert tool.parameters.schema == schema
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
"test_script": ("This is a new test script", vol.Schema(schema))
}
async def test_selector_serializer(
hass: HomeAssistant, llm_context: llm.LLMContext
) -> None:
"""Test serialization of Selectors in Open API format."""
api = await llm.async_get_api(hass, "assist", llm_context)
selector_serializer = api.custom_serializer
assert selector_serializer(selector.ActionSelector()) == {"type": "string"}
assert selector_serializer(selector.AddonSelector()) == {"type": "string"}
assert selector_serializer(selector.AreaSelector()) == {"type": "string"}
assert selector_serializer(selector.AreaSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.AssistPipelineSelector()) == {"type": "string"}
assert selector_serializer(
selector.AttributeSelector({"entity_id": "sensor.test"})
) == {"type": "string"}
assert selector_serializer(selector.BackupLocationSelector()) == {
"type": "string",
"pattern": "^(?:\\/backup|\\w+)$",
}
assert selector_serializer(selector.BooleanSelector()) == {"type": "boolean"}
assert selector_serializer(selector.ColorRGBSelector()) == {
"type": "array",
"items": {"type": "number"},
"maxItems": 3,
"minItems": 3,
"format": "RGB",
}
assert selector_serializer(selector.ColorTempSelector()) == {"type": "number"}
assert selector_serializer(selector.ColorTempSelector({"min": 0, "max": 1000})) == {
"type": "number",
"minimum": 0,
"maximum": 1000,
}
assert selector_serializer(
selector.ColorTempSelector({"min_mireds": 100, "max_mireds": 1000})
) == {"type": "number", "minimum": 100, "maximum": 1000}
assert selector_serializer(selector.ConfigEntrySelector()) == {"type": "string"}
assert selector_serializer(selector.ConstantSelector({"value": "test"})) == {
"enum": ["test"]
}
assert selector_serializer(selector.ConstantSelector({"value": 1})) == {"enum": [1]}
assert selector_serializer(selector.ConstantSelector({"value": True})) == {
"enum": [True]
}
assert selector_serializer(selector.QrCodeSelector({"data": "test"})) == {
"type": "string"
}
assert selector_serializer(selector.ConversationAgentSelector()) == {
"type": "string"
}
assert selector_serializer(selector.CountrySelector()) == {
"type": "string",
"format": "ISO 3166-1 alpha-2",
}
assert selector_serializer(
selector.CountrySelector({"countries": ["GB", "FR"]})
) == {"type": "string", "enum": ["GB", "FR"]}
assert selector_serializer(selector.DateSelector()) == {
"type": "string",
"format": "date",
}
assert selector_serializer(selector.DateTimeSelector()) == {
"type": "string",
"format": "date-time",
}
assert selector_serializer(selector.DeviceSelector()) == {"type": "string"}
assert selector_serializer(selector.DeviceSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.EntitySelector()) == {
"type": "string",
"format": "entity_id",
}
assert selector_serializer(selector.EntitySelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string", "format": "entity_id"},
}
assert selector_serializer(selector.FloorSelector()) == {"type": "string"}
assert selector_serializer(selector.FloorSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.IconSelector()) == {"type": "string"}
assert selector_serializer(selector.LabelSelector()) == {"type": "string"}
assert selector_serializer(selector.LabelSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.LanguageSelector()) == {
"type": "string",
"format": "RFC 5646",
}
assert selector_serializer(
selector.LanguageSelector({"languages": ["en", "fr"]})
) == {"type": "string", "enum": ["en", "fr"]}
assert selector_serializer(selector.LocationSelector()) == {
"type": "object",
"properties": {
"latitude": {"type": "number"},
"longitude": {"type": "number"},
"radius": {"type": "number"},
},
"required": ["latitude", "longitude"],
}
assert selector_serializer(selector.MediaSelector()) == {
"type": "object",
"properties": {
"entity_id": {"type": "string"},
"media_content_id": {"type": "string"},
"media_content_type": {"type": "string"},
"metadata": {"type": "object", "additionalProperties": True},
},
"required": ["entity_id", "media_content_id", "media_content_type"],
}
assert selector_serializer(selector.NumberSelector({"mode": "box"})) == {
"type": "number"
}
assert selector_serializer(selector.NumberSelector({"min": 30, "max": 100})) == {
"type": "number",
"minimum": 30,
"maximum": 100,
}
assert selector_serializer(selector.ObjectSelector()) == {"type": "object"}
assert selector_serializer(
selector.SelectSelector(
{
"options": [
{"value": "A", "label": "Letter A"},
{"value": "B", "label": "Letter B"},
{"value": "C", "label": "Letter C"},
]
}
)
) == {"type": "string", "enum": ["A", "B", "C"]}
assert selector_serializer(
selector.SelectSelector({"options": ["A", "B", "C"], "multiple": True})
) == {
"type": "array",
"items": {"type": "string", "enum": ["A", "B", "C"]},
"uniqueItems": True,
}
assert selector_serializer(
selector.StateSelector({"entity_id": "sensor.test"})
) == {"type": "string"}
assert selector_serializer(selector.TemplateSelector()) == {
"type": "string",
"format": "jinja2",
}
assert selector_serializer(selector.TextSelector()) == {"type": "string"}
assert selector_serializer(selector.TextSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.ThemeSelector()) == {"type": "string"}
assert selector_serializer(selector.TimeSelector()) == {
"type": "string",
"format": "time",
}
assert selector_serializer(selector.TriggerSelector()) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.FileSelector({"accept": ".txt"})) == {
"type": "string"
}

View file

@ -55,6 +55,8 @@ def _test_selector(
config = {selector_type: schema} config = {selector_type: schema}
selector.validate_selector(config) selector.validate_selector(config)
selector_instance = selector.selector(config) selector_instance = selector.selector(config)
assert selector_instance == selector.selector(config)
assert selector_instance != 5
# We do not allow enums in the config, as they cannot serialize # We do not allow enums in the config, as they cannot serialize
assert not any(isinstance(val, Enum) for val in selector_instance.config.values()) assert not any(isinstance(val, Enum) for val in selector_instance.config.values())