Start type annotating/testing helpers (#17858)
* Add type hints to helpers.intent and location * Test typing for helpers.icon, json, and typing * Add type hints to helpers.state * Add type hints to helpers.translation
This commit is contained in:
parent
0f877711a0
commit
c9c707e368
5 changed files with 92 additions and 55 deletions
|
@ -1,17 +1,20 @@
|
|||
"""Module to coordinate user intentions."""
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Callable, Dict, Iterable, Optional
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import ATTR_SUPPORTED_FEATURES
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import callback, State, T
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.const import ATTR_ENTITY_ID
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_SlotsType = Dict[str, Any]
|
||||
|
||||
INTENT_TURN_OFF = 'HassTurnOff'
|
||||
INTENT_TURN_ON = 'HassTurnOn'
|
||||
|
@ -28,7 +31,7 @@ SPEECH_TYPE_SSML = 'ssml'
|
|||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_register(hass, handler):
|
||||
def async_register(hass: HomeAssistantType, handler: 'IntentHandler') -> None:
|
||||
"""Register an intent with Home Assistant."""
|
||||
intents = hass.data.get(DATA_KEY)
|
||||
if intents is None:
|
||||
|
@ -44,10 +47,12 @@ def async_register(hass, handler):
|
|||
|
||||
|
||||
@bind_hass
|
||||
async def async_handle(hass, platform, intent_type, slots=None,
|
||||
text_input=None):
|
||||
async def async_handle(hass: HomeAssistantType, platform: str,
|
||||
intent_type: str, slots: Optional[_SlotsType] = None,
|
||||
text_input: Optional[str] = None) -> 'IntentResponse':
|
||||
"""Handle an intent."""
|
||||
handler = hass.data.get(DATA_KEY, {}).get(intent_type)
|
||||
handler = \
|
||||
hass.data.get(DATA_KEY, {}).get(intent_type) # type: IntentHandler
|
||||
|
||||
if handler is None:
|
||||
raise UnknownIntent('Unknown intent {}'.format(intent_type))
|
||||
|
@ -93,7 +98,8 @@ class IntentUnexpectedError(IntentError):
|
|||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_match_state(hass, name, states=None):
|
||||
def async_match_state(hass: HomeAssistantType, name: str,
|
||||
states: Optional[Iterable[State]] = None) -> State:
|
||||
"""Find a state that matches the name."""
|
||||
if states is None:
|
||||
states = hass.states.async_all()
|
||||
|
@ -108,7 +114,7 @@ def async_match_state(hass, name, states=None):
|
|||
|
||||
|
||||
@callback
|
||||
def async_test_feature(state, feature, feature_name):
|
||||
def async_test_feature(state: State, feature: int, feature_name: str) -> None:
|
||||
"""Test is state supports a feature."""
|
||||
if state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) & feature == 0:
|
||||
raise IntentHandleError(
|
||||
|
@ -119,18 +125,18 @@ def async_test_feature(state, feature, feature_name):
|
|||
class IntentHandler:
|
||||
"""Intent handler registration."""
|
||||
|
||||
intent_type = None
|
||||
slot_schema = None
|
||||
intent_type = None # type: Optional[str]
|
||||
slot_schema = None # type: Optional[vol.Schema]
|
||||
_slot_schema = None
|
||||
platforms = []
|
||||
platforms = [] # type: Optional[Iterable[str]]
|
||||
|
||||
@callback
|
||||
def async_can_handle(self, intent_obj):
|
||||
def async_can_handle(self, intent_obj: 'Intent') -> bool:
|
||||
"""Test if an intent can be handled."""
|
||||
return self.platforms is None or intent_obj.platform in self.platforms
|
||||
|
||||
@callback
|
||||
def async_validate_slots(self, slots):
|
||||
def async_validate_slots(self, slots: _SlotsType) -> _SlotsType:
|
||||
"""Validate slot information."""
|
||||
if self.slot_schema is None:
|
||||
return slots
|
||||
|
@ -141,18 +147,19 @@ class IntentHandler:
|
|||
for key, validator in self.slot_schema.items()},
|
||||
extra=vol.ALLOW_EXTRA)
|
||||
|
||||
return self._slot_schema(slots)
|
||||
return self._slot_schema(slots) # type: ignore
|
||||
|
||||
async def async_handle(self, intent_obj):
|
||||
async def async_handle(self, intent_obj: 'Intent') -> 'IntentResponse':
|
||||
"""Handle the intent."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Represent a string of an intent handler."""
|
||||
return '<{} - {}>'.format(self.__class__.__name__, self.intent_type)
|
||||
|
||||
|
||||
def _fuzzymatch(name, items, key):
|
||||
def _fuzzymatch(name: str, items: Iterable[T], key: Callable[[T], str]) \
|
||||
-> Optional[T]:
|
||||
"""Fuzzy matching function."""
|
||||
matches = []
|
||||
pattern = '.*?'.join(name)
|
||||
|
@ -176,14 +183,15 @@ class ServiceIntentHandler(IntentHandler):
|
|||
vol.Required('name'): cv.string,
|
||||
}
|
||||
|
||||
def __init__(self, intent_type, domain, service, speech):
|
||||
def __init__(self, intent_type: str, domain: str, service: str,
|
||||
speech: str) -> None:
|
||||
"""Create Service Intent Handler."""
|
||||
self.intent_type = intent_type
|
||||
self.domain = domain
|
||||
self.service = service
|
||||
self.speech = speech
|
||||
|
||||
async def async_handle(self, intent_obj):
|
||||
async def async_handle(self, intent_obj: 'Intent') -> 'IntentResponse':
|
||||
"""Handle the hass intent."""
|
||||
hass = intent_obj.hass
|
||||
slots = self.async_validate_slots(intent_obj.slots)
|
||||
|
@ -203,7 +211,9 @@ class Intent:
|
|||
|
||||
__slots__ = ['hass', 'platform', 'intent_type', 'slots', 'text_input']
|
||||
|
||||
def __init__(self, hass, platform, intent_type, slots, text_input):
|
||||
def __init__(self, hass: HomeAssistantType, platform: str,
|
||||
intent_type: str, slots: _SlotsType,
|
||||
text_input: Optional[str]) -> None:
|
||||
"""Initialize an intent."""
|
||||
self.hass = hass
|
||||
self.platform = platform
|
||||
|
@ -212,7 +222,7 @@ class Intent:
|
|||
self.text_input = text_input
|
||||
|
||||
@callback
|
||||
def create_response(self):
|
||||
def create_response(self) -> 'IntentResponse':
|
||||
"""Create a response."""
|
||||
return IntentResponse(self)
|
||||
|
||||
|
@ -220,14 +230,15 @@ class Intent:
|
|||
class IntentResponse:
|
||||
"""Response to an intent."""
|
||||
|
||||
def __init__(self, intent=None):
|
||||
def __init__(self, intent: Optional[Intent] = None) -> None:
|
||||
"""Initialize an IntentResponse."""
|
||||
self.intent = intent
|
||||
self.speech = {}
|
||||
self.card = {}
|
||||
self.speech = {} # type: Dict[str, Dict[str, Any]]
|
||||
self.card = {} # type: Dict[str, Dict[str, str]]
|
||||
|
||||
@callback
|
||||
def async_set_speech(self, speech, speech_type='plain', extra_data=None):
|
||||
def async_set_speech(self, speech: str, speech_type: str = 'plain',
|
||||
extra_data: Optional[Any] = None) -> None:
|
||||
"""Set speech response."""
|
||||
self.speech[speech_type] = {
|
||||
'speech': speech,
|
||||
|
@ -235,7 +246,8 @@ class IntentResponse:
|
|||
}
|
||||
|
||||
@callback
|
||||
def async_set_card(self, title, content, card_type='simple'):
|
||||
def async_set_card(self, title: str, content: str,
|
||||
card_type: str = 'simple') -> None:
|
||||
"""Set speech response."""
|
||||
self.card[card_type] = {
|
||||
'title': title,
|
||||
|
@ -243,7 +255,7 @@ class IntentResponse:
|
|||
}
|
||||
|
||||
@callback
|
||||
def as_dict(self):
|
||||
def as_dict(self) -> Dict[str, Dict[str, Dict[str, Any]]]:
|
||||
"""Return a dictionary representation of an intent response."""
|
||||
return {
|
||||
'speech': self.speech,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""Location helpers for Home Assistant."""
|
||||
|
||||
from typing import Sequence
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE
|
||||
from homeassistant.core import State
|
||||
|
@ -18,7 +18,7 @@ def has_location(state: State) -> bool:
|
|||
|
||||
|
||||
def closest(latitude: float, longitude: float,
|
||||
states: Sequence[State]) -> State:
|
||||
states: Sequence[State]) -> Optional[State]:
|
||||
"""Return closest state to point.
|
||||
|
||||
Async friendly.
|
||||
|
@ -31,6 +31,7 @@ def closest(latitude: float, longitude: float,
|
|||
return min(
|
||||
with_location,
|
||||
key=lambda state: loc_util.distance(
|
||||
latitude, longitude, state.attributes.get(ATTR_LATITUDE),
|
||||
state.attributes.get(ATTR_LONGITUDE))
|
||||
state.attributes.get(ATTR_LATITUDE),
|
||||
state.attributes.get(ATTR_LONGITUDE),
|
||||
latitude, longitude)
|
||||
)
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
"""Helpers that help with state related things."""
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from types import TracebackType
|
||||
from typing import ( # noqa: F401 pylint: disable=unused-import
|
||||
Awaitable, Dict, Iterable, List, Optional, Tuple, Type, Union)
|
||||
|
||||
from homeassistant.loader import bind_hass
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
@ -42,6 +46,7 @@ from homeassistant.const import (
|
|||
STATE_UNLOCKED, SERVICE_SELECT_OPTION)
|
||||
from homeassistant.core import State
|
||||
from homeassistant.util.async_ import run_coroutine_threadsafe
|
||||
from .typing import HomeAssistantType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -102,43 +107,50 @@ class AsyncTrackStates:
|
|||
Must be run within the event loop.
|
||||
"""
|
||||
|
||||
def __init__(self, hass):
|
||||
def __init__(self, hass: HomeAssistantType) -> None:
|
||||
"""Initialize a TrackStates block."""
|
||||
self.hass = hass
|
||||
self.states = []
|
||||
self.states = [] # type: List[State]
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> List[State]:
|
||||
"""Record time from which to track changes."""
|
||||
self.now = dt_util.utcnow()
|
||||
return self.states
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(self, exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType]) -> None:
|
||||
"""Add changes states to changes list."""
|
||||
self.states.extend(get_changed_since(self.hass.states.async_all(),
|
||||
self.now))
|
||||
|
||||
|
||||
def get_changed_since(states, utc_point_in_time):
|
||||
def get_changed_since(states: Iterable[State],
|
||||
utc_point_in_time: dt.datetime) -> List[State]:
|
||||
"""Return list of states that have been changed since utc_point_in_time."""
|
||||
return [state for state in states
|
||||
if state.last_updated >= utc_point_in_time]
|
||||
|
||||
|
||||
@bind_hass
|
||||
def reproduce_state(hass, states, blocking=False):
|
||||
def reproduce_state(hass: HomeAssistantType,
|
||||
states: Union[State, Iterable[State]],
|
||||
blocking: bool = False) -> None:
|
||||
"""Reproduce given state."""
|
||||
return run_coroutine_threadsafe(
|
||||
return run_coroutine_threadsafe( # type: ignore
|
||||
async_reproduce_state(hass, states, blocking), hass.loop).result()
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_reproduce_state(hass, states, blocking=False):
|
||||
async def async_reproduce_state(hass: HomeAssistantType,
|
||||
states: Union[State, Iterable[State]],
|
||||
blocking: bool = False) -> None:
|
||||
"""Reproduce given state."""
|
||||
if isinstance(states, State):
|
||||
states = [states]
|
||||
|
||||
to_call = defaultdict(list)
|
||||
to_call = defaultdict(list) # type: Dict[Tuple[str, str, str], List[str]]
|
||||
|
||||
for state in states:
|
||||
|
||||
|
@ -182,7 +194,7 @@ async def async_reproduce_state(hass, states, blocking=False):
|
|||
json.dumps(dict(state.attributes), sort_keys=True))
|
||||
to_call[key].append(state.entity_id)
|
||||
|
||||
domain_tasks = {}
|
||||
domain_tasks = {} # type: Dict[str, List[Awaitable[Optional[bool]]]]
|
||||
for (service_domain, service, service_data), entity_ids in to_call.items():
|
||||
data = json.loads(service_data)
|
||||
data[ATTR_ENTITY_ID] = entity_ids
|
||||
|
@ -194,7 +206,8 @@ async def async_reproduce_state(hass, states, blocking=False):
|
|||
hass.services.async_call(service_domain, service, data, blocking)
|
||||
)
|
||||
|
||||
async def async_handle_service_calls(coro_list):
|
||||
async def async_handle_service_calls(
|
||||
coro_list: Iterable[Awaitable]) -> None:
|
||||
"""Handle service calls by domain sequence."""
|
||||
for coro in coro_list:
|
||||
await coro
|
||||
|
@ -205,7 +218,7 @@ async def async_reproduce_state(hass, states, blocking=False):
|
|||
await asyncio.wait(execute_tasks, loop=hass.loop)
|
||||
|
||||
|
||||
def state_as_number(state):
|
||||
def state_as_number(state: State) -> float:
|
||||
"""
|
||||
Try to coerce our state to a number.
|
||||
|
||||
|
|
|
@ -1,17 +1,19 @@
|
|||
"""Translation string lookup helpers."""
|
||||
import logging
|
||||
from os import path
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.loader import get_component, bind_hass
|
||||
from homeassistant.util.json import load_json
|
||||
from .typing import HomeAssistantType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
TRANSLATION_STRING_CACHE = 'translation_string_cache'
|
||||
|
||||
|
||||
def recursive_flatten(prefix, data):
|
||||
def recursive_flatten(prefix: Any, data: Dict) -> Dict[str, Any]:
|
||||
"""Return a flattened representation of dict data."""
|
||||
output = {}
|
||||
for key, value in data.items():
|
||||
|
@ -23,12 +25,13 @@ def recursive_flatten(prefix, data):
|
|||
return output
|
||||
|
||||
|
||||
def flatten(data):
|
||||
def flatten(data: Dict) -> Dict[str, Any]:
|
||||
"""Return a flattened representation of dict data."""
|
||||
return recursive_flatten('', data)
|
||||
|
||||
|
||||
def component_translation_file(hass, component, language):
|
||||
def component_translation_file(hass: HomeAssistantType, component: str,
|
||||
language: str) -> str:
|
||||
"""Return the translation json file location for a component."""
|
||||
if '.' in component:
|
||||
name = component.split('.', 1)[1]
|
||||
|
@ -36,6 +39,7 @@ def component_translation_file(hass, component, language):
|
|||
name = component
|
||||
|
||||
module = get_component(hass, component)
|
||||
assert module is not None
|
||||
component_path = path.dirname(module.__file__)
|
||||
|
||||
# If loading translations for the package root, (__init__.py), the
|
||||
|
@ -48,19 +52,23 @@ def component_translation_file(hass, component, language):
|
|||
return path.join(component_path, '.translations', filename)
|
||||
|
||||
|
||||
def load_translations_files(translation_files):
|
||||
def load_translations_files(translation_files: Dict[str, str]) \
|
||||
-> Dict[str, Dict[str, Any]]:
|
||||
"""Load and parse translation.json files."""
|
||||
loaded = {}
|
||||
for component, translation_file in translation_files.items():
|
||||
loaded[component] = load_json(translation_file)
|
||||
loaded_json = load_json(translation_file)
|
||||
assert isinstance(loaded_json, dict)
|
||||
loaded[component] = loaded_json
|
||||
|
||||
return loaded
|
||||
|
||||
|
||||
def build_resources(translation_cache, components):
|
||||
def build_resources(translation_cache: Dict[str, Dict[str, Any]],
|
||||
components: Iterable[str]) -> Dict[str, Dict[str, Any]]:
|
||||
"""Build the resources response for the given components."""
|
||||
# Build response
|
||||
resources = {}
|
||||
resources = {} # type: Dict[str, Dict[str, Any]]
|
||||
for component in components:
|
||||
if '.' not in component:
|
||||
domain = component
|
||||
|
@ -79,7 +87,8 @@ def build_resources(translation_cache, components):
|
|||
|
||||
|
||||
@bind_hass
|
||||
async def async_get_component_resources(hass, language):
|
||||
async def async_get_component_resources(hass: HomeAssistantType,
|
||||
language: str) -> Dict[str, Any]:
|
||||
"""Return translation resources for all components."""
|
||||
if TRANSLATION_STRING_CACHE not in hass.data:
|
||||
hass.data[TRANSLATION_STRING_CACHE] = {}
|
||||
|
@ -99,12 +108,13 @@ async def async_get_component_resources(hass, language):
|
|||
|
||||
# Load missing files
|
||||
if missing_files:
|
||||
loaded_translations = await hass.async_add_job(
|
||||
load_translations_job = hass.async_add_job(
|
||||
load_translations_files, missing_files)
|
||||
assert load_translations_job is not None
|
||||
loaded_translations = await load_translations_job
|
||||
|
||||
# Update cache
|
||||
for component, translation_data in loaded_translations.items():
|
||||
translation_cache[component] = translation_data
|
||||
translation_cache.update(loaded_translations)
|
||||
|
||||
resources = build_resources(translation_cache, components)
|
||||
|
||||
|
@ -114,7 +124,8 @@ async def async_get_component_resources(hass, language):
|
|||
|
||||
|
||||
@bind_hass
|
||||
async def async_get_translations(hass, language):
|
||||
async def async_get_translations(hass: HomeAssistantType,
|
||||
language: str) -> Dict[str, Any]:
|
||||
"""Return all backend translations."""
|
||||
resources = await async_get_component_resources(hass, language)
|
||||
if language != 'en':
|
||||
|
|
2
tox.ini
2
tox.ini
|
@ -60,4 +60,4 @@ whitelist_externals=/bin/bash
|
|||
deps =
|
||||
-r{toxinidir}/requirements_test.txt
|
||||
commands =
|
||||
/bin/bash -c 'mypy homeassistant/*.py homeassistant/auth/ homeassistant/util/'
|
||||
/bin/bash -c 'mypy homeassistant/*.py homeassistant/{auth,util}/ homeassistant/helpers/{icon,intent,json,location,state,translation,typing}.py'
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue