Spread async love (#3575)
* Convert Entity.update_ha_state to be async * Make Service.call async * Update entity.py * Add Entity.async_update * Make automation zone trigger async * Fix linting * Reduce flakiness in hass.block_till_done * Make automation.numeric_state async * Make mqtt.subscribe async * Make automation.mqtt async * Make automation.time async * Make automation.sun async * Add async_track_point_in_utc_time * Make helpers.track_sunrise/set async * Add async_track_state_change * Make automation.state async * Clean up helpers/entity.py tests * Lint * Lint * Core.is_state and Core.is_state_attr are async friendly * Lint * Lint
This commit is contained in:
parent
7e50ccd32a
commit
b650b2b0db
17 changed files with 323 additions and 151 deletions
|
@ -4,6 +4,7 @@ Offer MQTT listening automation rules.
|
|||
For more details about this automation rule, please refer to the documentation
|
||||
at https://home-assistant.io/components/automation/#mqtt-trigger
|
||||
"""
|
||||
import asyncio
|
||||
import voluptuous as vol
|
||||
|
||||
import homeassistant.components.mqtt as mqtt
|
||||
|
@ -26,10 +27,11 @@ def trigger(hass, config, action):
|
|||
topic = config.get(CONF_TOPIC)
|
||||
payload = config.get(CONF_PAYLOAD)
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_automation_listener(msg_topic, msg_payload, qos):
|
||||
"""Listen for MQTT messages."""
|
||||
if payload is None or payload == msg_payload:
|
||||
action({
|
||||
hass.async_add_job(action, {
|
||||
'trigger': {
|
||||
'platform': 'mqtt',
|
||||
'topic': msg_topic,
|
||||
|
|
|
@ -4,6 +4,7 @@ Offer numeric state listening automation rules.
|
|||
For more details about this automation rule, please refer to the documentation
|
||||
at https://home-assistant.io/components/automation/#numeric-state-trigger
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -34,7 +35,7 @@ def trigger(hass, config, action):
|
|||
if value_template is not None:
|
||||
value_template.hass = hass
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@asyncio.coroutine
|
||||
def state_automation_listener(entity, from_s, to_s):
|
||||
"""Listen for state changes and calls action."""
|
||||
if to_s is None:
|
||||
|
@ -50,19 +51,19 @@ def trigger(hass, config, action):
|
|||
}
|
||||
|
||||
# If new one doesn't match, nothing to do
|
||||
if not condition.numeric_state(
|
||||
if not condition.async_numeric_state(
|
||||
hass, to_s, below, above, value_template, variables):
|
||||
return
|
||||
|
||||
# Only match if old didn't exist or existed but didn't match
|
||||
# Written as: skip if old one did exist and matched
|
||||
if from_s is not None and condition.numeric_state(
|
||||
if from_s is not None and condition.async_numeric_state(
|
||||
hass, from_s, below, above, value_template, variables):
|
||||
return
|
||||
|
||||
variables['trigger']['from_state'] = from_s
|
||||
variables['trigger']['to_state'] = to_s
|
||||
|
||||
action(variables)
|
||||
hass.async_add_job(action, variables)
|
||||
|
||||
return track_state_change(hass, entity_id, state_automation_listener)
|
||||
|
|
|
@ -4,12 +4,15 @@ Offer state listening automation rules.
|
|||
For more details about this automation rule, please refer to the documentation
|
||||
at https://home-assistant.io/components/automation/#state-trigger
|
||||
"""
|
||||
import asyncio
|
||||
import voluptuous as vol
|
||||
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.const import MATCH_ALL, CONF_PLATFORM
|
||||
from homeassistant.helpers.event import track_state_change, track_point_in_time
|
||||
from homeassistant.helpers.event import (
|
||||
async_track_state_change, async_track_point_in_utc_time)
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.util.async import run_callback_threadsafe
|
||||
|
||||
CONF_ENTITY_ID = "entity_id"
|
||||
CONF_FROM = "from"
|
||||
|
@ -38,16 +41,17 @@ def trigger(hass, config, action):
|
|||
from_state = config.get(CONF_FROM, MATCH_ALL)
|
||||
to_state = config.get(CONF_TO) or config.get(CONF_STATE) or MATCH_ALL
|
||||
time_delta = config.get(CONF_FOR)
|
||||
remove_state_for_cancel = None
|
||||
remove_state_for_listener = None
|
||||
async_remove_state_for_cancel = None
|
||||
async_remove_state_for_listener = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def state_automation_listener(entity, from_s, to_s):
|
||||
"""Listen for state changes and calls action."""
|
||||
nonlocal remove_state_for_cancel, remove_state_for_listener
|
||||
nonlocal async_remove_state_for_cancel, async_remove_state_for_listener
|
||||
|
||||
def call_action():
|
||||
"""Call action with right context."""
|
||||
action({
|
||||
hass.async_add_job(action, {
|
||||
'trigger': {
|
||||
'platform': 'state',
|
||||
'entity_id': entity,
|
||||
|
@ -61,35 +65,41 @@ def trigger(hass, config, action):
|
|||
call_action()
|
||||
return
|
||||
|
||||
@asyncio.coroutine
|
||||
def state_for_listener(now):
|
||||
"""Fire on state changes after a delay and calls action."""
|
||||
remove_state_for_cancel()
|
||||
async_remove_state_for_cancel()
|
||||
call_action()
|
||||
|
||||
@asyncio.coroutine
|
||||
def state_for_cancel_listener(entity, inner_from_s, inner_to_s):
|
||||
"""Fire on changes and cancel for listener if changed."""
|
||||
if inner_to_s.state == to_s.state:
|
||||
return
|
||||
remove_state_for_listener()
|
||||
remove_state_for_cancel()
|
||||
async_remove_state_for_listener()
|
||||
async_remove_state_for_cancel()
|
||||
|
||||
remove_state_for_listener = track_point_in_time(
|
||||
async_remove_state_for_listener = async_track_point_in_utc_time(
|
||||
hass, state_for_listener, dt_util.utcnow() + time_delta)
|
||||
|
||||
remove_state_for_cancel = track_state_change(
|
||||
async_remove_state_for_cancel = async_track_state_change(
|
||||
hass, entity, state_for_cancel_listener)
|
||||
|
||||
unsub = track_state_change(hass, entity_id, state_automation_listener,
|
||||
from_state, to_state)
|
||||
unsub = async_track_state_change(
|
||||
hass, entity_id, state_automation_listener, from_state, to_state)
|
||||
|
||||
def async_remove():
|
||||
"""Remove state listeners async."""
|
||||
unsub()
|
||||
# pylint: disable=not-callable
|
||||
if async_remove_state_for_cancel is not None:
|
||||
async_remove_state_for_cancel()
|
||||
|
||||
if async_remove_state_for_listener is not None:
|
||||
async_remove_state_for_listener()
|
||||
|
||||
def remove():
|
||||
"""Remove state listeners."""
|
||||
unsub()
|
||||
# pylint: disable=not-callable
|
||||
if remove_state_for_cancel is not None:
|
||||
remove_state_for_cancel()
|
||||
|
||||
if remove_state_for_listener is not None:
|
||||
remove_state_for_listener()
|
||||
run_callback_threadsafe(hass.loop, async_remove).result()
|
||||
|
||||
return remove
|
||||
|
|
|
@ -4,6 +4,7 @@ Offer sun based automation rules.
|
|||
For more details about this automation rule, please refer to the documentation
|
||||
at https://home-assistant.io/components/automation/#sun-trigger
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
|
||||
|
@ -30,9 +31,10 @@ def trigger(hass, config, action):
|
|||
event = config.get(CONF_EVENT)
|
||||
offset = config.get(CONF_OFFSET)
|
||||
|
||||
@asyncio.coroutine
|
||||
def call_action():
|
||||
"""Call action with right context."""
|
||||
action({
|
||||
hass.async_add_job(action, {
|
||||
'trigger': {
|
||||
'platform': 'sun',
|
||||
'event': event,
|
||||
|
|
|
@ -4,6 +4,7 @@ Offer time listening automation rules.
|
|||
For more details about this automation rule, please refer to the documentation
|
||||
at https://home-assistant.io/components/automation/#time-trigger
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -38,9 +39,10 @@ def trigger(hass, config, action):
|
|||
minutes = config.get(CONF_MINUTES)
|
||||
seconds = config.get(CONF_SECONDS)
|
||||
|
||||
@asyncio.coroutine
|
||||
def time_automation_listener(now):
|
||||
"""Listen for time changes and calls action."""
|
||||
action({
|
||||
hass.async_add_job(action, {
|
||||
'trigger': {
|
||||
'platform': 'time',
|
||||
'now': now,
|
||||
|
|
|
@ -4,6 +4,7 @@ Offer zone automation rules.
|
|||
For more details about this automation rule, please refer to the documentation
|
||||
at https://home-assistant.io/components/automation/#zone-trigger
|
||||
"""
|
||||
import asyncio
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import (
|
||||
|
@ -31,6 +32,7 @@ def trigger(hass, config, action):
|
|||
zone_entity_id = config.get(CONF_ZONE)
|
||||
event = config.get(CONF_EVENT)
|
||||
|
||||
@asyncio.coroutine
|
||||
def zone_automation_listener(entity, from_s, to_s):
|
||||
"""Listen for state changes and calls action."""
|
||||
if from_s and not location.has_location(from_s) or \
|
||||
|
@ -47,7 +49,7 @@ def trigger(hass, config, action):
|
|||
# pylint: disable=too-many-boolean-expressions
|
||||
if event == EVENT_ENTER and not from_match and to_match or \
|
||||
event == EVENT_LEAVE and from_match and not to_match:
|
||||
action({
|
||||
hass.async_add_job(action, {
|
||||
'trigger': {
|
||||
'platform': 'zone',
|
||||
'entity_id': entity,
|
||||
|
|
|
@ -4,6 +4,7 @@ Event parser and human readable log generator.
|
|||
For more details about this component, please refer to the documentation at
|
||||
https://home-assistant.io/components/logbook/
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from itertools import groupby
|
||||
|
@ -20,6 +21,7 @@ from homeassistant.const import (EVENT_HOMEASSISTANT_START,
|
|||
STATE_NOT_HOME, STATE_OFF, STATE_ON,
|
||||
ATTR_HIDDEN)
|
||||
from homeassistant.core import State, split_entity_id, DOMAIN as HA_DOMAIN
|
||||
from homeassistant.util.async import run_callback_threadsafe
|
||||
|
||||
DOMAIN = "logbook"
|
||||
DEPENDENCIES = ['recorder', 'frontend']
|
||||
|
@ -57,6 +59,13 @@ LOG_MESSAGE_SCHEMA = vol.Schema({
|
|||
|
||||
|
||||
def log_entry(hass, name, message, domain=None, entity_id=None):
|
||||
"""Add an entry to the logbook."""
|
||||
run_callback_threadsafe(
|
||||
hass.loop, async_log_entry, hass, name, message, domain, entity_id
|
||||
).result()
|
||||
|
||||
|
||||
def async_log_entry(hass, name, message, domain=None, entity_id=None):
|
||||
"""Add an entry to the logbook."""
|
||||
data = {
|
||||
ATTR_NAME: name,
|
||||
|
@ -67,11 +76,12 @@ def log_entry(hass, name, message, domain=None, entity_id=None):
|
|||
data[ATTR_DOMAIN] = domain
|
||||
if entity_id is not None:
|
||||
data[ATTR_ENTITY_ID] = entity_id
|
||||
hass.bus.fire(EVENT_LOGBOOK_ENTRY, data)
|
||||
hass.bus.async_fire(EVENT_LOGBOOK_ENTRY, data)
|
||||
|
||||
|
||||
def setup(hass, config):
|
||||
"""Listen for download events to download files."""
|
||||
@asyncio.coroutine
|
||||
def log_message(service):
|
||||
"""Handle sending notification message service calls."""
|
||||
message = service.data[ATTR_MESSAGE]
|
||||
|
@ -80,8 +90,8 @@ def setup(hass, config):
|
|||
entity_id = service.data.get(ATTR_ENTITY_ID)
|
||||
|
||||
message.hass = hass
|
||||
message = message.render()
|
||||
log_entry(hass, name, message, domain, entity_id)
|
||||
message = message.async_render()
|
||||
async_log_entry(hass, name, message, domain, entity_id)
|
||||
|
||||
hass.wsgi.register_view(LogbookView(hass, config))
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ Support for MQTT message handling.
|
|||
For more details about this component, please refer to the documentation at
|
||||
https://home-assistant.io/components/mqtt/
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
|
@ -11,6 +12,7 @@ import time
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import JobPriority
|
||||
from homeassistant.bootstrap import prepare_setup_platform
|
||||
from homeassistant.config import load_yaml_config_file
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
@ -164,11 +166,20 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
|
|||
|
||||
def subscribe(hass, topic, callback, qos=DEFAULT_QOS):
|
||||
"""Subscribe to an MQTT topic."""
|
||||
@asyncio.coroutine
|
||||
def mqtt_topic_subscriber(event):
|
||||
"""Match subscribed MQTT topic."""
|
||||
if _match_topic(topic, event.data[ATTR_TOPIC]):
|
||||
callback(event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD],
|
||||
event.data[ATTR_QOS])
|
||||
if not _match_topic(topic, event.data[ATTR_TOPIC]):
|
||||
return
|
||||
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
yield from callback(
|
||||
event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD],
|
||||
event.data[ATTR_QOS])
|
||||
else:
|
||||
hass.add_job(callback, event.data[ATTR_TOPIC],
|
||||
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS],
|
||||
priority=JobPriority.EVENT_CALLBACK)
|
||||
|
||||
remove = hass.bus.listen(EVENT_MQTT_MESSAGE_RECEIVED,
|
||||
mqtt_topic_subscriber)
|
||||
|
|
|
@ -248,12 +248,16 @@ class HomeAssistant(object):
|
|||
|
||||
def notify_when_done():
|
||||
"""Notify event loop when pool done."""
|
||||
count = 0
|
||||
while True:
|
||||
# Wait for the work queue to empty
|
||||
self.pool.block_till_done()
|
||||
|
||||
# Verify the loop is empty
|
||||
if self._loop_empty():
|
||||
count += 1
|
||||
|
||||
if count == 2:
|
||||
break
|
||||
|
||||
# sleep in the loop executor, this forces execution back into
|
||||
|
@ -675,40 +679,29 @@ class StateMachine(object):
|
|||
return list(self._states.values())
|
||||
|
||||
def get(self, entity_id):
|
||||
"""Retrieve state of entity_id or None if not found."""
|
||||
"""Retrieve state of entity_id or None if not found.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
return self._states.get(entity_id.lower())
|
||||
|
||||
def is_state(self, entity_id, state):
|
||||
"""Test if entity exists and is specified state."""
|
||||
return run_callback_threadsafe(
|
||||
self._loop, self.async_is_state, entity_id, state
|
||||
).result()
|
||||
|
||||
def async_is_state(self, entity_id, state):
|
||||
"""Test if entity exists and is specified state.
|
||||
|
||||
This method must be run in the event loop.
|
||||
Async friendly.
|
||||
"""
|
||||
entity_id = entity_id.lower()
|
||||
state_obj = self.get(entity_id)
|
||||
|
||||
return (entity_id in self._states and
|
||||
self._states[entity_id].state == state)
|
||||
return state_obj and state_obj.state == state
|
||||
|
||||
def is_state_attr(self, entity_id, name, value):
|
||||
"""Test if entity exists and has a state attribute set to value."""
|
||||
return run_callback_threadsafe(
|
||||
self._loop, self.async_is_state_attr, entity_id, name, value
|
||||
).result()
|
||||
|
||||
def async_is_state_attr(self, entity_id, name, value):
|
||||
"""Test if entity exists and has a state attribute set to value.
|
||||
|
||||
This method must be run in the event loop.
|
||||
Async friendly.
|
||||
"""
|
||||
entity_id = entity_id.lower()
|
||||
state_obj = self.get(entity_id)
|
||||
|
||||
return (entity_id in self._states and
|
||||
self._states[entity_id].attributes.get(name, None) == value)
|
||||
return state_obj and state_obj.attributes.get(name, None) == value
|
||||
|
||||
def remove(self, entity_id):
|
||||
"""Remove the state of an entity.
|
||||
|
@ -799,7 +792,8 @@ class StateMachine(object):
|
|||
class Service(object):
|
||||
"""Represents a callable service."""
|
||||
|
||||
__slots__ = ['func', 'description', 'fields', 'schema']
|
||||
__slots__ = ['func', 'description', 'fields', 'schema',
|
||||
'iscoroutinefunction']
|
||||
|
||||
def __init__(self, func, description, fields, schema):
|
||||
"""Initialize a service."""
|
||||
|
@ -807,6 +801,7 @@ class Service(object):
|
|||
self.description = description or ''
|
||||
self.fields = fields or {}
|
||||
self.schema = schema
|
||||
self.iscoroutinefunction = asyncio.iscoroutinefunction(func)
|
||||
|
||||
def as_dict(self):
|
||||
"""Return dictionary representation of this service."""
|
||||
|
@ -815,19 +810,6 @@ class Service(object):
|
|||
'fields': self.fields,
|
||||
}
|
||||
|
||||
def __call__(self, call):
|
||||
"""Execute the service."""
|
||||
try:
|
||||
if self.schema:
|
||||
call.data = self.schema(call.data)
|
||||
call.data = MappingProxyType(call.data)
|
||||
|
||||
self.func(call)
|
||||
except vol.MultipleInvalid as ex:
|
||||
_LOGGER.error('Invalid service data for %s.%s: %s',
|
||||
call.domain, call.service,
|
||||
humanize_error(call.data, ex))
|
||||
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class ServiceCall(object):
|
||||
|
@ -839,7 +821,7 @@ class ServiceCall(object):
|
|||
"""Initialize a service call."""
|
||||
self.domain = domain.lower()
|
||||
self.service = service.lower()
|
||||
self.data = data or {}
|
||||
self.data = MappingProxyType(data or {})
|
||||
self.call_id = call_id
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -983,9 +965,9 @@ class ServiceRegistry(object):
|
|||
fut = asyncio.Future(loop=self._loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def service_executed(call):
|
||||
def service_executed(event):
|
||||
"""Callback method that is called when service is executed."""
|
||||
if call.data[ATTR_SERVICE_CALL_ID] == call_id:
|
||||
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
|
||||
fut.set_result(True)
|
||||
|
||||
unsub = self._bus.async_listen(EVENT_SERVICE_EXECUTED,
|
||||
|
@ -1000,9 +982,10 @@ class ServiceRegistry(object):
|
|||
unsub()
|
||||
return success
|
||||
|
||||
@asyncio.coroutine
|
||||
def _event_to_service_call(self, event):
|
||||
"""Callback for SERVICE_CALLED events from the event bus."""
|
||||
service_data = event.data.get(ATTR_SERVICE_DATA)
|
||||
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
|
||||
domain = event.data.get(ATTR_DOMAIN).lower()
|
||||
service = event.data.get(ATTR_SERVICE).lower()
|
||||
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
|
||||
|
@ -1014,19 +997,41 @@ class ServiceRegistry(object):
|
|||
return
|
||||
|
||||
service_handler = self._services[domain][service]
|
||||
|
||||
def fire_service_executed():
|
||||
"""Fire service executed event."""
|
||||
if not call_id:
|
||||
return
|
||||
|
||||
data = {ATTR_SERVICE_CALL_ID: call_id}
|
||||
|
||||
if service_handler.iscoroutinefunction:
|
||||
self._bus.async_fire(EVENT_SERVICE_EXECUTED, data)
|
||||
else:
|
||||
self._bus.fire(EVENT_SERVICE_EXECUTED, data)
|
||||
|
||||
try:
|
||||
if service_handler.schema:
|
||||
service_data = service_handler.schema(service_data)
|
||||
except vol.Invalid as ex:
|
||||
_LOGGER.error('Invalid service data for %s.%s: %s',
|
||||
domain, service, humanize_error(service_data, ex))
|
||||
fire_service_executed()
|
||||
return
|
||||
|
||||
service_call = ServiceCall(domain, service, service_data, call_id)
|
||||
|
||||
# Add a job to the pool that calls _execute_service
|
||||
self._add_job(self._execute_service, service_handler, service_call,
|
||||
priority=JobPriority.EVENT_SERVICE)
|
||||
if not service_handler.iscoroutinefunction:
|
||||
def execute_service():
|
||||
"""Execute a service and fires a SERVICE_EXECUTED event."""
|
||||
service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
|
||||
def _execute_service(self, service, call):
|
||||
"""Execute a service and fires a SERVICE_EXECUTED event."""
|
||||
service(call)
|
||||
self._add_job(execute_service, priority=JobPriority.EVENT_SERVICE)
|
||||
return
|
||||
|
||||
if call.call_id is not None:
|
||||
self._bus.fire(
|
||||
EVENT_SERVICE_EXECUTED, {ATTR_SERVICE_CALL_ID: call.call_id})
|
||||
yield from service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
|
||||
def _generate_unique_id(self):
|
||||
"""Generate a unique service call id."""
|
||||
|
|
|
@ -84,6 +84,15 @@ def or_from_config(config: ConfigType, config_validation: bool=True):
|
|||
def numeric_state(hass: HomeAssistant, entity, below=None, above=None,
|
||||
value_template=None, variables=None):
|
||||
"""Test a numeric state condition."""
|
||||
return run_callback_threadsafe(
|
||||
hass.loop, async_numeric_state, hass, entity, below, above,
|
||||
value_template, variables,
|
||||
).result()
|
||||
|
||||
|
||||
def async_numeric_state(hass: HomeAssistant, entity, below=None, above=None,
|
||||
value_template=None, variables=None):
|
||||
"""Test a numeric state condition."""
|
||||
if isinstance(entity, str):
|
||||
entity = hass.states.get(entity)
|
||||
|
||||
|
@ -96,7 +105,7 @@ def numeric_state(hass: HomeAssistant, entity, below=None, above=None,
|
|||
variables = dict(variables or {})
|
||||
variables['state'] = entity
|
||||
try:
|
||||
value = value_template.render(variables)
|
||||
value = value_template.async_render(variables)
|
||||
except TemplateError as ex:
|
||||
_LOGGER.error("Template error: %s", ex)
|
||||
return False
|
||||
|
@ -290,7 +299,10 @@ def time_from_config(config, config_validation=True):
|
|||
|
||||
|
||||
def zone(hass, zone_ent, entity):
|
||||
"""Test if zone-condition matches."""
|
||||
"""Test if zone-condition matches.
|
||||
|
||||
Can be run async.
|
||||
"""
|
||||
if isinstance(zone_ent, str):
|
||||
zone_ent = hass.states.get(zone_ent)
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""An abstract class for entities."""
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from typing import Any, Optional, List, Dict
|
||||
|
@ -11,6 +12,7 @@ from homeassistant.const import (
|
|||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import NoEntitySpecifiedError
|
||||
from homeassistant.util import ensure_unique_string, slugify
|
||||
from homeassistant.util.async import run_coroutine_threadsafe
|
||||
|
||||
# Entity attributes that we will overwrite
|
||||
_OVERWRITE = {} # type: Dict[str, Any]
|
||||
|
@ -143,6 +145,23 @@ class Entity(object):
|
|||
|
||||
If force_refresh == True will update entity before setting state.
|
||||
"""
|
||||
# We're already in a thread, do the force refresh here.
|
||||
if force_refresh and not hasattr(self, 'async_update'):
|
||||
self.update()
|
||||
force_refresh = False
|
||||
|
||||
run_coroutine_threadsafe(
|
||||
self.async_update_ha_state(force_refresh), self.hass.loop
|
||||
).result()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_update_ha_state(self, force_refresh=False):
|
||||
"""Update Home Assistant with current state of entity.
|
||||
|
||||
If force_refresh == True will update entity before setting state.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if self.hass is None:
|
||||
raise RuntimeError("Attribute hass is None for {}".format(self))
|
||||
|
||||
|
@ -151,7 +170,13 @@ class Entity(object):
|
|||
"No entity id specified for entity {}".format(self.name))
|
||||
|
||||
if force_refresh:
|
||||
self.update()
|
||||
if hasattr(self, 'async_update'):
|
||||
# pylint: disable=no-member
|
||||
self.async_update()
|
||||
else:
|
||||
# PS: Run this in our own thread pool once we have
|
||||
# future support?
|
||||
yield from self.hass.loop.run_in_executor(None, self.update)
|
||||
|
||||
state = STATE_UNKNOWN if self.state is None else str(self.state)
|
||||
attr = self.state_attributes or {}
|
||||
|
@ -192,7 +217,7 @@ class Entity(object):
|
|||
# Could not convert state to float
|
||||
pass
|
||||
|
||||
return self.hass.states.set(
|
||||
self.hass.states.async_set(
|
||||
self.entity_id, state, attr, self.force_update)
|
||||
|
||||
def remove(self) -> None:
|
||||
|
|
|
@ -18,6 +18,28 @@ def track_state_change(hass, entity_ids, action, from_state=None,
|
|||
|
||||
Returns a function that can be called to remove the listener.
|
||||
"""
|
||||
async_unsub = run_callback_threadsafe(
|
||||
hass.loop, async_track_state_change, hass, entity_ids, action,
|
||||
from_state, to_state).result()
|
||||
|
||||
def remove():
|
||||
"""Remove listener."""
|
||||
run_callback_threadsafe(hass.loop, async_unsub).result()
|
||||
|
||||
return remove
|
||||
|
||||
|
||||
def async_track_state_change(hass, entity_ids, action, from_state=None,
|
||||
to_state=None):
|
||||
"""Track specific state changes.
|
||||
|
||||
entity_ids, from_state and to_state can be string or list.
|
||||
Use list to match multiple.
|
||||
|
||||
Returns a function that can be called to remove the listener.
|
||||
|
||||
Must be run within the event loop.
|
||||
"""
|
||||
from_state = _process_state_match(from_state)
|
||||
to_state = _process_state_match(to_state)
|
||||
|
||||
|
@ -52,7 +74,7 @@ def track_state_change(hass, entity_ids, action, from_state=None,
|
|||
event.data.get('old_state'),
|
||||
event.data.get('new_state'))
|
||||
|
||||
return hass.bus.listen(EVENT_STATE_CHANGED, state_change_listener)
|
||||
return hass.bus.async_listen(EVENT_STATE_CHANGED, state_change_listener)
|
||||
|
||||
|
||||
def track_point_in_time(hass, action, point_in_time):
|
||||
|
@ -69,6 +91,19 @@ def track_point_in_time(hass, action, point_in_time):
|
|||
|
||||
|
||||
def track_point_in_utc_time(hass, action, point_in_time):
|
||||
"""Add a listener that fires once after a specific point in UTC time."""
|
||||
async_unsub = run_callback_threadsafe(
|
||||
hass.loop, async_track_point_in_utc_time, hass, action, point_in_time
|
||||
).result()
|
||||
|
||||
def remove():
|
||||
"""Remove listener."""
|
||||
run_callback_threadsafe(hass.loop, async_unsub).result()
|
||||
|
||||
return remove
|
||||
|
||||
|
||||
def async_track_point_in_utc_time(hass, action, point_in_time):
|
||||
"""Add a listener that fires once after a specific point in UTC time."""
|
||||
# Ensure point_in_time is UTC
|
||||
point_in_time = dt_util.as_utc(point_in_time)
|
||||
|
@ -88,20 +123,14 @@ def track_point_in_utc_time(hass, action, point_in_time):
|
|||
# listener gets lined up twice to be executed. This will make
|
||||
# sure the second time it does nothing.
|
||||
point_in_time_listener.run = True
|
||||
async_remove()
|
||||
async_unsub()
|
||||
|
||||
hass.async_add_job(action, now)
|
||||
|
||||
future = run_callback_threadsafe(
|
||||
hass.loop, hass.bus.async_listen, EVENT_TIME_CHANGED,
|
||||
point_in_time_listener)
|
||||
async_remove = future.result()
|
||||
async_unsub = hass.bus.async_listen(EVENT_TIME_CHANGED,
|
||||
point_in_time_listener)
|
||||
|
||||
def remove():
|
||||
"""Remove listener."""
|
||||
run_callback_threadsafe(hass.loop, async_remove).result()
|
||||
|
||||
return remove
|
||||
return async_unsub
|
||||
|
||||
|
||||
def track_sunrise(hass, action, offset=None):
|
||||
|
@ -118,19 +147,21 @@ def track_sunrise(hass, action, offset=None):
|
|||
|
||||
return next_time
|
||||
|
||||
@asyncio.coroutine
|
||||
def sunrise_automation_listener(now):
|
||||
"""Called when it's time for action."""
|
||||
nonlocal remove
|
||||
remove = track_point_in_utc_time(hass, sunrise_automation_listener,
|
||||
next_rise())
|
||||
action()
|
||||
remove = async_track_point_in_utc_time(
|
||||
hass, sunrise_automation_listener, next_rise())
|
||||
hass.async_add_job(action)
|
||||
|
||||
remove = track_point_in_utc_time(hass, sunrise_automation_listener,
|
||||
next_rise())
|
||||
remove = run_callback_threadsafe(
|
||||
hass.loop, async_track_point_in_utc_time, hass,
|
||||
sunrise_automation_listener, next_rise()).result()
|
||||
|
||||
def remove_listener():
|
||||
"""Remove sunrise listener."""
|
||||
remove()
|
||||
"""Remove sunset listener."""
|
||||
run_callback_threadsafe(hass.loop, remove).result()
|
||||
|
||||
return remove_listener
|
||||
|
||||
|
@ -149,19 +180,21 @@ def track_sunset(hass, action, offset=None):
|
|||
|
||||
return next_time
|
||||
|
||||
@asyncio.coroutine
|
||||
def sunset_automation_listener(now):
|
||||
"""Called when it's time for action."""
|
||||
nonlocal remove
|
||||
remove = track_point_in_utc_time(hass, sunset_automation_listener,
|
||||
next_set())
|
||||
action()
|
||||
remove = async_track_point_in_utc_time(
|
||||
hass, sunset_automation_listener, next_set())
|
||||
hass.async_add_job(action)
|
||||
|
||||
remove = track_point_in_utc_time(hass, sunset_automation_listener,
|
||||
next_set())
|
||||
remove = run_callback_threadsafe(
|
||||
hass.loop, async_track_point_in_utc_time, hass,
|
||||
sunset_automation_listener, next_set()).result()
|
||||
|
||||
def remove_listener():
|
||||
"""Remove sunset listener."""
|
||||
remove()
|
||||
run_callback_threadsafe(hass.loop, remove).result()
|
||||
|
||||
return remove_listener
|
||||
|
||||
|
|
|
@ -149,8 +149,8 @@ class Template(object):
|
|||
global_vars = ENV.make_globals({
|
||||
'closest': location_methods.closest,
|
||||
'distance': location_methods.distance,
|
||||
'is_state': self.hass.states.async_is_state,
|
||||
'is_state_attr': self.hass.states.async_is_state_attr,
|
||||
'is_state': self.hass.states.is_state,
|
||||
'is_state_attr': self.hass.states.is_state_attr,
|
||||
'states': AllStates(self.hass),
|
||||
})
|
||||
|
||||
|
|
|
@ -77,7 +77,8 @@ class TestComponentsCore(unittest.TestCase):
|
|||
service_call = ha.ServiceCall('homeassistant', 'turn_on', {
|
||||
'entity_id': ['light.test', 'sensor.bla', 'light.bla']
|
||||
})
|
||||
self.hass.services._services['homeassistant']['turn_on'](service_call)
|
||||
service = self.hass.services._services['homeassistant']['turn_on']
|
||||
service.func(service_call)
|
||||
|
||||
self.assertEqual(2, mock_call.call_count)
|
||||
self.assertEqual(
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
"""The tests for the logbook component."""
|
||||
# pylint: disable=protected-access,too-many-public-methods
|
||||
import unittest
|
||||
from datetime import timedelta
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components import sun
|
||||
import homeassistant.core as ha
|
||||
|
@ -18,13 +19,17 @@ from tests.common import mock_http_component, get_test_home_assistant
|
|||
class TestComponentLogbook(unittest.TestCase):
|
||||
"""Test the History component."""
|
||||
|
||||
EMPTY_CONFIG = logbook.CONFIG_SCHEMA({ha.DOMAIN: {}, logbook.DOMAIN: {}})
|
||||
EMPTY_CONFIG = logbook.CONFIG_SCHEMA({logbook.DOMAIN: {}})
|
||||
|
||||
def setUp(self):
|
||||
"""Setup things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
mock_http_component(self.hass)
|
||||
assert setup_component(self.hass, logbook.DOMAIN, self.EMPTY_CONFIG)
|
||||
self.hass.config.components += ['frontend', 'recorder', 'api']
|
||||
with patch('homeassistant.components.logbook.'
|
||||
'register_built_in_panel'):
|
||||
assert setup_component(self.hass, logbook.DOMAIN,
|
||||
self.EMPTY_CONFIG)
|
||||
|
||||
def tearDown(self):
|
||||
"""Stop everything that was started."""
|
||||
|
@ -44,7 +49,6 @@ class TestComponentLogbook(unittest.TestCase):
|
|||
logbook.ATTR_DOMAIN: 'switch',
|
||||
logbook.ATTR_ENTITY_ID: 'switch.test_switch'
|
||||
}, True)
|
||||
self.hass.block_till_done()
|
||||
|
||||
self.assertEqual(1, len(calls))
|
||||
last_call = calls[-1]
|
||||
|
@ -65,7 +69,6 @@ class TestComponentLogbook(unittest.TestCase):
|
|||
|
||||
self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener)
|
||||
self.hass.services.call(logbook.DOMAIN, 'log', {}, True)
|
||||
self.hass.block_till_done()
|
||||
|
||||
self.assertEqual(0, len(calls))
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
"""Test the entity helper."""
|
||||
# pylint: disable=protected-access,too-many-public-methods
|
||||
import unittest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import homeassistant.helpers.entity as entity
|
||||
from homeassistant.const import ATTR_HIDDEN
|
||||
|
@ -8,26 +11,75 @@ from homeassistant.const import ATTR_HIDDEN
|
|||
from tests.common import get_test_home_assistant
|
||||
|
||||
|
||||
class TestHelpersEntity(unittest.TestCase):
|
||||
def test_generate_entity_id_requires_hass_or_ids():
|
||||
"""Ensure we require at least hass or current ids."""
|
||||
fmt = 'test.{}'
|
||||
with pytest.raises(ValueError):
|
||||
entity.generate_entity_id(fmt, 'hello world')
|
||||
|
||||
|
||||
def test_generate_entity_id_given_keys():
|
||||
"""Test generating an entity id given current ids."""
|
||||
fmt = 'test.{}'
|
||||
assert entity.generate_entity_id(
|
||||
fmt, 'overwrite hidden true', current_ids=[
|
||||
'test.overwrite_hidden_true']) == 'test.overwrite_hidden_true_2'
|
||||
assert entity.generate_entity_id(
|
||||
fmt, 'overwrite hidden true', current_ids=[
|
||||
'test.another_entity']) == 'test.overwrite_hidden_true'
|
||||
|
||||
|
||||
def test_async_update_support(event_loop):
|
||||
"""Test async update getting called."""
|
||||
sync_update = []
|
||||
async_update = []
|
||||
|
||||
class AsyncEntity(entity.Entity):
|
||||
hass = MagicMock()
|
||||
entity_id = 'sensor.test'
|
||||
|
||||
def update(self):
|
||||
sync_update.append([1])
|
||||
|
||||
ent = AsyncEntity()
|
||||
ent.hass.loop = event_loop
|
||||
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
yield from ent.async_update_ha_state(True)
|
||||
|
||||
event_loop.run_until_complete(test())
|
||||
|
||||
assert len(sync_update) == 1
|
||||
assert len(async_update) == 0
|
||||
|
||||
ent.async_update = lambda: async_update.append(1)
|
||||
|
||||
event_loop.run_until_complete(test())
|
||||
|
||||
assert len(sync_update) == 1
|
||||
assert len(async_update) == 1
|
||||
|
||||
|
||||
class TestHelpersEntity(object):
|
||||
"""Test homeassistant.helpers.entity module."""
|
||||
|
||||
def setUp(self): # pylint: disable=invalid-name
|
||||
def setup_method(self, method):
|
||||
"""Setup things to be run when tests are started."""
|
||||
self.entity = entity.Entity()
|
||||
self.entity.entity_id = 'test.overwrite_hidden_true'
|
||||
self.hass = self.entity.hass = get_test_home_assistant()
|
||||
self.entity.update_ha_state()
|
||||
|
||||
def tearDown(self): # pylint: disable=invalid-name
|
||||
def teardown_method(self, method):
|
||||
"""Stop everything that was started."""
|
||||
self.hass.stop()
|
||||
entity.set_customize({})
|
||||
self.hass.stop()
|
||||
|
||||
def test_default_hidden_not_in_attributes(self):
|
||||
"""Test that the default hidden property is set to False."""
|
||||
self.assertNotIn(
|
||||
ATTR_HIDDEN,
|
||||
self.hass.states.get(self.entity.entity_id).attributes)
|
||||
assert ATTR_HIDDEN not in self.hass.states.get(
|
||||
self.entity.entity_id).attributes
|
||||
|
||||
def test_overwriting_hidden_property_to_true(self):
|
||||
"""Test we can overwrite hidden property to True."""
|
||||
|
@ -35,31 +87,11 @@ class TestHelpersEntity(unittest.TestCase):
|
|||
self.entity.update_ha_state()
|
||||
|
||||
state = self.hass.states.get(self.entity.entity_id)
|
||||
self.assertTrue(state.attributes.get(ATTR_HIDDEN))
|
||||
|
||||
def test_generate_entity_id_requires_hass_or_ids(self):
|
||||
"""Ensure we require at least hass or current ids."""
|
||||
fmt = 'test.{}'
|
||||
with self.assertRaises(ValueError):
|
||||
entity.generate_entity_id(fmt, 'hello world')
|
||||
assert state.attributes.get(ATTR_HIDDEN)
|
||||
|
||||
def test_generate_entity_id_given_hass(self):
|
||||
"""Test generating an entity id given hass object."""
|
||||
fmt = 'test.{}'
|
||||
self.assertEqual(
|
||||
'test.overwrite_hidden_true_2',
|
||||
entity.generate_entity_id(fmt, 'overwrite hidden true',
|
||||
hass=self.hass))
|
||||
|
||||
def test_generate_entity_id_given_keys(self):
|
||||
"""Test generating an entity id given current ids."""
|
||||
fmt = 'test.{}'
|
||||
self.assertEqual(
|
||||
'test.overwrite_hidden_true_2',
|
||||
entity.generate_entity_id(
|
||||
fmt, 'overwrite hidden true',
|
||||
current_ids=['test.overwrite_hidden_true']))
|
||||
self.assertEqual(
|
||||
'test.overwrite_hidden_true',
|
||||
entity.generate_entity_id(fmt, 'overwrite hidden true',
|
||||
current_ids=['test.another_entity']))
|
||||
assert entity.generate_entity_id(
|
||||
fmt, 'overwrite hidden true',
|
||||
hass=self.hass) == 'test.overwrite_hidden_true_2'
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Test to verify that Home Assistant core works."""
|
||||
# pylint: disable=protected-access,too-many-public-methods
|
||||
# pylint: disable=too-few-public-methods
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import unittest
|
||||
|
@ -362,7 +363,6 @@ class TestServiceRegistry(unittest.TestCase):
|
|||
self.hass = get_test_home_assistant()
|
||||
self.services = self.hass.services
|
||||
self.services.register("Test_Domain", "TEST_SERVICE", lambda x: None)
|
||||
self.hass.block_till_done()
|
||||
|
||||
def tearDown(self): # pylint: disable=invalid-name
|
||||
"""Stop down stuff we started."""
|
||||
|
@ -387,8 +387,13 @@ class TestServiceRegistry(unittest.TestCase):
|
|||
def test_call_with_blocking_done_in_time(self):
|
||||
"""Test call with blocking."""
|
||||
calls = []
|
||||
|
||||
def service_handler(call):
|
||||
"""Service handler."""
|
||||
calls.append(call)
|
||||
|
||||
self.services.register("test_domain", "register_calls",
|
||||
lambda x: calls.append(1))
|
||||
service_handler)
|
||||
|
||||
self.assertTrue(
|
||||
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
|
||||
|
@ -404,6 +409,22 @@ class TestServiceRegistry(unittest.TestCase):
|
|||
finally:
|
||||
ha.SERVICE_CALL_LIMIT = prior
|
||||
|
||||
def test_async_service(self):
|
||||
"""Test registering and calling an async service."""
|
||||
calls = []
|
||||
|
||||
@asyncio.coroutine
|
||||
def service_handler(call):
|
||||
"""Service handler coroutine."""
|
||||
calls.append(call)
|
||||
|
||||
self.services.register('test_domain', 'register_calls',
|
||||
service_handler)
|
||||
self.assertTrue(
|
||||
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(1, len(calls))
|
||||
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
"""Test configuration methods."""
|
||||
|
|
Loading…
Add table
Reference in a new issue