Start type annotating/testing helpers (#17858)

* Add type hints to helpers.intent and location

* Test typing for helpers.icon, json, and typing

* Add type hints to helpers.state

* Add type hints to helpers.translation
This commit is contained in:
Ville Skyttä 2018-10-28 21:12:52 +02:00 committed by GitHub
parent 0f877711a0
commit c9c707e368
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 92 additions and 55 deletions

View file

@ -1,17 +1,20 @@
"""Module to coordinate user intentions."""
import logging
import re
from typing import Any, Callable, Dict, Iterable, Optional
import voluptuous as vol
from homeassistant.const import ATTR_SUPPORTED_FEATURES
from homeassistant.core import callback
from homeassistant.core import callback, State, T
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import bind_hass
from homeassistant.const import ATTR_ENTITY_ID
_LOGGER = logging.getLogger(__name__)
_SlotsType = Dict[str, Any]
INTENT_TURN_OFF = 'HassTurnOff'
INTENT_TURN_ON = 'HassTurnOn'
@ -28,7 +31,7 @@ SPEECH_TYPE_SSML = 'ssml'
@callback
@bind_hass
def async_register(hass, handler):
def async_register(hass: HomeAssistantType, handler: 'IntentHandler') -> None:
"""Register an intent with Home Assistant."""
intents = hass.data.get(DATA_KEY)
if intents is None:
@ -44,10 +47,12 @@ def async_register(hass, handler):
@bind_hass
async def async_handle(hass, platform, intent_type, slots=None,
text_input=None):
async def async_handle(hass: HomeAssistantType, platform: str,
intent_type: str, slots: Optional[_SlotsType] = None,
text_input: Optional[str] = None) -> 'IntentResponse':
"""Handle an intent."""
handler = hass.data.get(DATA_KEY, {}).get(intent_type)
handler = \
hass.data.get(DATA_KEY, {}).get(intent_type) # type: IntentHandler
if handler is None:
raise UnknownIntent('Unknown intent {}'.format(intent_type))
@ -93,7 +98,8 @@ class IntentUnexpectedError(IntentError):
@callback
@bind_hass
def async_match_state(hass, name, states=None):
def async_match_state(hass: HomeAssistantType, name: str,
states: Optional[Iterable[State]] = None) -> State:
"""Find a state that matches the name."""
if states is None:
states = hass.states.async_all()
@ -108,7 +114,7 @@ def async_match_state(hass, name, states=None):
@callback
def async_test_feature(state, feature, feature_name):
def async_test_feature(state: State, feature: int, feature_name: str) -> None:
"""Test is state supports a feature."""
if state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) & feature == 0:
raise IntentHandleError(
@ -119,18 +125,18 @@ def async_test_feature(state, feature, feature_name):
class IntentHandler:
"""Intent handler registration."""
intent_type = None
slot_schema = None
intent_type = None # type: Optional[str]
slot_schema = None # type: Optional[vol.Schema]
_slot_schema = None
platforms = []
platforms = [] # type: Optional[Iterable[str]]
@callback
def async_can_handle(self, intent_obj):
def async_can_handle(self, intent_obj: 'Intent') -> bool:
"""Test if an intent can be handled."""
return self.platforms is None or intent_obj.platform in self.platforms
@callback
def async_validate_slots(self, slots):
def async_validate_slots(self, slots: _SlotsType) -> _SlotsType:
"""Validate slot information."""
if self.slot_schema is None:
return slots
@ -141,18 +147,19 @@ class IntentHandler:
for key, validator in self.slot_schema.items()},
extra=vol.ALLOW_EXTRA)
return self._slot_schema(slots)
return self._slot_schema(slots) # type: ignore
async def async_handle(self, intent_obj):
async def async_handle(self, intent_obj: 'Intent') -> 'IntentResponse':
"""Handle the intent."""
raise NotImplementedError()
def __repr__(self):
def __repr__(self) -> str:
"""Represent a string of an intent handler."""
return '<{} - {}>'.format(self.__class__.__name__, self.intent_type)
def _fuzzymatch(name, items, key):
def _fuzzymatch(name: str, items: Iterable[T], key: Callable[[T], str]) \
-> Optional[T]:
"""Fuzzy matching function."""
matches = []
pattern = '.*?'.join(name)
@ -176,14 +183,15 @@ class ServiceIntentHandler(IntentHandler):
vol.Required('name'): cv.string,
}
def __init__(self, intent_type, domain, service, speech):
def __init__(self, intent_type: str, domain: str, service: str,
speech: str) -> None:
"""Create Service Intent Handler."""
self.intent_type = intent_type
self.domain = domain
self.service = service
self.speech = speech
async def async_handle(self, intent_obj):
async def async_handle(self, intent_obj: 'Intent') -> 'IntentResponse':
"""Handle the hass intent."""
hass = intent_obj.hass
slots = self.async_validate_slots(intent_obj.slots)
@ -203,7 +211,9 @@ class Intent:
__slots__ = ['hass', 'platform', 'intent_type', 'slots', 'text_input']
def __init__(self, hass, platform, intent_type, slots, text_input):
def __init__(self, hass: HomeAssistantType, platform: str,
intent_type: str, slots: _SlotsType,
text_input: Optional[str]) -> None:
"""Initialize an intent."""
self.hass = hass
self.platform = platform
@ -212,7 +222,7 @@ class Intent:
self.text_input = text_input
@callback
def create_response(self):
def create_response(self) -> 'IntentResponse':
"""Create a response."""
return IntentResponse(self)
@ -220,14 +230,15 @@ class Intent:
class IntentResponse:
"""Response to an intent."""
def __init__(self, intent=None):
def __init__(self, intent: Optional[Intent] = None) -> None:
"""Initialize an IntentResponse."""
self.intent = intent
self.speech = {}
self.card = {}
self.speech = {} # type: Dict[str, Dict[str, Any]]
self.card = {} # type: Dict[str, Dict[str, str]]
@callback
def async_set_speech(self, speech, speech_type='plain', extra_data=None):
def async_set_speech(self, speech: str, speech_type: str = 'plain',
extra_data: Optional[Any] = None) -> None:
"""Set speech response."""
self.speech[speech_type] = {
'speech': speech,
@ -235,7 +246,8 @@ class IntentResponse:
}
@callback
def async_set_card(self, title, content, card_type='simple'):
def async_set_card(self, title: str, content: str,
card_type: str = 'simple') -> None:
"""Set speech response."""
self.card[card_type] = {
'title': title,
@ -243,7 +255,7 @@ class IntentResponse:
}
@callback
def as_dict(self):
def as_dict(self) -> Dict[str, Dict[str, Dict[str, Any]]]:
"""Return a dictionary representation of an intent response."""
return {
'speech': self.speech,

View file

@ -1,6 +1,6 @@
"""Location helpers for Home Assistant."""
from typing import Sequence
from typing import Optional, Sequence
from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE
from homeassistant.core import State
@ -18,7 +18,7 @@ def has_location(state: State) -> bool:
def closest(latitude: float, longitude: float,
states: Sequence[State]) -> State:
states: Sequence[State]) -> Optional[State]:
"""Return closest state to point.
Async friendly.
@ -31,6 +31,7 @@ def closest(latitude: float, longitude: float,
return min(
with_location,
key=lambda state: loc_util.distance(
latitude, longitude, state.attributes.get(ATTR_LATITUDE),
state.attributes.get(ATTR_LONGITUDE))
state.attributes.get(ATTR_LATITUDE),
state.attributes.get(ATTR_LONGITUDE),
latitude, longitude)
)

View file

@ -1,8 +1,12 @@
"""Helpers that help with state related things."""
import asyncio
import datetime as dt
import json
import logging
from collections import defaultdict
from types import TracebackType
from typing import ( # noqa: F401 pylint: disable=unused-import
Awaitable, Dict, Iterable, List, Optional, Tuple, Type, Union)
from homeassistant.loader import bind_hass
import homeassistant.util.dt as dt_util
@ -42,6 +46,7 @@ from homeassistant.const import (
STATE_UNLOCKED, SERVICE_SELECT_OPTION)
from homeassistant.core import State
from homeassistant.util.async_ import run_coroutine_threadsafe
from .typing import HomeAssistantType
_LOGGER = logging.getLogger(__name__)
@ -102,43 +107,50 @@ class AsyncTrackStates:
Must be run within the event loop.
"""
def __init__(self, hass):
def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize a TrackStates block."""
self.hass = hass
self.states = []
self.states = [] # type: List[State]
# pylint: disable=attribute-defined-outside-init
def __enter__(self):
def __enter__(self) -> List[State]:
"""Record time from which to track changes."""
self.now = dt_util.utcnow()
return self.states
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]) -> None:
"""Add changes states to changes list."""
self.states.extend(get_changed_since(self.hass.states.async_all(),
self.now))
def get_changed_since(states, utc_point_in_time):
def get_changed_since(states: Iterable[State],
utc_point_in_time: dt.datetime) -> List[State]:
"""Return list of states that have been changed since utc_point_in_time."""
return [state for state in states
if state.last_updated >= utc_point_in_time]
@bind_hass
def reproduce_state(hass, states, blocking=False):
def reproduce_state(hass: HomeAssistantType,
states: Union[State, Iterable[State]],
blocking: bool = False) -> None:
"""Reproduce given state."""
return run_coroutine_threadsafe(
return run_coroutine_threadsafe( # type: ignore
async_reproduce_state(hass, states, blocking), hass.loop).result()
@bind_hass
async def async_reproduce_state(hass, states, blocking=False):
async def async_reproduce_state(hass: HomeAssistantType,
states: Union[State, Iterable[State]],
blocking: bool = False) -> None:
"""Reproduce given state."""
if isinstance(states, State):
states = [states]
to_call = defaultdict(list)
to_call = defaultdict(list) # type: Dict[Tuple[str, str, str], List[str]]
for state in states:
@ -182,7 +194,7 @@ async def async_reproduce_state(hass, states, blocking=False):
json.dumps(dict(state.attributes), sort_keys=True))
to_call[key].append(state.entity_id)
domain_tasks = {}
domain_tasks = {} # type: Dict[str, List[Awaitable[Optional[bool]]]]
for (service_domain, service, service_data), entity_ids in to_call.items():
data = json.loads(service_data)
data[ATTR_ENTITY_ID] = entity_ids
@ -194,7 +206,8 @@ async def async_reproduce_state(hass, states, blocking=False):
hass.services.async_call(service_domain, service, data, blocking)
)
async def async_handle_service_calls(coro_list):
async def async_handle_service_calls(
coro_list: Iterable[Awaitable]) -> None:
"""Handle service calls by domain sequence."""
for coro in coro_list:
await coro
@ -205,7 +218,7 @@ async def async_reproduce_state(hass, states, blocking=False):
await asyncio.wait(execute_tasks, loop=hass.loop)
def state_as_number(state):
def state_as_number(state: State) -> float:
"""
Try to coerce our state to a number.

View file

@ -1,17 +1,19 @@
"""Translation string lookup helpers."""
import logging
from os import path
from typing import Any, Dict, Iterable
from homeassistant import config_entries
from homeassistant.loader import get_component, bind_hass
from homeassistant.util.json import load_json
from .typing import HomeAssistantType
_LOGGER = logging.getLogger(__name__)
TRANSLATION_STRING_CACHE = 'translation_string_cache'
def recursive_flatten(prefix, data):
def recursive_flatten(prefix: Any, data: Dict) -> Dict[str, Any]:
"""Return a flattened representation of dict data."""
output = {}
for key, value in data.items():
@ -23,12 +25,13 @@ def recursive_flatten(prefix, data):
return output
def flatten(data):
def flatten(data: Dict) -> Dict[str, Any]:
"""Return a flattened representation of dict data."""
return recursive_flatten('', data)
def component_translation_file(hass, component, language):
def component_translation_file(hass: HomeAssistantType, component: str,
language: str) -> str:
"""Return the translation json file location for a component."""
if '.' in component:
name = component.split('.', 1)[1]
@ -36,6 +39,7 @@ def component_translation_file(hass, component, language):
name = component
module = get_component(hass, component)
assert module is not None
component_path = path.dirname(module.__file__)
# If loading translations for the package root, (__init__.py), the
@ -48,19 +52,23 @@ def component_translation_file(hass, component, language):
return path.join(component_path, '.translations', filename)
def load_translations_files(translation_files):
def load_translations_files(translation_files: Dict[str, str]) \
-> Dict[str, Dict[str, Any]]:
"""Load and parse translation.json files."""
loaded = {}
for component, translation_file in translation_files.items():
loaded[component] = load_json(translation_file)
loaded_json = load_json(translation_file)
assert isinstance(loaded_json, dict)
loaded[component] = loaded_json
return loaded
def build_resources(translation_cache, components):
def build_resources(translation_cache: Dict[str, Dict[str, Any]],
components: Iterable[str]) -> Dict[str, Dict[str, Any]]:
"""Build the resources response for the given components."""
# Build response
resources = {}
resources = {} # type: Dict[str, Dict[str, Any]]
for component in components:
if '.' not in component:
domain = component
@ -79,7 +87,8 @@ def build_resources(translation_cache, components):
@bind_hass
async def async_get_component_resources(hass, language):
async def async_get_component_resources(hass: HomeAssistantType,
language: str) -> Dict[str, Any]:
"""Return translation resources for all components."""
if TRANSLATION_STRING_CACHE not in hass.data:
hass.data[TRANSLATION_STRING_CACHE] = {}
@ -99,12 +108,13 @@ async def async_get_component_resources(hass, language):
# Load missing files
if missing_files:
loaded_translations = await hass.async_add_job(
load_translations_job = hass.async_add_job(
load_translations_files, missing_files)
assert load_translations_job is not None
loaded_translations = await load_translations_job
# Update cache
for component, translation_data in loaded_translations.items():
translation_cache[component] = translation_data
translation_cache.update(loaded_translations)
resources = build_resources(translation_cache, components)
@ -114,7 +124,8 @@ async def async_get_component_resources(hass, language):
@bind_hass
async def async_get_translations(hass, language):
async def async_get_translations(hass: HomeAssistantType,
language: str) -> Dict[str, Any]:
"""Return all backend translations."""
resources = await async_get_component_resources(hass, language)
if language != 'en':

View file

@ -60,4 +60,4 @@ whitelist_externals=/bin/bash
deps =
-r{toxinidir}/requirements_test.txt
commands =
/bin/bash -c 'mypy homeassistant/*.py homeassistant/auth/ homeassistant/util/'
/bin/bash -c 'mypy homeassistant/*.py homeassistant/{auth,util}/ homeassistant/helpers/{icon,intent,json,location,state,translation,typing}.py'