Improve conversation typing (#106905)
This commit is contained in:
parent
b2a4de6eed
commit
833cddc8f5
4 changed files with 34 additions and 18 deletions
|
@ -8,6 +8,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
from hassil.recognize import RecognizeResult
|
from hassil.recognize import RecognizeResult
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -108,7 +109,7 @@ def async_set_agent(
|
||||||
hass: core.HomeAssistant,
|
hass: core.HomeAssistant,
|
||||||
config_entry: ConfigEntry,
|
config_entry: ConfigEntry,
|
||||||
agent: AbstractConversationAgent,
|
agent: AbstractConversationAgent,
|
||||||
):
|
) -> None:
|
||||||
"""Set the agent to handle the conversations."""
|
"""Set the agent to handle the conversations."""
|
||||||
_get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent)
|
_get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent)
|
||||||
|
|
||||||
|
@ -118,7 +119,7 @@ def async_set_agent(
|
||||||
def async_unset_agent(
|
def async_unset_agent(
|
||||||
hass: core.HomeAssistant,
|
hass: core.HomeAssistant,
|
||||||
config_entry: ConfigEntry,
|
config_entry: ConfigEntry,
|
||||||
):
|
) -> None:
|
||||||
"""Set the agent to handle the conversations."""
|
"""Set the agent to handle the conversations."""
|
||||||
_get_agent_manager(hass).async_unset_agent(config_entry.entry_id)
|
_get_agent_manager(hass).async_unset_agent(config_entry.entry_id)
|
||||||
|
|
||||||
|
@ -133,7 +134,7 @@ async def async_get_conversation_languages(
|
||||||
all conversation agents.
|
all conversation agents.
|
||||||
"""
|
"""
|
||||||
agent_manager = _get_agent_manager(hass)
|
agent_manager = _get_agent_manager(hass)
|
||||||
languages = set()
|
languages: set[str] = set()
|
||||||
|
|
||||||
agent_ids: Iterable[str]
|
agent_ids: Iterable[str]
|
||||||
if agent_id is None:
|
if agent_id is None:
|
||||||
|
@ -408,7 +409,7 @@ class ConversationProcessView(http.HomeAssistantView):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
async def post(self, request, data):
|
async def post(self, request: web.Request, data: dict[str, str]) -> web.Response:
|
||||||
"""Send a request for processing."""
|
"""Send a request for processing."""
|
||||||
hass = request.app["hass"]
|
hass = request.app["hass"]
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ from homeassistant.components.homeassistant.exposed_entities import (
|
||||||
async_listen_entity_updates,
|
async_listen_entity_updates,
|
||||||
async_should_expose,
|
async_should_expose,
|
||||||
)
|
)
|
||||||
from homeassistant.const import MATCH_ALL
|
from homeassistant.const import EVENT_STATE_CHANGED, MATCH_ALL
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
area_registry as ar,
|
area_registry as ar,
|
||||||
device_registry as dr,
|
device_registry as dr,
|
||||||
|
@ -145,7 +145,7 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
"""Return a list of supported languages."""
|
"""Return a list of supported languages."""
|
||||||
return get_domains_and_languages()["homeassistant"]
|
return get_domains_and_languages()["homeassistant"]
|
||||||
|
|
||||||
async def async_initialize(self, config_intents):
|
async def async_initialize(self, config_intents: dict[str, Any] | None) -> None:
|
||||||
"""Initialize the default agent."""
|
"""Initialize the default agent."""
|
||||||
if "intent" not in self.hass.config.components:
|
if "intent" not in self.hass.config.components:
|
||||||
await setup.async_setup_component(self.hass, "intent", {})
|
await setup.async_setup_component(self.hass, "intent", {})
|
||||||
|
@ -156,17 +156,17 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
|
|
||||||
self.hass.bus.async_listen(
|
self.hass.bus.async_listen(
|
||||||
ar.EVENT_AREA_REGISTRY_UPDATED,
|
ar.EVENT_AREA_REGISTRY_UPDATED,
|
||||||
self._async_handle_area_registry_changed,
|
self._async_handle_area_registry_changed, # type: ignore[arg-type]
|
||||||
run_immediately=True,
|
run_immediately=True,
|
||||||
)
|
)
|
||||||
self.hass.bus.async_listen(
|
self.hass.bus.async_listen(
|
||||||
er.EVENT_ENTITY_REGISTRY_UPDATED,
|
er.EVENT_ENTITY_REGISTRY_UPDATED,
|
||||||
self._async_handle_entity_registry_changed,
|
self._async_handle_entity_registry_changed, # type: ignore[arg-type]
|
||||||
run_immediately=True,
|
run_immediately=True,
|
||||||
)
|
)
|
||||||
self.hass.bus.async_listen(
|
self.hass.bus.async_listen(
|
||||||
core.EVENT_STATE_CHANGED,
|
EVENT_STATE_CHANGED,
|
||||||
self._async_handle_state_changed,
|
self._async_handle_state_changed, # type: ignore[arg-type]
|
||||||
run_immediately=True,
|
run_immediately=True,
|
||||||
)
|
)
|
||||||
async_listen_entity_updates(
|
async_listen_entity_updates(
|
||||||
|
@ -433,7 +433,7 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
|
|
||||||
return speech
|
return speech
|
||||||
|
|
||||||
async def async_reload(self, language: str | None = None):
|
async def async_reload(self, language: str | None = None) -> None:
|
||||||
"""Clear cached intents for a language."""
|
"""Clear cached intents for a language."""
|
||||||
if language is None:
|
if language is None:
|
||||||
self._lang_intents.clear()
|
self._lang_intents.clear()
|
||||||
|
@ -442,7 +442,7 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
self._lang_intents.pop(language, None)
|
self._lang_intents.pop(language, None)
|
||||||
_LOGGER.debug("Cleared intents for language: %s", language)
|
_LOGGER.debug("Cleared intents for language: %s", language)
|
||||||
|
|
||||||
async def async_prepare(self, language: str | None = None):
|
async def async_prepare(self, language: str | None = None) -> None:
|
||||||
"""Load intents for a language."""
|
"""Load intents for a language."""
|
||||||
if language is None:
|
if language is None:
|
||||||
language = self.hass.config.language
|
language = self.hass.config.language
|
||||||
|
@ -594,12 +594,16 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
return lang_intents
|
return lang_intents
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
def _async_handle_area_registry_changed(self, event: core.Event) -> None:
|
def _async_handle_area_registry_changed(
|
||||||
|
self, event: EventType[ar.EventAreaRegistryUpdatedData]
|
||||||
|
) -> None:
|
||||||
"""Clear area area cache when the area registry has changed."""
|
"""Clear area area cache when the area registry has changed."""
|
||||||
self._slot_lists = None
|
self._slot_lists = None
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
def _async_handle_entity_registry_changed(self, event: core.Event) -> None:
|
def _async_handle_entity_registry_changed(
|
||||||
|
self, event: EventType[er.EventEntityRegistryUpdatedData]
|
||||||
|
) -> None:
|
||||||
"""Clear names list cache when an entity registry entry has changed."""
|
"""Clear names list cache when an entity registry entry has changed."""
|
||||||
if event.data["action"] != "update" or not any(
|
if event.data["action"] != "update" or not any(
|
||||||
field in event.data["changes"] for field in _ENTITY_REGISTRY_UPDATE_FIELDS
|
field in event.data["changes"] for field in _ENTITY_REGISTRY_UPDATE_FIELDS
|
||||||
|
@ -608,9 +612,11 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
self._slot_lists = None
|
self._slot_lists = None
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
def _async_handle_state_changed(self, event: core.Event) -> None:
|
def _async_handle_state_changed(
|
||||||
|
self, event: EventType[EventStateChangedData]
|
||||||
|
) -> None:
|
||||||
"""Clear names list cache when a state is added or removed from the state machine."""
|
"""Clear names list cache when a state is added or removed from the state machine."""
|
||||||
if event.data.get("old_state") and event.data.get("new_state"):
|
if event.data["old_state"] and event.data["new_state"]:
|
||||||
return
|
return
|
||||||
self._slot_lists = None
|
self._slot_lists = None
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
"""Util for Conversation."""
|
"""Util for Conversation."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
def create_matcher(utterance):
|
def create_matcher(utterance: str) -> re.Pattern[str]:
|
||||||
"""Create a regex that matches the utterance."""
|
"""Create a regex that matches the utterance."""
|
||||||
# Split utterance into parts that are type: NORMAL, GROUP or OPTIONAL
|
# Split utterance into parts that are type: NORMAL, GROUP or OPTIONAL
|
||||||
# Pattern matches (GROUP|OPTIONAL): Change light to [the color] {name}
|
# Pattern matches (GROUP|OPTIONAL): Change light to [the color] {name}
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import Container, Iterable, MutableMapping
|
from collections.abc import Container, Iterable, MutableMapping
|
||||||
from typing import Any, cast
|
from typing import Any, Literal, TypedDict, cast
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -22,6 +22,13 @@ STORAGE_VERSION_MINOR = 3
|
||||||
SAVE_DELAY = 10
|
SAVE_DELAY = 10
|
||||||
|
|
||||||
|
|
||||||
|
class EventAreaRegistryUpdatedData(TypedDict):
|
||||||
|
"""EventAreaRegistryUpdated data."""
|
||||||
|
|
||||||
|
action: Literal["create", "remove", "update"]
|
||||||
|
area_id: str
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(slots=True, frozen=True)
|
||||||
class AreaEntry:
|
class AreaEntry:
|
||||||
"""Area Registry Entry."""
|
"""Area Registry Entry."""
|
||||||
|
|
Loading…
Add table
Reference in a new issue