RFC: Use bind_hass for helpers (#9745)

* Add Helpers bind_hass functionality

* Update other helpers
This commit is contained in:
Paulus Schoutsen 2017-10-08 08:17:54 -07:00 committed by GitHub
parent e19e9a1f2b
commit ca54bbfcc9
16 changed files with 108 additions and 25 deletions

View file

@ -11,13 +11,11 @@ from typing import Any, Optional, Dict
import voluptuous as vol
import homeassistant.components as core_components
from homeassistant import (
core, config as conf_util, loader, components as core_components)
from homeassistant.components import persistent_notification
import homeassistant.config as conf_util
import homeassistant.core as core
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE
from homeassistant.setup import async_setup_component
import homeassistant.loader as loader
from homeassistant.util.logging import AsyncHandler
from homeassistant.util.package import async_get_user_site, get_user_site
from homeassistant.util.yaml import clear_secret_cache

View file

@ -30,7 +30,7 @@ from homeassistant.const import (
EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE,
EVENT_SERVICE_REMOVED, __version__)
from homeassistant.loader import Components
from homeassistant import loader
from homeassistant.exceptions import (
HomeAssistantError, InvalidEntityFormatError)
from homeassistant.util.async import (
@ -129,7 +129,8 @@ class HomeAssistant(object):
self.services = ServiceRegistry(self)
self.states = StateMachine(self.bus, self.loop)
self.config = Config() # type: Config
self.components = Components(self)
self.components = loader.Components(self)
self.helpers = loader.Helpers(self)
# This is a dictionary that any component can store any data on.
self.data = {}
self.state = CoreState.not_running

View file

@ -9,8 +9,8 @@ from aiohttp.web_exceptions import HTTPGatewayTimeout, HTTPBadGateway
import async_timeout
from homeassistant.core import callback
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE
from homeassistant.const import __version__
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__
from homeassistant.loader import bind_hass
DATA_CONNECTOR = 'aiohttp_connector'
DATA_CONNECTOR_NOTVERIFY = 'aiohttp_connector_notverify'
@ -21,6 +21,7 @@ SERVER_SOFTWARE = 'HomeAssistant/{0} aiohttp/{1} Python/{2[0]}.{2[1]}'.format(
@callback
@bind_hass
def async_get_clientsession(hass, verify_ssl=True):
"""Return default aiohttp ClientSession.
@ -45,6 +46,7 @@ def async_get_clientsession(hass, verify_ssl=True):
@callback
@bind_hass
def async_create_clientsession(hass, verify_ssl=True, auto_cleanup=True,
**kwargs):
"""Create a new ClientSession with kwargs, i.e. for cookies.
@ -71,6 +73,7 @@ def async_create_clientsession(hass, verify_ssl=True, auto_cleanup=True,
@asyncio.coroutine
@bind_hass
def async_aiohttp_proxy_web(hass, request, web_coro, buffer_size=102400,
timeout=10):
"""Stream websession request to aiohttp web response."""
@ -102,6 +105,7 @@ def async_aiohttp_proxy_web(hass, request, web_coro, buffer_size=102400,
@asyncio.coroutine
@bind_hass
def async_aiohttp_proxy_stream(hass, request, stream, content_type,
buffer_size=102400, timeout=10):
"""Stream a stream to aiohttp web response."""

View file

@ -8,6 +8,7 @@ There are two different types of discoveries that can be fired/listened for.
import asyncio
from homeassistant import setup, core
from homeassistant.loader import bind_hass
from homeassistant.const import (
ATTR_DISCOVERED, ATTR_SERVICE, EVENT_PLATFORM_DISCOVERED)
from homeassistant.exceptions import HomeAssistantError
@ -18,6 +19,7 @@ EVENT_LOAD_PLATFORM = 'load_platform.{}'
ATTR_PLATFORM = 'platform'
@bind_hass
def listen(hass, service, callback):
"""Set up listener for discovery of specific service.
@ -28,6 +30,7 @@ def listen(hass, service, callback):
@core.callback
@bind_hass
def async_listen(hass, service, callback):
"""Set up listener for discovery of specific service.
@ -48,6 +51,7 @@ def async_listen(hass, service, callback):
hass.bus.async_listen(EVENT_PLATFORM_DISCOVERED, discovery_event_listener)
@bind_hass
def discover(hass, service, discovered=None, component=None, hass_config=None):
"""Fire discovery event. Can ensure a component is loaded."""
hass.add_job(
@ -55,6 +59,7 @@ def discover(hass, service, discovered=None, component=None, hass_config=None):
@asyncio.coroutine
@bind_hass
def async_discover(hass, service, discovered=None, component=None,
hass_config=None):
"""Fire discovery event. Can ensure a component is loaded."""
@ -76,6 +81,7 @@ def async_discover(hass, service, discovered=None, component=None,
hass.bus.async_fire(EVENT_PLATFORM_DISCOVERED, data)
@bind_hass
def listen_platform(hass, component, callback):
"""Register a platform loader listener."""
run_callback_threadsafe(
@ -83,6 +89,7 @@ def listen_platform(hass, component, callback):
).result()
@bind_hass
def async_listen_platform(hass, component, callback):
"""Register a platform loader listener.
@ -109,6 +116,7 @@ def async_listen_platform(hass, component, callback):
EVENT_PLATFORM_DISCOVERED, discovery_platform_listener)
@bind_hass
def load_platform(hass, component, platform, discovered=None,
hass_config=None):
"""Load a component and platform dynamically.
@ -127,6 +135,7 @@ def load_platform(hass, component, platform, discovered=None,
@asyncio.coroutine
@bind_hass
def async_load_platform(hass, component, platform, discovered=None,
hass_config=None):
"""Load a component and platform dynamically.

View file

@ -2,6 +2,7 @@
import logging
from homeassistant.core import callback
from homeassistant.loader import bind_hass
from homeassistant.util.async import run_callback_threadsafe
@ -9,6 +10,7 @@ _LOGGER = logging.getLogger(__name__)
DATA_DISPATCHER = 'dispatcher'
@bind_hass
def dispatcher_connect(hass, signal, target):
"""Connect a callable function to a signal."""
async_unsub = run_callback_threadsafe(
@ -22,6 +24,7 @@ def dispatcher_connect(hass, signal, target):
@callback
@bind_hass
def async_dispatcher_connect(hass, signal, target):
"""Connect a callable function to a signal.
@ -49,12 +52,14 @@ def async_dispatcher_connect(hass, signal, target):
return async_remove_dispatcher
@bind_hass
def dispatcher_send(hass, signal, *args):
"""Send signal and data."""
hass.loop.call_soon_threadsafe(async_dispatcher_send, hass, signal, *args)
@callback
@bind_hass
def async_dispatcher_send(hass, signal, *args):
"""Send signal and data.

View file

@ -1,6 +1,7 @@
"""Helpers for listening to events."""
import functools as ft
from homeassistant.loader import bind_hass
from homeassistant.helpers.sun import get_astral_event_next
from ..core import HomeAssistant, callback
from ..const import (
@ -35,6 +36,7 @@ def threaded_listener_factory(async_factory):
@callback
@bind_hass
def async_track_state_change(hass, entity_ids, action, from_state=None,
to_state=None):
"""Track specific state changes.
@ -86,6 +88,7 @@ track_state_change = threaded_listener_factory(async_track_state_change)
@callback
@bind_hass
def async_track_template(hass, template, action, variables=None):
"""Add a listener that track state changes with template condition."""
from . import condition
@ -114,6 +117,7 @@ track_template = threaded_listener_factory(async_track_template)
@callback
@bind_hass
def async_track_same_state(hass, orig_value, period, action,
async_check_func=None, entity_ids=MATCH_ALL):
"""Track the state of entities for a period and run a action.
@ -170,6 +174,7 @@ track_same_state = threaded_listener_factory(async_track_same_state)
@callback
@bind_hass
def async_track_point_in_time(hass, action, point_in_time):
"""Add a listener that fires once after a specific point in time."""
utc_point_in_time = dt_util.as_utc(point_in_time)
@ -187,6 +192,7 @@ track_point_in_time = threaded_listener_factory(async_track_point_in_time)
@callback
@bind_hass
def async_track_point_in_utc_time(hass, action, point_in_time):
"""Add a listener that fires once after a specific point in UTC time."""
# Ensure point_in_time is UTC
@ -221,6 +227,7 @@ track_point_in_utc_time = threaded_listener_factory(
@callback
@bind_hass
def async_track_time_interval(hass, action, interval):
"""Add a listener that fires repetitively at every timedelta interval."""
remove = None
@ -251,6 +258,7 @@ track_time_interval = threaded_listener_factory(async_track_time_interval)
@callback
@bind_hass
def async_track_sunrise(hass, action, offset=None):
"""Add a listener that will fire a specified offset from sunrise daily."""
remove = None
@ -279,6 +287,7 @@ track_sunrise = threaded_listener_factory(async_track_sunrise)
@callback
@bind_hass
def async_track_sunset(hass, action, offset=None):
"""Add a listener that will fire a specified offset from sunset daily."""
remove = None
@ -307,6 +316,7 @@ track_sunset = threaded_listener_factory(async_track_sunset)
@callback
@bind_hass
def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
hour=None, minute=None, second=None,
local=False):
@ -352,6 +362,7 @@ track_utc_time_change = threaded_listener_factory(async_track_utc_time_change)
@callback
@bind_hass
def async_track_time_change(hass, action, year=None, month=None, day=None,
hour=None, minute=None, second=None):
"""Add a listener that will fire if UTC time matches a pattern."""

View file

@ -6,6 +6,7 @@ import voluptuous as vol
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import bind_hass
DATA_KEY = 'intent'
@ -19,6 +20,7 @@ SPEECH_TYPE_SSML = 'ssml'
@callback
@bind_hass
def async_register(hass, handler):
"""Register an intent with Home Assistant."""
intents = hass.data.get(DATA_KEY)
@ -33,6 +35,7 @@ def async_register(hass, handler):
@asyncio.coroutine
@bind_hass
def async_handle(hass, platform, intent_type, slots=None, text_input=None):
"""Handle an intent."""
handler = hass.data.get(DATA_KEY, {}).get(intent_type)

View file

@ -7,6 +7,7 @@ import async_timeout
from homeassistant.core import HomeAssistant, CoreState, callback
from homeassistant.const import EVENT_HOMEASSISTANT_START
from homeassistant.loader import bind_hass
from homeassistant.components.history import get_states, last_recorder_run
from homeassistant.components.recorder import (
wait_connection_ready, DOMAIN as _RECORDER)
@ -49,6 +50,7 @@ def _load_restore_cache(hass: HomeAssistant):
@asyncio.coroutine
@bind_hass
def async_get_last_state(hass, entity_id: str):
"""Restore state."""
if DATA_RESTORE_CACHE in hass.data:

View file

@ -9,7 +9,7 @@ import voluptuous as vol
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.core import HomeAssistant # NOQA
from homeassistant.exceptions import TemplateError
from homeassistant.loader import get_component
from homeassistant.loader import get_component, bind_hass
import homeassistant.helpers.config_validation as cv
from homeassistant.util.async import run_coroutine_threadsafe
@ -22,6 +22,7 @@ CONF_SERVICE_DATA_TEMPLATE = 'data_template'
_LOGGER = logging.getLogger(__name__)
@bind_hass
def call_from_config(hass, config, blocking=False, variables=None,
validate_config=True):
"""Call a service based on a config hash."""
@ -31,6 +32,7 @@ def call_from_config(hass, config, blocking=False, variables=None,
@asyncio.coroutine
@bind_hass
def async_call_from_config(hass, config, blocking=False, variables=None,
validate_config=True):
"""Call a service based on a config hash."""
@ -80,6 +82,7 @@ def async_call_from_config(hass, config, blocking=False, variables=None,
domain, service_name, service_data, blocking)
@bind_hass
def extract_entity_ids(hass, service_call, expand_group=True):
"""Extract a list of entity ids from a service call.

View file

@ -5,11 +5,13 @@ import sys
from homeassistant.core import callback
from homeassistant.const import RESTART_EXIT_CODE
from homeassistant.loader import bind_hass
_LOGGER = logging.getLogger(__name__)
@callback
@bind_hass
def async_register_signal_handling(hass):
"""Register system signal handler for core."""
if sys.platform != 'win32':

View file

@ -4,6 +4,7 @@ import json
import logging
from collections import defaultdict
from homeassistant.loader import bind_hass
import homeassistant.util.dt as dt_util
from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID, ATTR_MEDIA_CONTENT_TYPE, ATTR_MEDIA_SEEK_POSITION,
@ -120,6 +121,7 @@ def get_changed_since(states, utc_point_in_time):
if state.last_updated >= utc_point_in_time]
@bind_hass
def reproduce_state(hass, states, blocking=False):
"""Reproduce given state."""
return run_coroutine_threadsafe(
@ -127,6 +129,7 @@ def reproduce_state(hass, states, blocking=False):
@asyncio.coroutine
@bind_hass
def async_reproduce_state(hass, states, blocking=False):
"""Reproduce given state."""
if isinstance(states, State):

View file

@ -3,11 +3,13 @@ import datetime
from homeassistant.core import callback
from homeassistant.util import dt as dt_util
from homeassistant.loader import bind_hass
DATA_LOCATION_CACHE = 'astral_location_cache'
@callback
@bind_hass
def get_astral_location(hass):
"""Get an astral location for the current Home Assistant configuration."""
from astral import Location
@ -29,6 +31,7 @@ def get_astral_location(hass):
@callback
@bind_hass
def get_astral_event_next(hass, event, utc_point_in_time=None, offset=None):
"""Calculate the next specified solar event."""
import astral
@ -56,6 +59,7 @@ def get_astral_event_next(hass, event, utc_point_in_time=None, offset=None):
@callback
@bind_hass
def get_astral_event_date(hass, event, date=None):
"""Calculate the astral event time for the specified date."""
import astral
@ -76,6 +80,7 @@ def get_astral_event_date(hass, event, date=None):
@callback
@bind_hass
def is_up(hass, utc_point_in_time=None):
"""Calculate if the sun is currently up."""
if utc_point_in_time is None:

View file

@ -15,7 +15,7 @@ from homeassistant.const import (
from homeassistant.core import State
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import location as loc_helper
from homeassistant.loader import get_component
from homeassistant.loader import get_component, bind_hass
from homeassistant.util import convert, dt as dt_util, location as loc_util
from homeassistant.util.async import run_callback_threadsafe
@ -30,6 +30,7 @@ _RE_GET_ENTITIES = re.compile(
)
@bind_hass
def attach(hass, obj):
"""Recursively attach hass to all template instances in list and dict."""
if isinstance(obj, list):

View file

@ -4,7 +4,7 @@ Provides methods for loading Home Assistant components.
This module has quite some complex parts. I have tried to add as much
documentation as possible to keep it understandable.
Components are loaded by calling get_component('switch') from your code.
Components can be accessed via hass.components.switch from your code.
If you want to retrieve a platform that is part of a component, you should
call get_component('switch.your_platform'). In both cases the config directory
is checked to see if it contains a user provided version. If not available it
@ -183,22 +183,38 @@ class Components:
component = get_component(comp_name)
if component is None:
raise ImportError('Unable to load {}'.format(comp_name))
wrapped = ComponentWrapper(self._hass, component)
wrapped = ModuleWrapper(self._hass, component)
setattr(self, comp_name, wrapped)
return wrapped
class ComponentWrapper:
"""Class to wrap a component and auto fill in hass argument."""
class Helpers:
"""Helper to load helpers."""
def __init__(self, hass, component):
"""Initialize the component wrapper."""
def __init__(self, hass):
"""Initialize the Helpers class."""
self._hass = hass
self._component = component
def __getattr__(self, helper_name):
"""Fetch a helper."""
helper = importlib.import_module(
'homeassistant.helpers.{}'.format(helper_name))
wrapped = ModuleWrapper(self._hass, helper)
setattr(self, helper_name, wrapped)
return wrapped
class ModuleWrapper:
"""Class to wrap a Python module and auto fill in hass argument."""
def __init__(self, hass, module):
"""Initialize the module wrapper."""
self._hass = hass
self._module = module
def __getattr__(self, attr):
"""Fetch an attribute."""
value = getattr(self._component, attr)
value = getattr(self._module, attr)
if hasattr(value, '__bind_hass'):
value = ft.partial(value, self._hass)

View file

@ -27,6 +27,7 @@ class TestHelpersDiscovery:
@patch('homeassistant.setup.async_setup_component')
def test_listen(self, mock_setup_component):
"""Test discovery listen/discover combo."""
helpers = self.hass.helpers
calls_single = []
calls_multi = []
@ -40,12 +41,12 @@ class TestHelpersDiscovery:
"""Service discovered callback."""
calls_multi.append((service, info))
discovery.listen(self.hass, 'test service', callback_single)
discovery.listen(self.hass, ['test service', 'another service'],
callback_multi)
helpers.discovery.listen('test service', callback_single)
helpers.discovery.listen(['test service', 'another service'],
callback_multi)
discovery.discover(self.hass, 'test service', 'discovery info',
'test_component')
helpers.discovery.discover('test service', 'discovery info',
'test_component')
self.hass.block_till_done()
assert mock_setup_component.called
@ -54,8 +55,8 @@ class TestHelpersDiscovery:
assert len(calls_single) == 1
assert calls_single[0] == ('test service', 'discovery info')
discovery.discover(self.hass, 'another service', 'discovery info',
'test_component')
helpers.discovery.discover('another service', 'discovery info',
'test_component')
self.hass.block_till_done()
assert len(calls_single) == 1

View file

@ -84,3 +84,22 @@ def test_component_wrapper(hass):
yield from hass.async_block_till_done()
assert len(calls) == 1
@asyncio.coroutine
def test_helpers_wrapper(hass):
"""Test helpers wrapper."""
helpers = loader.Helpers(hass)
result = []
def discovery_callback(service, discovered):
"""Handle discovery callback."""
result.append(discovered)
helpers.discovery.async_listen('service_name', discovery_callback)
yield from helpers.discovery.async_discover('service_name', 'hello')
yield from hass.async_block_till_done()
assert result == ['hello']