Prioritize entity names over area names in Assist matching (#86982)
* Refactor async_match_states * Check entity name after state, before aliases * Give entity name matches priority over area names * Don't force result to have area * Add area alias in tests * Move name/area list creation back * Clean up PR * More clean up
This commit is contained in:
parent
f8c6e4c20a
commit
be69c81db5
4 changed files with 148 additions and 37 deletions
|
@ -11,7 +11,7 @@ import re
|
|||
from typing import IO, Any
|
||||
|
||||
from hassil.intents import Intents, ResponseType, SlotList, TextSlotList
|
||||
from hassil.recognize import recognize
|
||||
from hassil.recognize import RecognizeResult, recognize_all
|
||||
from hassil.util import merge_dict
|
||||
from home_assistant_intents import get_intents
|
||||
import yaml
|
||||
|
@ -128,7 +128,10 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
}
|
||||
|
||||
result = await self.hass.async_add_executor_job(
|
||||
recognize, user_input.text, lang_intents.intents, slot_lists
|
||||
self._recognize,
|
||||
user_input,
|
||||
lang_intents,
|
||||
slot_lists,
|
||||
)
|
||||
if result is None:
|
||||
_LOGGER.debug("No intent was matched for '%s'", user_input.text)
|
||||
|
@ -197,6 +200,26 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
def _recognize(
|
||||
self,
|
||||
user_input: ConversationInput,
|
||||
lang_intents: LanguageIntents,
|
||||
slot_lists: dict[str, SlotList],
|
||||
) -> RecognizeResult | None:
|
||||
"""Search intents for a match to user input."""
|
||||
# Prioritize matches with entity names above area names
|
||||
maybe_result: RecognizeResult | None = None
|
||||
for result in recognize_all(
|
||||
user_input.text, lang_intents.intents, slot_lists=slot_lists
|
||||
):
|
||||
if "name" in result.entities:
|
||||
return result
|
||||
|
||||
# Keep looking in case an entity has the same name
|
||||
maybe_result = result
|
||||
|
||||
return maybe_result
|
||||
|
||||
async def async_reload(self, language: str | None = None):
|
||||
"""Clear cached intents for a language."""
|
||||
if language is None:
|
||||
|
@ -373,19 +396,19 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
if self._names_list is not None:
|
||||
return self._names_list
|
||||
states = self.hass.states.async_all()
|
||||
registry = entity_registry.async_get(self.hass)
|
||||
entities = entity_registry.async_get(self.hass)
|
||||
names = []
|
||||
for state in states:
|
||||
context = {"domain": state.domain}
|
||||
|
||||
entry = registry.async_get(state.entity_id)
|
||||
if entry is not None:
|
||||
if entry.entity_category:
|
||||
entity = entities.async_get(state.entity_id)
|
||||
if entity is not None:
|
||||
if entity.entity_category:
|
||||
# Skip configuration/diagnostic entities
|
||||
continue
|
||||
|
||||
if entry.aliases:
|
||||
for alias in entry.aliases:
|
||||
if entity.aliases:
|
||||
for alias in entity.aliases:
|
||||
names.append((alias, state.entity_id, context))
|
||||
|
||||
# Default name
|
||||
|
|
|
@ -138,8 +138,10 @@ def _has_name(
|
|||
if name in (state.entity_id, state.name.casefold()):
|
||||
return True
|
||||
|
||||
# Check aliases
|
||||
if (entity is not None) and entity.aliases:
|
||||
# Check name/aliases
|
||||
if (entity is None) or (not entity.aliases):
|
||||
return False
|
||||
|
||||
for alias in entity.aliases:
|
||||
if name == alias.casefold():
|
||||
return True
|
||||
|
@ -147,6 +149,51 @@ def _has_name(
|
|||
return False
|
||||
|
||||
|
||||
def _find_area(
|
||||
id_or_name: str, areas: area_registry.AreaRegistry
|
||||
) -> area_registry.AreaEntry | None:
|
||||
"""Find an area by id or name, checking aliases too."""
|
||||
area = areas.async_get_area(id_or_name) or areas.async_get_area_by_name(id_or_name)
|
||||
if area is not None:
|
||||
return area
|
||||
|
||||
# Check area aliases
|
||||
for maybe_area in areas.areas.values():
|
||||
if not maybe_area.aliases:
|
||||
continue
|
||||
|
||||
for area_alias in maybe_area.aliases:
|
||||
if id_or_name == area_alias.casefold():
|
||||
return maybe_area
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _filter_by_area(
|
||||
states_and_entities: list[tuple[State, entity_registry.RegistryEntry | None]],
|
||||
area: area_registry.AreaEntry,
|
||||
devices: device_registry.DeviceRegistry,
|
||||
) -> Iterable[tuple[State, entity_registry.RegistryEntry | None]]:
|
||||
"""Filter state/entity pairs by an area."""
|
||||
entity_area_ids: dict[str, str | None] = {}
|
||||
for _state, entity in states_and_entities:
|
||||
if entity is None:
|
||||
continue
|
||||
|
||||
if entity.area_id:
|
||||
# Use entity's area id first
|
||||
entity_area_ids[entity.id] = entity.area_id
|
||||
elif entity.device_id:
|
||||
# Fall back to device area if not set on entity
|
||||
device = devices.async_get(entity.device_id)
|
||||
if device is not None:
|
||||
entity_area_ids[entity.id] = device.area_id
|
||||
|
||||
for state, entity in states_and_entities:
|
||||
if (entity is not None) and (entity_area_ids.get(entity.id) == area.id):
|
||||
yield (state, entity)
|
||||
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_match_states(
|
||||
|
@ -200,45 +247,29 @@ def async_match_states(
|
|||
if areas is None:
|
||||
areas = area_registry.async_get(hass)
|
||||
|
||||
# id or name
|
||||
area = areas.async_get_area(area_name) or areas.async_get_area_by_name(
|
||||
area_name
|
||||
)
|
||||
area = _find_area(area_name, areas)
|
||||
assert area is not None, f"No area named {area_name}"
|
||||
|
||||
if area is not None:
|
||||
# Filter by states/entities by area
|
||||
if devices is None:
|
||||
devices = device_registry.async_get(hass)
|
||||
|
||||
entity_area_ids: dict[str, str | None] = {}
|
||||
for _state, entity in states_and_entities:
|
||||
if entity is None:
|
||||
continue
|
||||
|
||||
if entity.area_id:
|
||||
# Use entity's area id first
|
||||
entity_area_ids[entity.id] = entity.area_id
|
||||
elif entity.device_id:
|
||||
# Fall back to device area if not set on entity
|
||||
device = devices.async_get(entity.device_id)
|
||||
if device is not None:
|
||||
entity_area_ids[entity.id] = device.area_id
|
||||
|
||||
# Filter by area
|
||||
states_and_entities = [
|
||||
(state, entity)
|
||||
for state, entity in states_and_entities
|
||||
if (entity is not None) and (entity_area_ids.get(entity.id) == area.id)
|
||||
]
|
||||
states_and_entities = list(_filter_by_area(states_and_entities, area, devices))
|
||||
|
||||
if name is not None:
|
||||
if devices is None:
|
||||
devices = device_registry.async_get(hass)
|
||||
|
||||
# Filter by name
|
||||
name = name.casefold()
|
||||
|
||||
# Check states
|
||||
for state, entity in states_and_entities:
|
||||
if _has_name(state, entity, name):
|
||||
yield state
|
||||
break
|
||||
|
||||
else:
|
||||
# Not filtered by name
|
||||
for state, _entity in states_and_entities:
|
||||
|
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.cover import SERVICE_OPEN_COVER
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME
|
||||
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context
|
||||
from homeassistant.helpers import (
|
||||
area_registry,
|
||||
|
@ -777,3 +778,51 @@ async def test_turn_on_area(hass, init_components):
|
|||
assert call.domain == HASS_DOMAIN
|
||||
assert call.service == "turn_on"
|
||||
assert call.data == {"entity_id": "light.stove"}
|
||||
|
||||
|
||||
async def test_light_area_same_name(hass, init_components):
|
||||
"""Test turning on a light with the same name as an area."""
|
||||
entities = entity_registry.async_get(hass)
|
||||
devices = device_registry.async_get(hass)
|
||||
areas = area_registry.async_get(hass)
|
||||
entry = MockConfigEntry(domain="test")
|
||||
|
||||
device = devices.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||
)
|
||||
|
||||
kitchen_area = areas.async_create("kitchen")
|
||||
devices.async_update_device(device.id, area_id=kitchen_area.id)
|
||||
|
||||
kitchen_light = entities.async_get_or_create(
|
||||
"light", "demo", "1234", original_name="kitchen light"
|
||||
)
|
||||
entities.async_update_entity(kitchen_light.entity_id, area_id=kitchen_area.id)
|
||||
hass.states.async_set(
|
||||
kitchen_light.entity_id, "off", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
|
||||
)
|
||||
|
||||
ceiling_light = entities.async_get_or_create(
|
||||
"light", "demo", "5678", original_name="ceiling light"
|
||||
)
|
||||
entities.async_update_entity(ceiling_light.entity_id, area_id=kitchen_area.id)
|
||||
hass.states.async_set(
|
||||
ceiling_light.entity_id, "off", attributes={ATTR_FRIENDLY_NAME: "ceiling light"}
|
||||
)
|
||||
|
||||
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
|
||||
|
||||
await hass.services.async_call(
|
||||
"conversation",
|
||||
"process",
|
||||
{conversation.ATTR_TEXT: "turn on kitchen light"},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Should only turn on one light instead of all lights in the kitchen
|
||||
assert len(calls) == 1
|
||||
call = calls[0]
|
||||
assert call.domain == HASS_DOMAIN
|
||||
assert call.service == "turn_on"
|
||||
assert call.data == {"entity_id": kitchen_light.entity_id}
|
||||
|
|
|
@ -27,6 +27,7 @@ async def test_async_match_states(hass):
|
|||
"""Test async_match_state helper."""
|
||||
areas = area_registry.async_get(hass)
|
||||
area_kitchen = areas.async_get_or_create("kitchen")
|
||||
areas.async_update(area_kitchen.id, aliases={"food room"})
|
||||
area_bedroom = areas.async_get_or_create("bedroom")
|
||||
|
||||
state1 = State(
|
||||
|
@ -68,6 +69,13 @@ async def test_async_match_states(hass):
|
|||
)
|
||||
)
|
||||
|
||||
# Test area alias
|
||||
assert [state1] == list(
|
||||
intent.async_match_states(
|
||||
hass, name="kitchen light", area_name="food room", states=[state1, state2]
|
||||
)
|
||||
)
|
||||
|
||||
# Wrong area
|
||||
assert not list(
|
||||
intent.async_match_states(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue