Prioritize entity names over area names in Assist matching (#86982)

* Refactor async_match_states

* Check entity name after state, before aliases

* Give entity name matches priority over area names

* Don't force result to have area

* Add area alias in tests

* Move name/area list creation back

* Clean up PR

* More clean up
This commit is contained in:
Michael Hansen 2023-01-30 22:46:25 -06:00 committed by GitHub
parent f8c6e4c20a
commit be69c81db5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 148 additions and 37 deletions

View file

@ -11,7 +11,7 @@ import re
from typing import IO, Any from typing import IO, Any
from hassil.intents import Intents, ResponseType, SlotList, TextSlotList from hassil.intents import Intents, ResponseType, SlotList, TextSlotList
from hassil.recognize import recognize from hassil.recognize import RecognizeResult, recognize_all
from hassil.util import merge_dict from hassil.util import merge_dict
from home_assistant_intents import get_intents from home_assistant_intents import get_intents
import yaml import yaml
@ -128,7 +128,10 @@ class DefaultAgent(AbstractConversationAgent):
} }
result = await self.hass.async_add_executor_job( result = await self.hass.async_add_executor_job(
recognize, user_input.text, lang_intents.intents, slot_lists self._recognize,
user_input,
lang_intents,
slot_lists,
) )
if result is None: if result is None:
_LOGGER.debug("No intent was matched for '%s'", user_input.text) _LOGGER.debug("No intent was matched for '%s'", user_input.text)
@ -197,6 +200,26 @@ class DefaultAgent(AbstractConversationAgent):
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )
def _recognize(
self,
user_input: ConversationInput,
lang_intents: LanguageIntents,
slot_lists: dict[str, SlotList],
) -> RecognizeResult | None:
"""Search intents for a match to user input."""
# Prioritize matches with entity names above area names
maybe_result: RecognizeResult | None = None
for result in recognize_all(
user_input.text, lang_intents.intents, slot_lists=slot_lists
):
if "name" in result.entities:
return result
# Keep looking in case an entity has the same name
maybe_result = result
return maybe_result
async def async_reload(self, language: str | None = None): async def async_reload(self, language: str | None = None):
"""Clear cached intents for a language.""" """Clear cached intents for a language."""
if language is None: if language is None:
@ -373,19 +396,19 @@ class DefaultAgent(AbstractConversationAgent):
if self._names_list is not None: if self._names_list is not None:
return self._names_list return self._names_list
states = self.hass.states.async_all() states = self.hass.states.async_all()
registry = entity_registry.async_get(self.hass) entities = entity_registry.async_get(self.hass)
names = [] names = []
for state in states: for state in states:
context = {"domain": state.domain} context = {"domain": state.domain}
entry = registry.async_get(state.entity_id) entity = entities.async_get(state.entity_id)
if entry is not None: if entity is not None:
if entry.entity_category: if entity.entity_category:
# Skip configuration/diagnostic entities # Skip configuration/diagnostic entities
continue continue
if entry.aliases: if entity.aliases:
for alias in entry.aliases: for alias in entity.aliases:
names.append((alias, state.entity_id, context)) names.append((alias, state.entity_id, context))
# Default name # Default name

View file

@ -138,15 +138,62 @@ def _has_name(
if name in (state.entity_id, state.name.casefold()): if name in (state.entity_id, state.name.casefold()):
return True return True
# Check aliases # Check name/aliases
if (entity is not None) and entity.aliases: if (entity is None) or (not entity.aliases):
for alias in entity.aliases: return False
if name == alias.casefold():
return True for alias in entity.aliases:
if name == alias.casefold():
return True
return False return False
def _find_area(
id_or_name: str, areas: area_registry.AreaRegistry
) -> area_registry.AreaEntry | None:
"""Find an area by id or name, checking aliases too."""
area = areas.async_get_area(id_or_name) or areas.async_get_area_by_name(id_or_name)
if area is not None:
return area
# Check area aliases
for maybe_area in areas.areas.values():
if not maybe_area.aliases:
continue
for area_alias in maybe_area.aliases:
if id_or_name == area_alias.casefold():
return maybe_area
return None
def _filter_by_area(
states_and_entities: list[tuple[State, entity_registry.RegistryEntry | None]],
area: area_registry.AreaEntry,
devices: device_registry.DeviceRegistry,
) -> Iterable[tuple[State, entity_registry.RegistryEntry | None]]:
"""Filter state/entity pairs by an area."""
entity_area_ids: dict[str, str | None] = {}
for _state, entity in states_and_entities:
if entity is None:
continue
if entity.area_id:
# Use entity's area id first
entity_area_ids[entity.id] = entity.area_id
elif entity.device_id:
# Fall back to device area if not set on entity
device = devices.async_get(entity.device_id)
if device is not None:
entity_area_ids[entity.id] = device.area_id
for state, entity in states_and_entities:
if (entity is not None) and (entity_area_ids.get(entity.id) == area.id):
yield (state, entity)
@callback @callback
@bind_hass @bind_hass
def async_match_states( def async_match_states(
@ -200,45 +247,29 @@ def async_match_states(
if areas is None: if areas is None:
areas = area_registry.async_get(hass) areas = area_registry.async_get(hass)
# id or name area = _find_area(area_name, areas)
area = areas.async_get_area(area_name) or areas.async_get_area_by_name(
area_name
)
assert area is not None, f"No area named {area_name}" assert area is not None, f"No area named {area_name}"
if area is not None: if area is not None:
# Filter by states/entities by area
if devices is None: if devices is None:
devices = device_registry.async_get(hass) devices = device_registry.async_get(hass)
entity_area_ids: dict[str, str | None] = {} states_and_entities = list(_filter_by_area(states_and_entities, area, devices))
for _state, entity in states_and_entities:
if entity is None:
continue
if entity.area_id:
# Use entity's area id first
entity_area_ids[entity.id] = entity.area_id
elif entity.device_id:
# Fall back to device area if not set on entity
device = devices.async_get(entity.device_id)
if device is not None:
entity_area_ids[entity.id] = device.area_id
# Filter by area
states_and_entities = [
(state, entity)
for state, entity in states_and_entities
if (entity is not None) and (entity_area_ids.get(entity.id) == area.id)
]
if name is not None: if name is not None:
if devices is None:
devices = device_registry.async_get(hass)
# Filter by name # Filter by name
name = name.casefold() name = name.casefold()
# Check states
for state, entity in states_and_entities: for state, entity in states_and_entities:
if _has_name(state, entity, name): if _has_name(state, entity, name):
yield state yield state
break break
else: else:
# Not filtered by name # Not filtered by name
for state, _entity in states_and_entities: for state, _entity in states_and_entities:

View file

@ -6,6 +6,7 @@ import pytest
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.cover import SERVICE_OPEN_COVER from homeassistant.components.cover import SERVICE_OPEN_COVER
from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context from homeassistant.core import DOMAIN as HASS_DOMAIN, Context
from homeassistant.helpers import ( from homeassistant.helpers import (
area_registry, area_registry,
@ -777,3 +778,51 @@ async def test_turn_on_area(hass, init_components):
assert call.domain == HASS_DOMAIN assert call.domain == HASS_DOMAIN
assert call.service == "turn_on" assert call.service == "turn_on"
assert call.data == {"entity_id": "light.stove"} assert call.data == {"entity_id": "light.stove"}
async def test_light_area_same_name(hass, init_components):
"""Test turning on a light with the same name as an area."""
entities = entity_registry.async_get(hass)
devices = device_registry.async_get(hass)
areas = area_registry.async_get(hass)
entry = MockConfigEntry(domain="test")
device = devices.async_get_or_create(
config_entry_id=entry.entry_id,
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
kitchen_area = areas.async_create("kitchen")
devices.async_update_device(device.id, area_id=kitchen_area.id)
kitchen_light = entities.async_get_or_create(
"light", "demo", "1234", original_name="kitchen light"
)
entities.async_update_entity(kitchen_light.entity_id, area_id=kitchen_area.id)
hass.states.async_set(
kitchen_light.entity_id, "off", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
)
ceiling_light = entities.async_get_or_create(
"light", "demo", "5678", original_name="ceiling light"
)
entities.async_update_entity(ceiling_light.entity_id, area_id=kitchen_area.id)
hass.states.async_set(
ceiling_light.entity_id, "off", attributes={ATTR_FRIENDLY_NAME: "ceiling light"}
)
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
await hass.services.async_call(
"conversation",
"process",
{conversation.ATTR_TEXT: "turn on kitchen light"},
)
await hass.async_block_till_done()
# Should only turn on one light instead of all lights in the kitchen
assert len(calls) == 1
call = calls[0]
assert call.domain == HASS_DOMAIN
assert call.service == "turn_on"
assert call.data == {"entity_id": kitchen_light.entity_id}

View file

@ -27,6 +27,7 @@ async def test_async_match_states(hass):
"""Test async_match_state helper.""" """Test async_match_state helper."""
areas = area_registry.async_get(hass) areas = area_registry.async_get(hass)
area_kitchen = areas.async_get_or_create("kitchen") area_kitchen = areas.async_get_or_create("kitchen")
areas.async_update(area_kitchen.id, aliases={"food room"})
area_bedroom = areas.async_get_or_create("bedroom") area_bedroom = areas.async_get_or_create("bedroom")
state1 = State( state1 = State(
@ -68,6 +69,13 @@ async def test_async_match_states(hass):
) )
) )
# Test area alias
assert [state1] == list(
intent.async_match_states(
hass, name="kitchen light", area_name="food room", states=[state1, state2]
)
)
# Wrong area # Wrong area
assert not list( assert not list(
intent.async_match_states( intent.async_match_states(