Improve conversation typing (#106905)

This commit is contained in:
Marc Mueller 2024-01-05 18:40:34 +01:00 committed by GitHub
parent b2a4de6eed
commit 833cddc8f5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 18 deletions

View file

@ -8,6 +8,7 @@ import logging
import re import re
from typing import Any, Literal from typing import Any, Literal
from aiohttp import web
from hassil.recognize import RecognizeResult from hassil.recognize import RecognizeResult
import voluptuous as vol import voluptuous as vol
@ -108,7 +109,7 @@ def async_set_agent(
hass: core.HomeAssistant, hass: core.HomeAssistant,
config_entry: ConfigEntry, config_entry: ConfigEntry,
agent: AbstractConversationAgent, agent: AbstractConversationAgent,
): ) -> None:
"""Set the agent to handle the conversations.""" """Set the agent to handle the conversations."""
_get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent) _get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent)
@ -118,7 +119,7 @@ def async_set_agent(
def async_unset_agent( def async_unset_agent(
hass: core.HomeAssistant, hass: core.HomeAssistant,
config_entry: ConfigEntry, config_entry: ConfigEntry,
): ) -> None:
"""Set the agent to handle the conversations.""" """Set the agent to handle the conversations."""
_get_agent_manager(hass).async_unset_agent(config_entry.entry_id) _get_agent_manager(hass).async_unset_agent(config_entry.entry_id)
@ -133,7 +134,7 @@ async def async_get_conversation_languages(
all conversation agents. all conversation agents.
""" """
agent_manager = _get_agent_manager(hass) agent_manager = _get_agent_manager(hass)
languages = set() languages: set[str] = set()
agent_ids: Iterable[str] agent_ids: Iterable[str]
if agent_id is None: if agent_id is None:
@ -408,7 +409,7 @@ class ConversationProcessView(http.HomeAssistantView):
} }
) )
) )
async def post(self, request, data): async def post(self, request: web.Request, data: dict[str, str]) -> web.Response:
"""Send a request for processing.""" """Send a request for processing."""
hass = request.app["hass"] hass = request.app["hass"]

View file

@ -34,7 +34,7 @@ from homeassistant.components.homeassistant.exposed_entities import (
async_listen_entity_updates, async_listen_entity_updates,
async_should_expose, async_should_expose,
) )
from homeassistant.const import MATCH_ALL from homeassistant.const import EVENT_STATE_CHANGED, MATCH_ALL
from homeassistant.helpers import ( from homeassistant.helpers import (
area_registry as ar, area_registry as ar,
device_registry as dr, device_registry as dr,
@ -145,7 +145,7 @@ class DefaultAgent(AbstractConversationAgent):
"""Return a list of supported languages.""" """Return a list of supported languages."""
return get_domains_and_languages()["homeassistant"] return get_domains_and_languages()["homeassistant"]
async def async_initialize(self, config_intents): async def async_initialize(self, config_intents: dict[str, Any] | None) -> None:
"""Initialize the default agent.""" """Initialize the default agent."""
if "intent" not in self.hass.config.components: if "intent" not in self.hass.config.components:
await setup.async_setup_component(self.hass, "intent", {}) await setup.async_setup_component(self.hass, "intent", {})
@ -156,17 +156,17 @@ class DefaultAgent(AbstractConversationAgent):
self.hass.bus.async_listen( self.hass.bus.async_listen(
ar.EVENT_AREA_REGISTRY_UPDATED, ar.EVENT_AREA_REGISTRY_UPDATED,
self._async_handle_area_registry_changed, self._async_handle_area_registry_changed, # type: ignore[arg-type]
run_immediately=True, run_immediately=True,
) )
self.hass.bus.async_listen( self.hass.bus.async_listen(
er.EVENT_ENTITY_REGISTRY_UPDATED, er.EVENT_ENTITY_REGISTRY_UPDATED,
self._async_handle_entity_registry_changed, self._async_handle_entity_registry_changed, # type: ignore[arg-type]
run_immediately=True, run_immediately=True,
) )
self.hass.bus.async_listen( self.hass.bus.async_listen(
core.EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
self._async_handle_state_changed, self._async_handle_state_changed, # type: ignore[arg-type]
run_immediately=True, run_immediately=True,
) )
async_listen_entity_updates( async_listen_entity_updates(
@ -433,7 +433,7 @@ class DefaultAgent(AbstractConversationAgent):
return speech return speech
async def async_reload(self, language: str | None = None): async def async_reload(self, language: str | None = None) -> None:
"""Clear cached intents for a language.""" """Clear cached intents for a language."""
if language is None: if language is None:
self._lang_intents.clear() self._lang_intents.clear()
@ -442,7 +442,7 @@ class DefaultAgent(AbstractConversationAgent):
self._lang_intents.pop(language, None) self._lang_intents.pop(language, None)
_LOGGER.debug("Cleared intents for language: %s", language) _LOGGER.debug("Cleared intents for language: %s", language)
async def async_prepare(self, language: str | None = None): async def async_prepare(self, language: str | None = None) -> None:
"""Load intents for a language.""" """Load intents for a language."""
if language is None: if language is None:
language = self.hass.config.language language = self.hass.config.language
@ -594,12 +594,16 @@ class DefaultAgent(AbstractConversationAgent):
return lang_intents return lang_intents
@core.callback @core.callback
def _async_handle_area_registry_changed(self, event: core.Event) -> None: def _async_handle_area_registry_changed(
self, event: EventType[ar.EventAreaRegistryUpdatedData]
) -> None:
"""Clear area area cache when the area registry has changed.""" """Clear area area cache when the area registry has changed."""
self._slot_lists = None self._slot_lists = None
@core.callback @core.callback
def _async_handle_entity_registry_changed(self, event: core.Event) -> None: def _async_handle_entity_registry_changed(
self, event: EventType[er.EventEntityRegistryUpdatedData]
) -> None:
"""Clear names list cache when an entity registry entry has changed.""" """Clear names list cache when an entity registry entry has changed."""
if event.data["action"] != "update" or not any( if event.data["action"] != "update" or not any(
field in event.data["changes"] for field in _ENTITY_REGISTRY_UPDATE_FIELDS field in event.data["changes"] for field in _ENTITY_REGISTRY_UPDATE_FIELDS
@ -608,9 +612,11 @@ class DefaultAgent(AbstractConversationAgent):
self._slot_lists = None self._slot_lists = None
@core.callback @core.callback
def _async_handle_state_changed(self, event: core.Event) -> None: def _async_handle_state_changed(
self, event: EventType[EventStateChangedData]
) -> None:
"""Clear names list cache when a state is added or removed from the state machine.""" """Clear names list cache when a state is added or removed from the state machine."""
if event.data.get("old_state") and event.data.get("new_state"): if event.data["old_state"] and event.data["new_state"]:
return return
self._slot_lists = None self._slot_lists = None

View file

@ -1,8 +1,10 @@
"""Util for Conversation.""" """Util for Conversation."""
from __future__ import annotations
import re import re
def create_matcher(utterance): def create_matcher(utterance: str) -> re.Pattern[str]:
"""Create a regex that matches the utterance.""" """Create a regex that matches the utterance."""
# Split utterance into parts that are type: NORMAL, GROUP or OPTIONAL # Split utterance into parts that are type: NORMAL, GROUP or OPTIONAL
# Pattern matches (GROUP|OPTIONAL): Change light to [the color] {name} # Pattern matches (GROUP|OPTIONAL): Change light to [the color] {name}

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Container, Iterable, MutableMapping from collections.abc import Container, Iterable, MutableMapping
from typing import Any, cast from typing import Any, Literal, TypedDict, cast
import attr import attr
@ -22,6 +22,13 @@ STORAGE_VERSION_MINOR = 3
SAVE_DELAY = 10 SAVE_DELAY = 10
class EventAreaRegistryUpdatedData(TypedDict):
"""EventAreaRegistryUpdated data."""
action: Literal["create", "remove", "update"]
area_id: str
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class AreaEntry: class AreaEntry:
"""Area Registry Entry.""" """Area Registry Entry."""