Improve conversation typing (#106905)
This commit is contained in:
parent
b2a4de6eed
commit
833cddc8f5
4 changed files with 34 additions and 18 deletions
|
@ -8,6 +8,7 @@ import logging
|
|||
import re
|
||||
from typing import Any, Literal
|
||||
|
||||
from aiohttp import web
|
||||
from hassil.recognize import RecognizeResult
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -108,7 +109,7 @@ def async_set_agent(
|
|||
hass: core.HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
agent: AbstractConversationAgent,
|
||||
):
|
||||
) -> None:
|
||||
"""Set the agent to handle the conversations."""
|
||||
_get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent)
|
||||
|
||||
|
@ -118,7 +119,7 @@ def async_set_agent(
|
|||
def async_unset_agent(
|
||||
hass: core.HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
):
|
||||
) -> None:
|
||||
"""Set the agent to handle the conversations."""
|
||||
_get_agent_manager(hass).async_unset_agent(config_entry.entry_id)
|
||||
|
||||
|
@ -133,7 +134,7 @@ async def async_get_conversation_languages(
|
|||
all conversation agents.
|
||||
"""
|
||||
agent_manager = _get_agent_manager(hass)
|
||||
languages = set()
|
||||
languages: set[str] = set()
|
||||
|
||||
agent_ids: Iterable[str]
|
||||
if agent_id is None:
|
||||
|
@ -408,7 +409,7 @@ class ConversationProcessView(http.HomeAssistantView):
|
|||
}
|
||||
)
|
||||
)
|
||||
async def post(self, request, data):
|
||||
async def post(self, request: web.Request, data: dict[str, str]) -> web.Response:
|
||||
"""Send a request for processing."""
|
||||
hass = request.app["hass"]
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ from homeassistant.components.homeassistant.exposed_entities import (
|
|||
async_listen_entity_updates,
|
||||
async_should_expose,
|
||||
)
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.const import EVENT_STATE_CHANGED, MATCH_ALL
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
|
@ -145,7 +145,7 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
"""Return a list of supported languages."""
|
||||
return get_domains_and_languages()["homeassistant"]
|
||||
|
||||
async def async_initialize(self, config_intents):
|
||||
async def async_initialize(self, config_intents: dict[str, Any] | None) -> None:
|
||||
"""Initialize the default agent."""
|
||||
if "intent" not in self.hass.config.components:
|
||||
await setup.async_setup_component(self.hass, "intent", {})
|
||||
|
@ -156,17 +156,17 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
|
||||
self.hass.bus.async_listen(
|
||||
ar.EVENT_AREA_REGISTRY_UPDATED,
|
||||
self._async_handle_area_registry_changed,
|
||||
self._async_handle_area_registry_changed, # type: ignore[arg-type]
|
||||
run_immediately=True,
|
||||
)
|
||||
self.hass.bus.async_listen(
|
||||
er.EVENT_ENTITY_REGISTRY_UPDATED,
|
||||
self._async_handle_entity_registry_changed,
|
||||
self._async_handle_entity_registry_changed, # type: ignore[arg-type]
|
||||
run_immediately=True,
|
||||
)
|
||||
self.hass.bus.async_listen(
|
||||
core.EVENT_STATE_CHANGED,
|
||||
self._async_handle_state_changed,
|
||||
EVENT_STATE_CHANGED,
|
||||
self._async_handle_state_changed, # type: ignore[arg-type]
|
||||
run_immediately=True,
|
||||
)
|
||||
async_listen_entity_updates(
|
||||
|
@ -433,7 +433,7 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
|
||||
return speech
|
||||
|
||||
async def async_reload(self, language: str | None = None):
|
||||
async def async_reload(self, language: str | None = None) -> None:
|
||||
"""Clear cached intents for a language."""
|
||||
if language is None:
|
||||
self._lang_intents.clear()
|
||||
|
@ -442,7 +442,7 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
self._lang_intents.pop(language, None)
|
||||
_LOGGER.debug("Cleared intents for language: %s", language)
|
||||
|
||||
async def async_prepare(self, language: str | None = None):
|
||||
async def async_prepare(self, language: str | None = None) -> None:
|
||||
"""Load intents for a language."""
|
||||
if language is None:
|
||||
language = self.hass.config.language
|
||||
|
@ -594,12 +594,16 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
return lang_intents
|
||||
|
||||
@core.callback
|
||||
def _async_handle_area_registry_changed(self, event: core.Event) -> None:
|
||||
def _async_handle_area_registry_changed(
|
||||
self, event: EventType[ar.EventAreaRegistryUpdatedData]
|
||||
) -> None:
|
||||
"""Clear area area cache when the area registry has changed."""
|
||||
self._slot_lists = None
|
||||
|
||||
@core.callback
|
||||
def _async_handle_entity_registry_changed(self, event: core.Event) -> None:
|
||||
def _async_handle_entity_registry_changed(
|
||||
self, event: EventType[er.EventEntityRegistryUpdatedData]
|
||||
) -> None:
|
||||
"""Clear names list cache when an entity registry entry has changed."""
|
||||
if event.data["action"] != "update" or not any(
|
||||
field in event.data["changes"] for field in _ENTITY_REGISTRY_UPDATE_FIELDS
|
||||
|
@ -608,9 +612,11 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
self._slot_lists = None
|
||||
|
||||
@core.callback
|
||||
def _async_handle_state_changed(self, event: core.Event) -> None:
|
||||
def _async_handle_state_changed(
|
||||
self, event: EventType[EventStateChangedData]
|
||||
) -> None:
|
||||
"""Clear names list cache when a state is added or removed from the state machine."""
|
||||
if event.data.get("old_state") and event.data.get("new_state"):
|
||||
if event.data["old_state"] and event.data["new_state"]:
|
||||
return
|
||||
self._slot_lists = None
|
||||
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
"""Util for Conversation."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def create_matcher(utterance):
|
||||
def create_matcher(utterance: str) -> re.Pattern[str]:
|
||||
"""Create a regex that matches the utterance."""
|
||||
# Split utterance into parts that are type: NORMAL, GROUP or OPTIONAL
|
||||
# Pattern matches (GROUP|OPTIONAL): Change light to [the color] {name}
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Container, Iterable, MutableMapping
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -22,6 +22,13 @@ STORAGE_VERSION_MINOR = 3
|
|||
SAVE_DELAY = 10
|
||||
|
||||
|
||||
class EventAreaRegistryUpdatedData(TypedDict):
|
||||
"""EventAreaRegistryUpdated data."""
|
||||
|
||||
action: Literal["create", "remove", "update"]
|
||||
area_id: str
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class AreaEntry:
|
||||
"""Area Registry Entry."""
|
||||
|
|
Loading…
Add table
Reference in a new issue