Improve conversation typing (#106905)

This commit is contained in:
Marc Mueller 2024-01-05 18:40:34 +01:00 committed by GitHub
parent b2a4de6eed
commit 833cddc8f5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 18 deletions

View file

@ -8,6 +8,7 @@ import logging
import re
from typing import Any, Literal
from aiohttp import web
from hassil.recognize import RecognizeResult
import voluptuous as vol
@ -108,7 +109,7 @@ def async_set_agent(
hass: core.HomeAssistant,
config_entry: ConfigEntry,
agent: AbstractConversationAgent,
):
) -> None:
"""Set the agent to handle the conversations."""
_get_agent_manager(hass).async_set_agent(config_entry.entry_id, agent)
@ -118,7 +119,7 @@ def async_set_agent(
def async_unset_agent(
hass: core.HomeAssistant,
config_entry: ConfigEntry,
):
) -> None:
"""Set the agent to handle the conversations."""
_get_agent_manager(hass).async_unset_agent(config_entry.entry_id)
@ -133,7 +134,7 @@ async def async_get_conversation_languages(
all conversation agents.
"""
agent_manager = _get_agent_manager(hass)
languages = set()
languages: set[str] = set()
agent_ids: Iterable[str]
if agent_id is None:
@ -408,7 +409,7 @@ class ConversationProcessView(http.HomeAssistantView):
}
)
)
async def post(self, request, data):
async def post(self, request: web.Request, data: dict[str, str]) -> web.Response:
"""Send a request for processing."""
hass = request.app["hass"]

View file

@ -34,7 +34,7 @@ from homeassistant.components.homeassistant.exposed_entities import (
async_listen_entity_updates,
async_should_expose,
)
from homeassistant.const import MATCH_ALL
from homeassistant.const import EVENT_STATE_CHANGED, MATCH_ALL
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
@ -145,7 +145,7 @@ class DefaultAgent(AbstractConversationAgent):
"""Return a list of supported languages."""
return get_domains_and_languages()["homeassistant"]
async def async_initialize(self, config_intents):
async def async_initialize(self, config_intents: dict[str, Any] | None) -> None:
"""Initialize the default agent."""
if "intent" not in self.hass.config.components:
await setup.async_setup_component(self.hass, "intent", {})
@ -156,17 +156,17 @@ class DefaultAgent(AbstractConversationAgent):
self.hass.bus.async_listen(
ar.EVENT_AREA_REGISTRY_UPDATED,
self._async_handle_area_registry_changed,
self._async_handle_area_registry_changed, # type: ignore[arg-type]
run_immediately=True,
)
self.hass.bus.async_listen(
er.EVENT_ENTITY_REGISTRY_UPDATED,
self._async_handle_entity_registry_changed,
self._async_handle_entity_registry_changed, # type: ignore[arg-type]
run_immediately=True,
)
self.hass.bus.async_listen(
core.EVENT_STATE_CHANGED,
self._async_handle_state_changed,
EVENT_STATE_CHANGED,
self._async_handle_state_changed, # type: ignore[arg-type]
run_immediately=True,
)
async_listen_entity_updates(
@ -433,7 +433,7 @@ class DefaultAgent(AbstractConversationAgent):
return speech
async def async_reload(self, language: str | None = None):
async def async_reload(self, language: str | None = None) -> None:
"""Clear cached intents for a language."""
if language is None:
self._lang_intents.clear()
@ -442,7 +442,7 @@ class DefaultAgent(AbstractConversationAgent):
self._lang_intents.pop(language, None)
_LOGGER.debug("Cleared intents for language: %s", language)
async def async_prepare(self, language: str | None = None):
async def async_prepare(self, language: str | None = None) -> None:
"""Load intents for a language."""
if language is None:
language = self.hass.config.language
@ -594,12 +594,16 @@ class DefaultAgent(AbstractConversationAgent):
return lang_intents
@core.callback
def _async_handle_area_registry_changed(self, event: core.Event) -> None:
def _async_handle_area_registry_changed(
self, event: EventType[ar.EventAreaRegistryUpdatedData]
) -> None:
"""Clear area area cache when the area registry has changed."""
self._slot_lists = None
@core.callback
def _async_handle_entity_registry_changed(self, event: core.Event) -> None:
def _async_handle_entity_registry_changed(
self, event: EventType[er.EventEntityRegistryUpdatedData]
) -> None:
"""Clear names list cache when an entity registry entry has changed."""
if event.data["action"] != "update" or not any(
field in event.data["changes"] for field in _ENTITY_REGISTRY_UPDATE_FIELDS
@ -608,9 +612,11 @@ class DefaultAgent(AbstractConversationAgent):
self._slot_lists = None
@core.callback
def _async_handle_state_changed(self, event: core.Event) -> None:
def _async_handle_state_changed(
self, event: EventType[EventStateChangedData]
) -> None:
"""Clear names list cache when a state is added or removed from the state machine."""
if event.data.get("old_state") and event.data.get("new_state"):
if event.data["old_state"] and event.data["new_state"]:
return
self._slot_lists = None

View file

@ -1,8 +1,10 @@
"""Util for Conversation."""
from __future__ import annotations
import re
def create_matcher(utterance):
def create_matcher(utterance: str) -> re.Pattern[str]:
"""Create a regex that matches the utterance."""
# Split utterance into parts that are type: NORMAL, GROUP or OPTIONAL
# Pattern matches (GROUP|OPTIONAL): Change light to [the color] {name}

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from collections import OrderedDict
from collections.abc import Container, Iterable, MutableMapping
from typing import Any, cast
from typing import Any, Literal, TypedDict, cast
import attr
@ -22,6 +22,13 @@ STORAGE_VERSION_MINOR = 3
SAVE_DELAY = 10
class EventAreaRegistryUpdatedData(TypedDict):
"""EventAreaRegistryUpdated data."""
action: Literal["create", "remove", "update"]
area_id: str
@attr.s(slots=True, frozen=True)
class AreaEntry:
"""Area Registry Entry."""