Start type annotating/testing helpers (#17858)

* Add type hints to helpers.intent and location

* Test typing for helpers.icon, json, and typing

* Add type hints to helpers.state

* Add type hints to helpers.translation
This commit is contained in:
Ville Skyttä 2018-10-28 21:12:52 +02:00 committed by GitHub
parent 0f877711a0
commit c9c707e368
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 92 additions and 55 deletions

View file

@ -1,17 +1,20 @@
"""Module to coordinate user intentions.""" """Module to coordinate user intentions."""
import logging import logging
import re import re
from typing import Any, Callable, Dict, Iterable, Optional
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ATTR_SUPPORTED_FEATURES from homeassistant.const import ATTR_SUPPORTED_FEATURES
from homeassistant.core import callback from homeassistant.core import callback, State, T
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.const import ATTR_ENTITY_ID from homeassistant.const import ATTR_ENTITY_ID
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_SlotsType = Dict[str, Any]
INTENT_TURN_OFF = 'HassTurnOff' INTENT_TURN_OFF = 'HassTurnOff'
INTENT_TURN_ON = 'HassTurnOn' INTENT_TURN_ON = 'HassTurnOn'
@ -28,7 +31,7 @@ SPEECH_TYPE_SSML = 'ssml'
@callback @callback
@bind_hass @bind_hass
def async_register(hass, handler): def async_register(hass: HomeAssistantType, handler: 'IntentHandler') -> None:
"""Register an intent with Home Assistant.""" """Register an intent with Home Assistant."""
intents = hass.data.get(DATA_KEY) intents = hass.data.get(DATA_KEY)
if intents is None: if intents is None:
@ -44,10 +47,12 @@ def async_register(hass, handler):
@bind_hass @bind_hass
async def async_handle(hass, platform, intent_type, slots=None, async def async_handle(hass: HomeAssistantType, platform: str,
text_input=None): intent_type: str, slots: Optional[_SlotsType] = None,
text_input: Optional[str] = None) -> 'IntentResponse':
"""Handle an intent.""" """Handle an intent."""
handler = hass.data.get(DATA_KEY, {}).get(intent_type) handler = \
hass.data.get(DATA_KEY, {}).get(intent_type) # type: IntentHandler
if handler is None: if handler is None:
raise UnknownIntent('Unknown intent {}'.format(intent_type)) raise UnknownIntent('Unknown intent {}'.format(intent_type))
@ -93,7 +98,8 @@ class IntentUnexpectedError(IntentError):
@callback @callback
@bind_hass @bind_hass
def async_match_state(hass, name, states=None): def async_match_state(hass: HomeAssistantType, name: str,
states: Optional[Iterable[State]] = None) -> State:
"""Find a state that matches the name.""" """Find a state that matches the name."""
if states is None: if states is None:
states = hass.states.async_all() states = hass.states.async_all()
@ -108,7 +114,7 @@ def async_match_state(hass, name, states=None):
@callback @callback
def async_test_feature(state, feature, feature_name): def async_test_feature(state: State, feature: int, feature_name: str) -> None:
"""Test is state supports a feature.""" """Test is state supports a feature."""
if state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) & feature == 0: if state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) & feature == 0:
raise IntentHandleError( raise IntentHandleError(
@ -119,18 +125,18 @@ def async_test_feature(state, feature, feature_name):
class IntentHandler: class IntentHandler:
"""Intent handler registration.""" """Intent handler registration."""
intent_type = None intent_type = None # type: Optional[str]
slot_schema = None slot_schema = None # type: Optional[vol.Schema]
_slot_schema = None _slot_schema = None
platforms = [] platforms = [] # type: Optional[Iterable[str]]
@callback @callback
def async_can_handle(self, intent_obj): def async_can_handle(self, intent_obj: 'Intent') -> bool:
"""Test if an intent can be handled.""" """Test if an intent can be handled."""
return self.platforms is None or intent_obj.platform in self.platforms return self.platforms is None or intent_obj.platform in self.platforms
@callback @callback
def async_validate_slots(self, slots): def async_validate_slots(self, slots: _SlotsType) -> _SlotsType:
"""Validate slot information.""" """Validate slot information."""
if self.slot_schema is None: if self.slot_schema is None:
return slots return slots
@ -141,18 +147,19 @@ class IntentHandler:
for key, validator in self.slot_schema.items()}, for key, validator in self.slot_schema.items()},
extra=vol.ALLOW_EXTRA) extra=vol.ALLOW_EXTRA)
return self._slot_schema(slots) return self._slot_schema(slots) # type: ignore
async def async_handle(self, intent_obj): async def async_handle(self, intent_obj: 'Intent') -> 'IntentResponse':
"""Handle the intent.""" """Handle the intent."""
raise NotImplementedError() raise NotImplementedError()
def __repr__(self): def __repr__(self) -> str:
"""Represent a string of an intent handler.""" """Represent a string of an intent handler."""
return '<{} - {}>'.format(self.__class__.__name__, self.intent_type) return '<{} - {}>'.format(self.__class__.__name__, self.intent_type)
def _fuzzymatch(name, items, key): def _fuzzymatch(name: str, items: Iterable[T], key: Callable[[T], str]) \
-> Optional[T]:
"""Fuzzy matching function.""" """Fuzzy matching function."""
matches = [] matches = []
pattern = '.*?'.join(name) pattern = '.*?'.join(name)
@ -176,14 +183,15 @@ class ServiceIntentHandler(IntentHandler):
vol.Required('name'): cv.string, vol.Required('name'): cv.string,
} }
def __init__(self, intent_type, domain, service, speech): def __init__(self, intent_type: str, domain: str, service: str,
speech: str) -> None:
"""Create Service Intent Handler.""" """Create Service Intent Handler."""
self.intent_type = intent_type self.intent_type = intent_type
self.domain = domain self.domain = domain
self.service = service self.service = service
self.speech = speech self.speech = speech
async def async_handle(self, intent_obj): async def async_handle(self, intent_obj: 'Intent') -> 'IntentResponse':
"""Handle the hass intent.""" """Handle the hass intent."""
hass = intent_obj.hass hass = intent_obj.hass
slots = self.async_validate_slots(intent_obj.slots) slots = self.async_validate_slots(intent_obj.slots)
@ -203,7 +211,9 @@ class Intent:
__slots__ = ['hass', 'platform', 'intent_type', 'slots', 'text_input'] __slots__ = ['hass', 'platform', 'intent_type', 'slots', 'text_input']
def __init__(self, hass, platform, intent_type, slots, text_input): def __init__(self, hass: HomeAssistantType, platform: str,
intent_type: str, slots: _SlotsType,
text_input: Optional[str]) -> None:
"""Initialize an intent.""" """Initialize an intent."""
self.hass = hass self.hass = hass
self.platform = platform self.platform = platform
@ -212,7 +222,7 @@ class Intent:
self.text_input = text_input self.text_input = text_input
@callback @callback
def create_response(self): def create_response(self) -> 'IntentResponse':
"""Create a response.""" """Create a response."""
return IntentResponse(self) return IntentResponse(self)
@ -220,14 +230,15 @@ class Intent:
class IntentResponse: class IntentResponse:
"""Response to an intent.""" """Response to an intent."""
def __init__(self, intent=None): def __init__(self, intent: Optional[Intent] = None) -> None:
"""Initialize an IntentResponse.""" """Initialize an IntentResponse."""
self.intent = intent self.intent = intent
self.speech = {} self.speech = {} # type: Dict[str, Dict[str, Any]]
self.card = {} self.card = {} # type: Dict[str, Dict[str, str]]
@callback @callback
def async_set_speech(self, speech, speech_type='plain', extra_data=None): def async_set_speech(self, speech: str, speech_type: str = 'plain',
extra_data: Optional[Any] = None) -> None:
"""Set speech response.""" """Set speech response."""
self.speech[speech_type] = { self.speech[speech_type] = {
'speech': speech, 'speech': speech,
@ -235,7 +246,8 @@ class IntentResponse:
} }
@callback @callback
def async_set_card(self, title, content, card_type='simple'): def async_set_card(self, title: str, content: str,
card_type: str = 'simple') -> None:
"""Set speech response.""" """Set speech response."""
self.card[card_type] = { self.card[card_type] = {
'title': title, 'title': title,
@ -243,7 +255,7 @@ class IntentResponse:
} }
@callback @callback
def as_dict(self): def as_dict(self) -> Dict[str, Dict[str, Dict[str, Any]]]:
"""Return a dictionary representation of an intent response.""" """Return a dictionary representation of an intent response."""
return { return {
'speech': self.speech, 'speech': self.speech,

View file

@ -1,6 +1,6 @@
"""Location helpers for Home Assistant.""" """Location helpers for Home Assistant."""
from typing import Sequence from typing import Optional, Sequence
from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE
from homeassistant.core import State from homeassistant.core import State
@ -18,7 +18,7 @@ def has_location(state: State) -> bool:
def closest(latitude: float, longitude: float, def closest(latitude: float, longitude: float,
states: Sequence[State]) -> State: states: Sequence[State]) -> Optional[State]:
"""Return closest state to point. """Return closest state to point.
Async friendly. Async friendly.
@ -31,6 +31,7 @@ def closest(latitude: float, longitude: float,
return min( return min(
with_location, with_location,
key=lambda state: loc_util.distance( key=lambda state: loc_util.distance(
latitude, longitude, state.attributes.get(ATTR_LATITUDE), state.attributes.get(ATTR_LATITUDE),
state.attributes.get(ATTR_LONGITUDE)) state.attributes.get(ATTR_LONGITUDE),
latitude, longitude)
) )

View file

@ -1,8 +1,12 @@
"""Helpers that help with state related things.""" """Helpers that help with state related things."""
import asyncio import asyncio
import datetime as dt
import json import json
import logging import logging
from collections import defaultdict from collections import defaultdict
from types import TracebackType
from typing import ( # noqa: F401 pylint: disable=unused-import
Awaitable, Dict, Iterable, List, Optional, Tuple, Type, Union)
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -42,6 +46,7 @@ from homeassistant.const import (
STATE_UNLOCKED, SERVICE_SELECT_OPTION) STATE_UNLOCKED, SERVICE_SELECT_OPTION)
from homeassistant.core import State from homeassistant.core import State
from homeassistant.util.async_ import run_coroutine_threadsafe from homeassistant.util.async_ import run_coroutine_threadsafe
from .typing import HomeAssistantType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -102,43 +107,50 @@ class AsyncTrackStates:
Must be run within the event loop. Must be run within the event loop.
""" """
def __init__(self, hass): def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize a TrackStates block.""" """Initialize a TrackStates block."""
self.hass = hass self.hass = hass
self.states = [] self.states = [] # type: List[State]
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
def __enter__(self): def __enter__(self) -> List[State]:
"""Record time from which to track changes.""" """Record time from which to track changes."""
self.now = dt_util.utcnow() self.now = dt_util.utcnow()
return self.states return self.states
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]) -> None:
"""Add changes states to changes list.""" """Add changes states to changes list."""
self.states.extend(get_changed_since(self.hass.states.async_all(), self.states.extend(get_changed_since(self.hass.states.async_all(),
self.now)) self.now))
def get_changed_since(states, utc_point_in_time): def get_changed_since(states: Iterable[State],
utc_point_in_time: dt.datetime) -> List[State]:
"""Return list of states that have been changed since utc_point_in_time.""" """Return list of states that have been changed since utc_point_in_time."""
return [state for state in states return [state for state in states
if state.last_updated >= utc_point_in_time] if state.last_updated >= utc_point_in_time]
@bind_hass @bind_hass
def reproduce_state(hass, states, blocking=False): def reproduce_state(hass: HomeAssistantType,
states: Union[State, Iterable[State]],
blocking: bool = False) -> None:
"""Reproduce given state.""" """Reproduce given state."""
return run_coroutine_threadsafe( return run_coroutine_threadsafe( # type: ignore
async_reproduce_state(hass, states, blocking), hass.loop).result() async_reproduce_state(hass, states, blocking), hass.loop).result()
@bind_hass @bind_hass
async def async_reproduce_state(hass, states, blocking=False): async def async_reproduce_state(hass: HomeAssistantType,
states: Union[State, Iterable[State]],
blocking: bool = False) -> None:
"""Reproduce given state.""" """Reproduce given state."""
if isinstance(states, State): if isinstance(states, State):
states = [states] states = [states]
to_call = defaultdict(list) to_call = defaultdict(list) # type: Dict[Tuple[str, str, str], List[str]]
for state in states: for state in states:
@ -182,7 +194,7 @@ async def async_reproduce_state(hass, states, blocking=False):
json.dumps(dict(state.attributes), sort_keys=True)) json.dumps(dict(state.attributes), sort_keys=True))
to_call[key].append(state.entity_id) to_call[key].append(state.entity_id)
domain_tasks = {} domain_tasks = {} # type: Dict[str, List[Awaitable[Optional[bool]]]]
for (service_domain, service, service_data), entity_ids in to_call.items(): for (service_domain, service, service_data), entity_ids in to_call.items():
data = json.loads(service_data) data = json.loads(service_data)
data[ATTR_ENTITY_ID] = entity_ids data[ATTR_ENTITY_ID] = entity_ids
@ -194,7 +206,8 @@ async def async_reproduce_state(hass, states, blocking=False):
hass.services.async_call(service_domain, service, data, blocking) hass.services.async_call(service_domain, service, data, blocking)
) )
async def async_handle_service_calls(coro_list): async def async_handle_service_calls(
coro_list: Iterable[Awaitable]) -> None:
"""Handle service calls by domain sequence.""" """Handle service calls by domain sequence."""
for coro in coro_list: for coro in coro_list:
await coro await coro
@ -205,7 +218,7 @@ async def async_reproduce_state(hass, states, blocking=False):
await asyncio.wait(execute_tasks, loop=hass.loop) await asyncio.wait(execute_tasks, loop=hass.loop)
def state_as_number(state): def state_as_number(state: State) -> float:
""" """
Try to coerce our state to a number. Try to coerce our state to a number.

View file

@ -1,17 +1,19 @@
"""Translation string lookup helpers.""" """Translation string lookup helpers."""
import logging import logging
from os import path from os import path
from typing import Any, Dict, Iterable
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.loader import get_component, bind_hass from homeassistant.loader import get_component, bind_hass
from homeassistant.util.json import load_json from homeassistant.util.json import load_json
from .typing import HomeAssistantType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
TRANSLATION_STRING_CACHE = 'translation_string_cache' TRANSLATION_STRING_CACHE = 'translation_string_cache'
def recursive_flatten(prefix, data): def recursive_flatten(prefix: Any, data: Dict) -> Dict[str, Any]:
"""Return a flattened representation of dict data.""" """Return a flattened representation of dict data."""
output = {} output = {}
for key, value in data.items(): for key, value in data.items():
@ -23,12 +25,13 @@ def recursive_flatten(prefix, data):
return output return output
def flatten(data): def flatten(data: Dict) -> Dict[str, Any]:
"""Return a flattened representation of dict data.""" """Return a flattened representation of dict data."""
return recursive_flatten('', data) return recursive_flatten('', data)
def component_translation_file(hass, component, language): def component_translation_file(hass: HomeAssistantType, component: str,
language: str) -> str:
"""Return the translation json file location for a component.""" """Return the translation json file location for a component."""
if '.' in component: if '.' in component:
name = component.split('.', 1)[1] name = component.split('.', 1)[1]
@ -36,6 +39,7 @@ def component_translation_file(hass, component, language):
name = component name = component
module = get_component(hass, component) module = get_component(hass, component)
assert module is not None
component_path = path.dirname(module.__file__) component_path = path.dirname(module.__file__)
# If loading translations for the package root, (__init__.py), the # If loading translations for the package root, (__init__.py), the
@ -48,19 +52,23 @@ def component_translation_file(hass, component, language):
return path.join(component_path, '.translations', filename) return path.join(component_path, '.translations', filename)
def load_translations_files(translation_files): def load_translations_files(translation_files: Dict[str, str]) \
-> Dict[str, Dict[str, Any]]:
"""Load and parse translation.json files.""" """Load and parse translation.json files."""
loaded = {} loaded = {}
for component, translation_file in translation_files.items(): for component, translation_file in translation_files.items():
loaded[component] = load_json(translation_file) loaded_json = load_json(translation_file)
assert isinstance(loaded_json, dict)
loaded[component] = loaded_json
return loaded return loaded
def build_resources(translation_cache, components): def build_resources(translation_cache: Dict[str, Dict[str, Any]],
components: Iterable[str]) -> Dict[str, Dict[str, Any]]:
"""Build the resources response for the given components.""" """Build the resources response for the given components."""
# Build response # Build response
resources = {} resources = {} # type: Dict[str, Dict[str, Any]]
for component in components: for component in components:
if '.' not in component: if '.' not in component:
domain = component domain = component
@ -79,7 +87,8 @@ def build_resources(translation_cache, components):
@bind_hass @bind_hass
async def async_get_component_resources(hass, language): async def async_get_component_resources(hass: HomeAssistantType,
language: str) -> Dict[str, Any]:
"""Return translation resources for all components.""" """Return translation resources for all components."""
if TRANSLATION_STRING_CACHE not in hass.data: if TRANSLATION_STRING_CACHE not in hass.data:
hass.data[TRANSLATION_STRING_CACHE] = {} hass.data[TRANSLATION_STRING_CACHE] = {}
@ -99,12 +108,13 @@ async def async_get_component_resources(hass, language):
# Load missing files # Load missing files
if missing_files: if missing_files:
loaded_translations = await hass.async_add_job( load_translations_job = hass.async_add_job(
load_translations_files, missing_files) load_translations_files, missing_files)
assert load_translations_job is not None
loaded_translations = await load_translations_job
# Update cache # Update cache
for component, translation_data in loaded_translations.items(): translation_cache.update(loaded_translations)
translation_cache[component] = translation_data
resources = build_resources(translation_cache, components) resources = build_resources(translation_cache, components)
@ -114,7 +124,8 @@ async def async_get_component_resources(hass, language):
@bind_hass @bind_hass
async def async_get_translations(hass, language): async def async_get_translations(hass: HomeAssistantType,
language: str) -> Dict[str, Any]:
"""Return all backend translations.""" """Return all backend translations."""
resources = await async_get_component_resources(hass, language) resources = await async_get_component_resources(hass, language)
if language != 'en': if language != 'en':

View file

@ -60,4 +60,4 @@ whitelist_externals=/bin/bash
deps = deps =
-r{toxinidir}/requirements_test.txt -r{toxinidir}/requirements_test.txt
commands = commands =
/bin/bash -c 'mypy homeassistant/*.py homeassistant/auth/ homeassistant/util/' /bin/bash -c 'mypy homeassistant/*.py homeassistant/{auth,util}/ homeassistant/helpers/{icon,intent,json,location,state,translation,typing}.py'