Prioritize entity names over area names in Assist matching (#86982)
* Refactor async_match_states * Check entity name after state, before aliases * Give entity name matches priority over area names * Don't force result to have area * Add area alias in tests * Move name/area list creation back * Clean up PR * More clean up
This commit is contained in:
parent
f8c6e4c20a
commit
be69c81db5
4 changed files with 148 additions and 37 deletions
|
@ -11,7 +11,7 @@ import re
|
||||||
from typing import IO, Any
|
from typing import IO, Any
|
||||||
|
|
||||||
from hassil.intents import Intents, ResponseType, SlotList, TextSlotList
|
from hassil.intents import Intents, ResponseType, SlotList, TextSlotList
|
||||||
from hassil.recognize import recognize
|
from hassil.recognize import RecognizeResult, recognize_all
|
||||||
from hassil.util import merge_dict
|
from hassil.util import merge_dict
|
||||||
from home_assistant_intents import get_intents
|
from home_assistant_intents import get_intents
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -128,7 +128,10 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
}
|
}
|
||||||
|
|
||||||
result = await self.hass.async_add_executor_job(
|
result = await self.hass.async_add_executor_job(
|
||||||
recognize, user_input.text, lang_intents.intents, slot_lists
|
self._recognize,
|
||||||
|
user_input,
|
||||||
|
lang_intents,
|
||||||
|
slot_lists,
|
||||||
)
|
)
|
||||||
if result is None:
|
if result is None:
|
||||||
_LOGGER.debug("No intent was matched for '%s'", user_input.text)
|
_LOGGER.debug("No intent was matched for '%s'", user_input.text)
|
||||||
|
@ -197,6 +200,26 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=conversation_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _recognize(
|
||||||
|
self,
|
||||||
|
user_input: ConversationInput,
|
||||||
|
lang_intents: LanguageIntents,
|
||||||
|
slot_lists: dict[str, SlotList],
|
||||||
|
) -> RecognizeResult | None:
|
||||||
|
"""Search intents for a match to user input."""
|
||||||
|
# Prioritize matches with entity names above area names
|
||||||
|
maybe_result: RecognizeResult | None = None
|
||||||
|
for result in recognize_all(
|
||||||
|
user_input.text, lang_intents.intents, slot_lists=slot_lists
|
||||||
|
):
|
||||||
|
if "name" in result.entities:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Keep looking in case an entity has the same name
|
||||||
|
maybe_result = result
|
||||||
|
|
||||||
|
return maybe_result
|
||||||
|
|
||||||
async def async_reload(self, language: str | None = None):
|
async def async_reload(self, language: str | None = None):
|
||||||
"""Clear cached intents for a language."""
|
"""Clear cached intents for a language."""
|
||||||
if language is None:
|
if language is None:
|
||||||
|
@ -373,19 +396,19 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
if self._names_list is not None:
|
if self._names_list is not None:
|
||||||
return self._names_list
|
return self._names_list
|
||||||
states = self.hass.states.async_all()
|
states = self.hass.states.async_all()
|
||||||
registry = entity_registry.async_get(self.hass)
|
entities = entity_registry.async_get(self.hass)
|
||||||
names = []
|
names = []
|
||||||
for state in states:
|
for state in states:
|
||||||
context = {"domain": state.domain}
|
context = {"domain": state.domain}
|
||||||
|
|
||||||
entry = registry.async_get(state.entity_id)
|
entity = entities.async_get(state.entity_id)
|
||||||
if entry is not None:
|
if entity is not None:
|
||||||
if entry.entity_category:
|
if entity.entity_category:
|
||||||
# Skip configuration/diagnostic entities
|
# Skip configuration/diagnostic entities
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if entry.aliases:
|
if entity.aliases:
|
||||||
for alias in entry.aliases:
|
for alias in entity.aliases:
|
||||||
names.append((alias, state.entity_id, context))
|
names.append((alias, state.entity_id, context))
|
||||||
|
|
||||||
# Default name
|
# Default name
|
||||||
|
|
|
@ -138,15 +138,62 @@ def _has_name(
|
||||||
if name in (state.entity_id, state.name.casefold()):
|
if name in (state.entity_id, state.name.casefold()):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check aliases
|
# Check name/aliases
|
||||||
if (entity is not None) and entity.aliases:
|
if (entity is None) or (not entity.aliases):
|
||||||
for alias in entity.aliases:
|
return False
|
||||||
if name == alias.casefold():
|
|
||||||
return True
|
for alias in entity.aliases:
|
||||||
|
if name == alias.casefold():
|
||||||
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _find_area(
|
||||||
|
id_or_name: str, areas: area_registry.AreaRegistry
|
||||||
|
) -> area_registry.AreaEntry | None:
|
||||||
|
"""Find an area by id or name, checking aliases too."""
|
||||||
|
area = areas.async_get_area(id_or_name) or areas.async_get_area_by_name(id_or_name)
|
||||||
|
if area is not None:
|
||||||
|
return area
|
||||||
|
|
||||||
|
# Check area aliases
|
||||||
|
for maybe_area in areas.areas.values():
|
||||||
|
if not maybe_area.aliases:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for area_alias in maybe_area.aliases:
|
||||||
|
if id_or_name == area_alias.casefold():
|
||||||
|
return maybe_area
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_by_area(
|
||||||
|
states_and_entities: list[tuple[State, entity_registry.RegistryEntry | None]],
|
||||||
|
area: area_registry.AreaEntry,
|
||||||
|
devices: device_registry.DeviceRegistry,
|
||||||
|
) -> Iterable[tuple[State, entity_registry.RegistryEntry | None]]:
|
||||||
|
"""Filter state/entity pairs by an area."""
|
||||||
|
entity_area_ids: dict[str, str | None] = {}
|
||||||
|
for _state, entity in states_and_entities:
|
||||||
|
if entity is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if entity.area_id:
|
||||||
|
# Use entity's area id first
|
||||||
|
entity_area_ids[entity.id] = entity.area_id
|
||||||
|
elif entity.device_id:
|
||||||
|
# Fall back to device area if not set on entity
|
||||||
|
device = devices.async_get(entity.device_id)
|
||||||
|
if device is not None:
|
||||||
|
entity_area_ids[entity.id] = device.area_id
|
||||||
|
|
||||||
|
for state, entity in states_and_entities:
|
||||||
|
if (entity is not None) and (entity_area_ids.get(entity.id) == area.id):
|
||||||
|
yield (state, entity)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def async_match_states(
|
def async_match_states(
|
||||||
|
@ -200,45 +247,29 @@ def async_match_states(
|
||||||
if areas is None:
|
if areas is None:
|
||||||
areas = area_registry.async_get(hass)
|
areas = area_registry.async_get(hass)
|
||||||
|
|
||||||
# id or name
|
area = _find_area(area_name, areas)
|
||||||
area = areas.async_get_area(area_name) or areas.async_get_area_by_name(
|
|
||||||
area_name
|
|
||||||
)
|
|
||||||
assert area is not None, f"No area named {area_name}"
|
assert area is not None, f"No area named {area_name}"
|
||||||
|
|
||||||
if area is not None:
|
if area is not None:
|
||||||
|
# Filter by states/entities by area
|
||||||
if devices is None:
|
if devices is None:
|
||||||
devices = device_registry.async_get(hass)
|
devices = device_registry.async_get(hass)
|
||||||
|
|
||||||
entity_area_ids: dict[str, str | None] = {}
|
states_and_entities = list(_filter_by_area(states_and_entities, area, devices))
|
||||||
for _state, entity in states_and_entities:
|
|
||||||
if entity is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if entity.area_id:
|
|
||||||
# Use entity's area id first
|
|
||||||
entity_area_ids[entity.id] = entity.area_id
|
|
||||||
elif entity.device_id:
|
|
||||||
# Fall back to device area if not set on entity
|
|
||||||
device = devices.async_get(entity.device_id)
|
|
||||||
if device is not None:
|
|
||||||
entity_area_ids[entity.id] = device.area_id
|
|
||||||
|
|
||||||
# Filter by area
|
|
||||||
states_and_entities = [
|
|
||||||
(state, entity)
|
|
||||||
for state, entity in states_and_entities
|
|
||||||
if (entity is not None) and (entity_area_ids.get(entity.id) == area.id)
|
|
||||||
]
|
|
||||||
|
|
||||||
if name is not None:
|
if name is not None:
|
||||||
|
if devices is None:
|
||||||
|
devices = device_registry.async_get(hass)
|
||||||
|
|
||||||
# Filter by name
|
# Filter by name
|
||||||
name = name.casefold()
|
name = name.casefold()
|
||||||
|
|
||||||
|
# Check states
|
||||||
for state, entity in states_and_entities:
|
for state, entity in states_and_entities:
|
||||||
if _has_name(state, entity, name):
|
if _has_name(state, entity, name):
|
||||||
yield state
|
yield state
|
||||||
break
|
break
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Not filtered by name
|
# Not filtered by name
|
||||||
for state, _entity in states_and_entities:
|
for state, _entity in states_and_entities:
|
||||||
|
|
|
@ -6,6 +6,7 @@ import pytest
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.cover import SERVICE_OPEN_COVER
|
from homeassistant.components.cover import SERVICE_OPEN_COVER
|
||||||
|
from homeassistant.const import ATTR_FRIENDLY_NAME
|
||||||
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context
|
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
area_registry,
|
area_registry,
|
||||||
|
@ -777,3 +778,51 @@ async def test_turn_on_area(hass, init_components):
|
||||||
assert call.domain == HASS_DOMAIN
|
assert call.domain == HASS_DOMAIN
|
||||||
assert call.service == "turn_on"
|
assert call.service == "turn_on"
|
||||||
assert call.data == {"entity_id": "light.stove"}
|
assert call.data == {"entity_id": "light.stove"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_light_area_same_name(hass, init_components):
|
||||||
|
"""Test turning on a light with the same name as an area."""
|
||||||
|
entities = entity_registry.async_get(hass)
|
||||||
|
devices = device_registry.async_get(hass)
|
||||||
|
areas = area_registry.async_get(hass)
|
||||||
|
entry = MockConfigEntry(domain="test")
|
||||||
|
|
||||||
|
device = devices.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||||
|
)
|
||||||
|
|
||||||
|
kitchen_area = areas.async_create("kitchen")
|
||||||
|
devices.async_update_device(device.id, area_id=kitchen_area.id)
|
||||||
|
|
||||||
|
kitchen_light = entities.async_get_or_create(
|
||||||
|
"light", "demo", "1234", original_name="kitchen light"
|
||||||
|
)
|
||||||
|
entities.async_update_entity(kitchen_light.entity_id, area_id=kitchen_area.id)
|
||||||
|
hass.states.async_set(
|
||||||
|
kitchen_light.entity_id, "off", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
|
||||||
|
)
|
||||||
|
|
||||||
|
ceiling_light = entities.async_get_or_create(
|
||||||
|
"light", "demo", "5678", original_name="ceiling light"
|
||||||
|
)
|
||||||
|
entities.async_update_entity(ceiling_light.entity_id, area_id=kitchen_area.id)
|
||||||
|
hass.states.async_set(
|
||||||
|
ceiling_light.entity_id, "off", attributes={ATTR_FRIENDLY_NAME: "ceiling light"}
|
||||||
|
)
|
||||||
|
|
||||||
|
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
"conversation",
|
||||||
|
"process",
|
||||||
|
{conversation.ATTR_TEXT: "turn on kitchen light"},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Should only turn on one light instead of all lights in the kitchen
|
||||||
|
assert len(calls) == 1
|
||||||
|
call = calls[0]
|
||||||
|
assert call.domain == HASS_DOMAIN
|
||||||
|
assert call.service == "turn_on"
|
||||||
|
assert call.data == {"entity_id": kitchen_light.entity_id}
|
||||||
|
|
|
@ -27,6 +27,7 @@ async def test_async_match_states(hass):
|
||||||
"""Test async_match_state helper."""
|
"""Test async_match_state helper."""
|
||||||
areas = area_registry.async_get(hass)
|
areas = area_registry.async_get(hass)
|
||||||
area_kitchen = areas.async_get_or_create("kitchen")
|
area_kitchen = areas.async_get_or_create("kitchen")
|
||||||
|
areas.async_update(area_kitchen.id, aliases={"food room"})
|
||||||
area_bedroom = areas.async_get_or_create("bedroom")
|
area_bedroom = areas.async_get_or_create("bedroom")
|
||||||
|
|
||||||
state1 = State(
|
state1 = State(
|
||||||
|
@ -68,6 +69,13 @@ async def test_async_match_states(hass):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test area alias
|
||||||
|
assert [state1] == list(
|
||||||
|
intent.async_match_states(
|
||||||
|
hass, name="kitchen light", area_name="food room", states=[state1, state2]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Wrong area
|
# Wrong area
|
||||||
assert not list(
|
assert not list(
|
||||||
intent.async_match_states(
|
intent.async_match_states(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue