Expose scripts with no fields as entities (#123061)
This commit is contained in:
parent
3ddef56167
commit
e0e61b5262
2 changed files with 113 additions and 82 deletions
|
@ -420,7 +420,9 @@ class AssistAPI(API):
|
|||
):
|
||||
continue
|
||||
|
||||
tools.append(ScriptTool(self.hass, state.entity_id))
|
||||
script_tool = ScriptTool(self.hass, state.entity_id)
|
||||
if script_tool.parameters.schema:
|
||||
tools.append(script_tool)
|
||||
|
||||
return tools
|
||||
|
||||
|
@ -451,12 +453,17 @@ def _get_exposed_entities(
|
|||
entities = {}
|
||||
|
||||
for state in hass.states.async_all():
|
||||
if state.domain == SCRIPT_DOMAIN:
|
||||
continue
|
||||
|
||||
if not async_should_expose(hass, assistant, state.entity_id):
|
||||
continue
|
||||
|
||||
description: str | None = None
|
||||
if state.domain == SCRIPT_DOMAIN:
|
||||
description, parameters = _get_cached_script_parameters(
|
||||
hass, state.entity_id
|
||||
)
|
||||
if parameters.schema: # Only list scripts without input fields here
|
||||
continue
|
||||
|
||||
entity_entry = entity_registry.async_get(state.entity_id)
|
||||
names = [state.name]
|
||||
area_names = []
|
||||
|
@ -485,6 +492,9 @@ def _get_exposed_entities(
|
|||
"state": state.state,
|
||||
}
|
||||
|
||||
if description:
|
||||
info["description"] = description
|
||||
|
||||
if area_names:
|
||||
info["areas"] = ", ".join(area_names)
|
||||
|
||||
|
@ -610,6 +620,83 @@ def _selector_serializer(schema: Any) -> Any: # noqa: C901
|
|||
return {"type": "string"}
|
||||
|
||||
|
||||
def _get_cached_script_parameters(
|
||||
hass: HomeAssistant, entity_id: str
|
||||
) -> tuple[str | None, vol.Schema]:
|
||||
"""Get script description and schema."""
|
||||
entity_registry = er.async_get(hass)
|
||||
|
||||
description = None
|
||||
parameters = vol.Schema({})
|
||||
entity_entry = entity_registry.async_get(entity_id)
|
||||
if entity_entry and entity_entry.unique_id:
|
||||
parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)
|
||||
|
||||
if parameters_cache is None:
|
||||
parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {}
|
||||
|
||||
@callback
|
||||
def clear_cache(event: Event) -> None:
|
||||
"""Clear script parameter cache on script reload or delete."""
|
||||
if (
|
||||
event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN
|
||||
and event.data[ATTR_SERVICE] in parameters_cache
|
||||
):
|
||||
parameters_cache.pop(event.data[ATTR_SERVICE])
|
||||
|
||||
cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
|
||||
|
||||
@callback
|
||||
def on_homeassistant_close(event: Event) -> None:
|
||||
"""Cleanup."""
|
||||
cancel()
|
||||
|
||||
hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close
|
||||
)
|
||||
|
||||
if entity_entry.unique_id in parameters_cache:
|
||||
return parameters_cache[entity_entry.unique_id]
|
||||
|
||||
if service_desc := service.async_get_cached_service_description(
|
||||
hass, SCRIPT_DOMAIN, entity_entry.unique_id
|
||||
):
|
||||
description = service_desc.get("description")
|
||||
schema: dict[vol.Marker, Any] = {}
|
||||
fields = service_desc.get("fields", {})
|
||||
|
||||
for field, config in fields.items():
|
||||
field_description = config.get("description")
|
||||
if not field_description:
|
||||
field_description = config.get("name")
|
||||
key: vol.Marker
|
||||
if config.get("required"):
|
||||
key = vol.Required(field, description=field_description)
|
||||
else:
|
||||
key = vol.Optional(field, description=field_description)
|
||||
if "selector" in config:
|
||||
schema[key] = selector.selector(config["selector"])
|
||||
else:
|
||||
schema[key] = cv.string
|
||||
|
||||
parameters = vol.Schema(schema)
|
||||
|
||||
aliases: list[str] = []
|
||||
if entity_entry.name:
|
||||
aliases.append(entity_entry.name)
|
||||
if entity_entry.aliases:
|
||||
aliases.extend(entity_entry.aliases)
|
||||
if aliases:
|
||||
if description:
|
||||
description = description + ". Aliases: " + str(list(aliases))
|
||||
else:
|
||||
description = "Aliases: " + str(list(aliases))
|
||||
|
||||
parameters_cache[entity_entry.unique_id] = (description, parameters)
|
||||
|
||||
return description, parameters
|
||||
|
||||
|
||||
class ScriptTool(Tool):
|
||||
"""LLM Tool representing a Script."""
|
||||
|
||||
|
@ -619,86 +706,14 @@ class ScriptTool(Tool):
|
|||
script_entity_id: str,
|
||||
) -> None:
|
||||
"""Init the class."""
|
||||
entity_registry = er.async_get(hass)
|
||||
|
||||
self.name = split_entity_id(script_entity_id)[1]
|
||||
if self.name[0].isdigit():
|
||||
self.name = "_" + self.name
|
||||
self._entity_id = script_entity_id
|
||||
self.parameters = vol.Schema({})
|
||||
entity_entry = entity_registry.async_get(script_entity_id)
|
||||
if entity_entry and entity_entry.unique_id:
|
||||
parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)
|
||||
|
||||
if parameters_cache is None:
|
||||
parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {}
|
||||
|
||||
@callback
|
||||
def clear_cache(event: Event) -> None:
|
||||
"""Clear script parameter cache on script reload or delete."""
|
||||
if (
|
||||
event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN
|
||||
and event.data[ATTR_SERVICE] in parameters_cache
|
||||
):
|
||||
parameters_cache.pop(event.data[ATTR_SERVICE])
|
||||
|
||||
cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
|
||||
|
||||
@callback
|
||||
def on_homeassistant_close(event: Event) -> None:
|
||||
"""Cleanup."""
|
||||
cancel()
|
||||
|
||||
hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close
|
||||
)
|
||||
|
||||
if entity_entry.unique_id in parameters_cache:
|
||||
self.description, self.parameters = parameters_cache[
|
||||
entity_entry.unique_id
|
||||
]
|
||||
return
|
||||
|
||||
if service_desc := service.async_get_cached_service_description(
|
||||
hass, SCRIPT_DOMAIN, entity_entry.unique_id
|
||||
):
|
||||
self.description = service_desc.get("description")
|
||||
schema: dict[vol.Marker, Any] = {}
|
||||
fields = service_desc.get("fields", {})
|
||||
|
||||
for field, config in fields.items():
|
||||
description = config.get("description")
|
||||
if not description:
|
||||
description = config.get("name")
|
||||
key: vol.Marker
|
||||
if config.get("required"):
|
||||
key = vol.Required(field, description=description)
|
||||
else:
|
||||
key = vol.Optional(field, description=description)
|
||||
if "selector" in config:
|
||||
schema[key] = selector.selector(config["selector"])
|
||||
else:
|
||||
schema[key] = cv.string
|
||||
|
||||
self.parameters = vol.Schema(schema)
|
||||
|
||||
aliases: list[str] = []
|
||||
if entity_entry.name:
|
||||
aliases.append(entity_entry.name)
|
||||
if entity_entry.aliases:
|
||||
aliases.extend(entity_entry.aliases)
|
||||
if aliases:
|
||||
if self.description:
|
||||
self.description = (
|
||||
self.description + ". Aliases: " + str(list(aliases))
|
||||
)
|
||||
else:
|
||||
self.description = "Aliases: " + str(list(aliases))
|
||||
|
||||
parameters_cache[entity_entry.unique_id] = (
|
||||
self.description,
|
||||
self.parameters,
|
||||
)
|
||||
self.description, self.parameters = _get_cached_script_parameters(
|
||||
hass, script_entity_id
|
||||
)
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||
|
|
|
@ -374,11 +374,16 @@ async def test_assist_api_prompt(
|
|||
"beer": {"description": "Number of beers"},
|
||||
"wine": {},
|
||||
},
|
||||
}
|
||||
},
|
||||
"script_with_no_fields": {
|
||||
"description": "This is another test script",
|
||||
"sequence": [],
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async_expose_entity(hass, "conversation", "script.test_script", True)
|
||||
async_expose_entity(hass, "conversation", "script.script_with_no_fields", True)
|
||||
|
||||
entry = MockConfigEntry(title=None)
|
||||
entry.add_to_hass(hass)
|
||||
|
@ -511,6 +516,10 @@ async def test_assist_api_prompt(
|
|||
)
|
||||
)
|
||||
exposed_entities_prompt = """An overview of the areas and the devices in this smart home:
|
||||
- names: script_with_no_fields
|
||||
domain: script
|
||||
state: 'off'
|
||||
description: This is another test script
|
||||
- names: Kitchen
|
||||
domain: light
|
||||
state: 'on'
|
||||
|
@ -657,6 +666,10 @@ async def test_script_tool(
|
|||
"extra_field": {"selector": {"area": {}}},
|
||||
},
|
||||
},
|
||||
"script_with_no_fields": {
|
||||
"description": "This is another test script",
|
||||
"sequence": [],
|
||||
},
|
||||
"unexposed_script": {
|
||||
"sequence": [],
|
||||
},
|
||||
|
@ -664,6 +677,7 @@ async def test_script_tool(
|
|||
},
|
||||
)
|
||||
async_expose_entity(hass, "conversation", "script.test_script", True)
|
||||
async_expose_entity(hass, "conversation", "script.script_with_no_fields", True)
|
||||
|
||||
entity_registry.async_update_entity(
|
||||
"script.test_script", name="script name", aliases={"script alias"}
|
||||
|
@ -700,7 +714,8 @@ async def test_script_tool(
|
|||
"test_script": (
|
||||
"This is a test script. Aliases: ['script name', 'script alias']",
|
||||
vol.Schema(schema),
|
||||
)
|
||||
),
|
||||
"script_with_no_fields": ("This is another test script", vol.Schema({})),
|
||||
}
|
||||
|
||||
tool_input = llm.ToolInput(
|
||||
|
@ -781,7 +796,8 @@ async def test_script_tool(
|
|||
"test_script": (
|
||||
"This is a new test script. Aliases: ['script name', 'script alias']",
|
||||
vol.Schema(schema),
|
||||
)
|
||||
),
|
||||
"script_with_no_fields": ("This is another test script", vol.Schema({})),
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue