RFC: Call services directly (#18720)

* Call services directly

* Simplify

* Type

* Lint

* Update name

* Fix tests

* Catch exceptions in HTTP view

* Lint

* Handle ServiceNotFound in API endpoints that call services

* Type

* Don't crash recorder on non-JSON serializable objects
This commit is contained in:
Paulus Schoutsen 2018-11-30 21:28:35 +01:00 committed by Pascal Vizeli
parent 53cbb28926
commit df21dd21f2
30 changed files with 312 additions and 186 deletions

View file

@ -11,6 +11,7 @@ import voluptuous as vol
from homeassistant.const import CONF_EXCLUDE, CONF_INCLUDE
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ServiceNotFound
from homeassistant.helpers import config_validation as cv
from . import MultiFactorAuthModule, MULTI_FACTOR_AUTH_MODULES, \
@ -314,8 +315,11 @@ class NotifySetupFlow(SetupFlow):
_generate_otp, self._secret, self._count)
assert self._notify_service
try:
await self._auth_module.async_notify(
code, self._notify_service, self._target)
except ServiceNotFound:
return self.async_abort(reason='notify_service_not_exist')
return self.async_show_form(
step_id='setup',

View file

@ -226,7 +226,11 @@ class LoginFlow(data_entry_flow.FlowHandler):
if user_input is None and hasattr(auth_module,
'async_initialize_login_mfa_step'):
try:
await auth_module.async_initialize_login_mfa_step(self.user.id)
except HomeAssistantError:
_LOGGER.exception('Error initializing MFA step')
return self.async_abort(reason='unknown_error')
if user_input is not None:
expires = self.created_at + MFA_SESSION_EXPIRATION

View file

@ -9,7 +9,9 @@ import json
import logging
from aiohttp import web
from aiohttp.web_exceptions import HTTPBadRequest
import async_timeout
import voluptuous as vol
from homeassistant.bootstrap import DATA_LOGGING
from homeassistant.components.http import HomeAssistantView
@ -21,7 +23,8 @@ from homeassistant.const import (
URL_API_TEMPLATE, __version__)
import homeassistant.core as ha
from homeassistant.auth.permissions.const import POLICY_READ
from homeassistant.exceptions import TemplateError, Unauthorized
from homeassistant.exceptions import (
TemplateError, Unauthorized, ServiceNotFound)
from homeassistant.helpers import template
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.helpers.state import AsyncTrackStates
@ -339,8 +342,11 @@ class APIDomainServicesView(HomeAssistantView):
"Data should be valid JSON.", HTTP_BAD_REQUEST)
with AsyncTrackStates(hass) as changed_states:
try:
await hass.services.async_call(
domain, service, data, True, self.context(request))
except (vol.Invalid, ServiceNotFound):
raise HTTPBadRequest()
return self.json(changed_states)

View file

@ -9,7 +9,9 @@ import json
import logging
from aiohttp import web
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPInternalServerError
from aiohttp.web_exceptions import (
HTTPUnauthorized, HTTPInternalServerError, HTTPBadRequest)
import voluptuous as vol
from homeassistant.components.http.ban import process_success_login
from homeassistant.core import Context, is_callback
@ -114,6 +116,10 @@ def request_handler_factory(view, handler):
if asyncio.iscoroutine(result):
result = await result
except vol.Invalid:
raise HTTPBadRequest()
except exceptions.ServiceNotFound:
raise HTTPInternalServerError()
except exceptions.Unauthorized:
raise HTTPUnauthorized()

View file

@ -13,7 +13,7 @@ from homeassistant.core import callback
from homeassistant.components.mqtt import (
valid_publish_topic, valid_subscribe_topic)
from homeassistant.const import (
ATTR_SERVICE_DATA, EVENT_CALL_SERVICE, EVENT_SERVICE_EXECUTED,
ATTR_SERVICE_DATA, EVENT_CALL_SERVICE,
EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL)
from homeassistant.core import EventOrigin, State
import homeassistant.helpers.config_validation as cv
@ -69,16 +69,6 @@ def async_setup(hass, config):
):
return
# Filter out all the "event service executed" events because they
# are only used internally by core as callbacks for blocking
# during the interval while a service is being executed.
# They will serve no purpose to the external system,
# and thus are unnecessary traffic.
# And at any rate it would cause an infinite loop to publish them
# because publishing to an MQTT topic itself triggers one.
if event.event_type == EVENT_SERVICE_EXECUTED:
return
event_info = {'event_type': event.event_type, 'event_data': event.data}
msg = json.dumps(event_info, cls=JSONEncoder)
mqtt.async_publish(pub_topic, msg)

View file

@ -300,14 +300,24 @@ class Recorder(threading.Thread):
time.sleep(CONNECT_RETRY_WAIT)
try:
with session_scope(session=self.get_session()) as session:
try:
dbevent = Events.from_event(event)
session.add(dbevent)
session.flush()
except (TypeError, ValueError):
_LOGGER.warning(
"Event is not JSON serializable: %s", event)
if event.event_type == EVENT_STATE_CHANGED:
try:
dbstate = States.from_event(event)
dbstate.event_id = dbevent.event_id
session.add(dbstate)
except (TypeError, ValueError):
_LOGGER.warning(
"State is not JSON serializable: %s",
event.data.get('new_state'))
updated = True
except exc.OperationalError as err:

View file

@ -3,7 +3,7 @@ import voluptuous as vol
from homeassistant.const import MATCH_ALL, EVENT_TIME_CHANGED
from homeassistant.core import callback, DOMAIN as HASS_DOMAIN
from homeassistant.exceptions import Unauthorized
from homeassistant.exceptions import Unauthorized, ServiceNotFound
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.service import async_get_all_descriptions
@ -141,10 +141,15 @@ async def handle_call_service(hass, connection, msg):
if (msg['domain'] == HASS_DOMAIN and
msg['service'] in ['restart', 'stop']):
blocking = False
try:
await hass.services.async_call(
msg['domain'], msg['service'], msg.get('service_data'), blocking,
connection.context(msg))
connection.send_message(messages.result_message(msg['id']))
except ServiceNotFound:
connection.send_message(messages.error_message(
msg['id'], const.ERR_NOT_FOUND, 'Service not found.'))
@callback

View file

@ -163,7 +163,6 @@ EVENT_HOMEASSISTANT_CLOSE = 'homeassistant_close'
EVENT_STATE_CHANGED = 'state_changed'
EVENT_TIME_CHANGED = 'time_changed'
EVENT_CALL_SERVICE = 'call_service'
EVENT_SERVICE_EXECUTED = 'service_executed'
EVENT_PLATFORM_DISCOVERED = 'platform_discovered'
EVENT_COMPONENT_LOADED = 'component_loaded'
EVENT_SERVICE_REGISTERED = 'service_registered'
@ -233,9 +232,6 @@ ATTR_ID = 'id'
# Name
ATTR_NAME = 'name'
# Data for a SERVICE_EXECUTED event
ATTR_SERVICE_CALL_ID = 'service_call_id'
# Contains one string or a list of strings, each being an entity id
ATTR_ENTITY_ID = 'entity_id'

View file

@ -25,18 +25,18 @@ from typing import ( # noqa: F401 pylint: disable=unused-import
from async_timeout import timeout
import attr
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.const import (
ATTR_DOMAIN, ATTR_FRIENDLY_NAME, ATTR_NOW, ATTR_SERVICE,
ATTR_SERVICE_CALL_ID, ATTR_SERVICE_DATA, ATTR_SECONDS, EVENT_CALL_SERVICE,
ATTR_SERVICE_DATA, ATTR_SECONDS, EVENT_CALL_SERVICE,
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
EVENT_HOMEASSISTANT_CLOSE, EVENT_SERVICE_REMOVED,
EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
EVENT_TIME_CHANGED, EVENT_TIMER_OUT_OF_SYNC, MATCH_ALL, __version__)
from homeassistant import loader
from homeassistant.exceptions import (
HomeAssistantError, InvalidEntityFormatError, InvalidStateError)
HomeAssistantError, InvalidEntityFormatError, InvalidStateError,
Unauthorized, ServiceNotFound)
from homeassistant.util.async_ import (
run_coroutine_threadsafe, run_callback_threadsafe,
fire_coroutine_threadsafe)
@ -954,7 +954,6 @@ class ServiceRegistry:
"""Initialize a service registry."""
self._services = {} # type: Dict[str, Dict[str, Service]]
self._hass = hass
self._async_unsub_call_event = None # type: Optional[CALLBACK_TYPE]
@property
def services(self) -> Dict[str, Dict[str, Service]]:
@ -1010,10 +1009,6 @@ class ServiceRegistry:
else:
self._services[domain] = {service: service_obj}
if self._async_unsub_call_event is None:
self._async_unsub_call_event = self._hass.bus.async_listen(
EVENT_CALL_SERVICE, self._event_to_service_call)
self._hass.bus.async_fire(
EVENT_SERVICE_REGISTERED,
{ATTR_DOMAIN: domain, ATTR_SERVICE: service}
@ -1092,100 +1087,61 @@ class ServiceRegistry:
This method is a coroutine.
"""
domain = domain.lower()
service = service.lower()
context = context or Context()
call_id = uuid.uuid4().hex
event_data = {
service_data = service_data or {}
try:
handler = self._services[domain][service]
except KeyError:
raise ServiceNotFound(domain, service) from None
if handler.schema:
service_data = handler.schema(service_data)
service_call = ServiceCall(domain, service, service_data, context)
self._hass.bus.async_fire(EVENT_CALL_SERVICE, {
ATTR_DOMAIN: domain.lower(),
ATTR_SERVICE: service.lower(),
ATTR_SERVICE_DATA: service_data,
ATTR_SERVICE_CALL_ID: call_id,
}
})
if not blocking:
self._hass.bus.async_fire(
EVENT_CALL_SERVICE, event_data, EventOrigin.local, context)
self._hass.async_create_task(
self._safe_execute(handler, service_call))
return None
fut = asyncio.Future() # type: asyncio.Future
@callback
def service_executed(event: Event) -> None:
"""Handle an executed service."""
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
fut.set_result(True)
unsub()
unsub = self._hass.bus.async_listen(
EVENT_SERVICE_EXECUTED, service_executed)
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data,
EventOrigin.local, context)
done, _ = await asyncio.wait([fut], timeout=SERVICE_CALL_LIMIT)
success = bool(done)
if not success:
unsub()
return success
async def _event_to_service_call(self, event: Event) -> None:
"""Handle the SERVICE_CALLED events from the EventBus."""
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore
service = event.data.get(ATTR_SERVICE).lower() # type: ignore
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
if not self.has_service(domain, service):
if event.origin == EventOrigin.local:
_LOGGER.warning("Unable to find service %s/%s",
domain, service)
return
service_handler = self._services[domain][service]
def fire_service_executed() -> None:
"""Fire service executed event."""
if not call_id:
return
data = {ATTR_SERVICE_CALL_ID: call_id}
if (service_handler.is_coroutinefunction or
service_handler.is_callback):
self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, data,
EventOrigin.local, event.context)
else:
self._hass.bus.fire(EVENT_SERVICE_EXECUTED, data,
EventOrigin.local, event.context)
try:
if service_handler.schema:
service_data = service_handler.schema(service_data)
except vol.Invalid as ex:
_LOGGER.error("Invalid service data for %s.%s: %s",
domain, service, humanize_error(service_data, ex))
fire_service_executed()
return
service_call = ServiceCall(
domain, service, service_data, event.context)
with timeout(SERVICE_CALL_LIMIT):
await asyncio.shield(
self._execute_service(handler, service_call))
return True
except asyncio.TimeoutError:
return False
async def _safe_execute(self, handler: Service,
service_call: ServiceCall) -> None:
"""Execute a service and catch exceptions."""
try:
if service_handler.is_callback:
service_handler.func(service_call)
fire_service_executed()
elif service_handler.is_coroutinefunction:
await service_handler.func(service_call)
fire_service_executed()
else:
def execute_service() -> None:
"""Execute a service and fires a SERVICE_EXECUTED event."""
service_handler.func(service_call)
fire_service_executed()
await self._hass.async_add_executor_job(execute_service)
await self._execute_service(handler, service_call)
except Unauthorized:
_LOGGER.warning('Unauthorized service called %s/%s',
service_call.domain, service_call.service)
except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error executing service %s', service_call)
async def _execute_service(self, handler: Service,
service_call: ServiceCall) -> None:
"""Execute a service."""
if handler.is_callback:
handler.func(service_call)
elif handler.is_coroutinefunction:
await handler.func(service_call)
else:
await self._hass.async_add_executor_job(handler.func, service_call)
class Config:
"""Configuration settings for Home Assistant."""

View file

@ -58,3 +58,14 @@ class Unauthorized(HomeAssistantError):
class UnknownUser(Unauthorized):
"""When call is made with user ID that doesn't exist."""
class ServiceNotFound(HomeAssistantError):
"""Raised when a service is not found."""
def __init__(self, domain: str, service: str) -> None:
"""Initialize error."""
super().__init__(
self, "Service {}.{} not found".format(domain, service))
self.domain = domain
self.service = service

View file

@ -61,6 +61,7 @@ async def test_validating_mfa_counter(hass):
'counter': 0,
'notify_service': 'dummy',
})
async_mock_service(hass, 'notify', 'dummy')
assert notify_auth_module._user_settings
notify_setting = list(notify_auth_module._user_settings.values())[0]
@ -389,9 +390,8 @@ async def test_not_raise_exception_when_service_not_exist(hass):
'username': 'test-user',
'password': 'test-pass',
})
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['step_id'] == 'mfa'
assert result['data_schema'].schema.get('code') == str
assert result['type'] == data_entry_flow.RESULT_TYPE_ABORT
assert result['reason'] == 'unknown_error'
# wait service call finished
await hass.async_block_till_done()

View file

@ -1,6 +1,9 @@
"""The tests for the demo climate component."""
import unittest
import pytest
import voluptuous as vol
from homeassistant.util.unit_system import (
METRIC_SYSTEM
)
@ -57,6 +60,7 @@ class TestDemoClimate(unittest.TestCase):
"""Test setting the target temperature without required attribute."""
state = self.hass.states.get(ENTITY_CLIMATE)
assert 21 == state.attributes.get('temperature')
with pytest.raises(vol.Invalid):
common.set_temperature(self.hass, None, ENTITY_CLIMATE)
self.hass.block_till_done()
assert 21 == state.attributes.get('temperature')
@ -99,8 +103,10 @@ class TestDemoClimate(unittest.TestCase):
assert state.attributes.get('temperature') is None
assert 21.0 == state.attributes.get('target_temp_low')
assert 24.0 == state.attributes.get('target_temp_high')
with pytest.raises(vol.Invalid):
common.set_temperature(self.hass, temperature=None,
entity_id=ENTITY_ECOBEE, target_temp_low=None,
entity_id=ENTITY_ECOBEE,
target_temp_low=None,
target_temp_high=None)
self.hass.block_till_done()
state = self.hass.states.get(ENTITY_ECOBEE)
@ -112,6 +118,7 @@ class TestDemoClimate(unittest.TestCase):
"""Test setting the target humidity without required attribute."""
state = self.hass.states.get(ENTITY_CLIMATE)
assert 67 == state.attributes.get('humidity')
with pytest.raises(vol.Invalid):
common.set_humidity(self.hass, None, ENTITY_CLIMATE)
self.hass.block_till_done()
state = self.hass.states.get(ENTITY_CLIMATE)
@ -130,6 +137,7 @@ class TestDemoClimate(unittest.TestCase):
"""Test setting fan mode without required attribute."""
state = self.hass.states.get(ENTITY_CLIMATE)
assert "On High" == state.attributes.get('fan_mode')
with pytest.raises(vol.Invalid):
common.set_fan_mode(self.hass, None, ENTITY_CLIMATE)
self.hass.block_till_done()
state = self.hass.states.get(ENTITY_CLIMATE)
@ -148,6 +156,7 @@ class TestDemoClimate(unittest.TestCase):
"""Test setting swing mode without required attribute."""
state = self.hass.states.get(ENTITY_CLIMATE)
assert "Off" == state.attributes.get('swing_mode')
with pytest.raises(vol.Invalid):
common.set_swing_mode(self.hass, None, ENTITY_CLIMATE)
self.hass.block_till_done()
state = self.hass.states.get(ENTITY_CLIMATE)
@ -170,6 +179,7 @@ class TestDemoClimate(unittest.TestCase):
state = self.hass.states.get(ENTITY_CLIMATE)
assert "cool" == state.attributes.get('operation_mode')
assert "cool" == state.state
with pytest.raises(vol.Invalid):
common.set_operation_mode(self.hass, None, ENTITY_CLIMATE)
self.hass.block_till_done()
state = self.hass.states.get(ENTITY_CLIMATE)

View file

@ -1,6 +1,9 @@
"""The tests for the climate component."""
import asyncio
import pytest
import voluptuous as vol
from homeassistant.components.climate import SET_TEMPERATURE_SCHEMA
from tests.common import async_mock_service
@ -14,12 +17,11 @@ def test_set_temp_schema_no_req(hass, caplog):
calls = async_mock_service(hass, domain, service, schema)
data = {'operation_mode': 'test', 'entity_id': ['climate.test_id']}
with pytest.raises(vol.Invalid):
yield from hass.services.async_call(domain, service, data)
yield from hass.async_block_till_done()
assert len(calls) == 0
assert 'ERROR' in caplog.text
assert 'Invalid service data' in caplog.text
@asyncio.coroutine

View file

@ -2,6 +2,9 @@
import unittest
import copy
import pytest
import voluptuous as vol
from homeassistant.util.unit_system import (
METRIC_SYSTEM
)
@ -91,6 +94,7 @@ class TestMQTTClimate(unittest.TestCase):
state = self.hass.states.get(ENTITY_CLIMATE)
assert "off" == state.attributes.get('operation_mode')
assert "off" == state.state
with pytest.raises(vol.Invalid):
common.set_operation_mode(self.hass, None, ENTITY_CLIMATE)
self.hass.block_till_done()
state = self.hass.states.get(ENTITY_CLIMATE)
@ -177,6 +181,7 @@ class TestMQTTClimate(unittest.TestCase):
state = self.hass.states.get(ENTITY_CLIMATE)
assert "low" == state.attributes.get('fan_mode')
with pytest.raises(vol.Invalid):
common.set_fan_mode(self.hass, None, ENTITY_CLIMATE)
self.hass.block_till_done()
state = self.hass.states.get(ENTITY_CLIMATE)
@ -225,6 +230,7 @@ class TestMQTTClimate(unittest.TestCase):
state = self.hass.states.get(ENTITY_CLIMATE)
assert "off" == state.attributes.get('swing_mode')
with pytest.raises(vol.Invalid):
common.set_swing_mode(self.hass, None, ENTITY_CLIMATE)
self.hass.block_till_done()
state = self.hass.states.get(ENTITY_CLIMATE)

View file

@ -1,6 +1,9 @@
"""Test deCONZ component setup process."""
from unittest.mock import Mock, patch
import pytest
import voluptuous as vol
from homeassistant.setup import async_setup_component
from homeassistant.components import deconz
@ -163,9 +166,11 @@ async def test_service_configure(hass):
await hass.async_block_till_done()
# field does not start with /
with pytest.raises(vol.Invalid):
with patch('pydeconz.DeconzSession.async_put_state',
return_value=mock_coro(True)):
await hass.services.async_call('deconz', 'configure', service_data={
await hass.services.async_call(
'deconz', 'configure', service_data={
'entity': 'light.test', 'field': 'state', 'data': data})
await hass.async_block_till_done()

View file

@ -1,8 +1,25 @@
"""Tests for Home Assistant View."""
from aiohttp.web_exceptions import HTTPInternalServerError
import pytest
from unittest.mock import Mock
from homeassistant.components.http.view import HomeAssistantView
from aiohttp.web_exceptions import (
HTTPInternalServerError, HTTPBadRequest, HTTPUnauthorized)
import pytest
import voluptuous as vol
from homeassistant.components.http.view import (
HomeAssistantView, request_handler_factory)
from homeassistant.exceptions import ServiceNotFound, Unauthorized
from tests.common import mock_coro_func
@pytest.fixture
def mock_request():
"""Mock a request."""
return Mock(
app={'hass': Mock(is_running=True)},
match_info={},
)
async def test_invalid_json(caplog):
@ -13,3 +30,30 @@ async def test_invalid_json(caplog):
view.json(float("NaN"))
assert str(float("NaN")) in caplog.text
async def test_handling_unauthorized(mock_request):
"""Test handling unauth exceptions."""
with pytest.raises(HTTPUnauthorized):
await request_handler_factory(
Mock(requires_auth=False),
mock_coro_func(exception=Unauthorized)
)(mock_request)
async def test_handling_invalid_data(mock_request):
"""Test handling unauth exceptions."""
with pytest.raises(HTTPBadRequest):
await request_handler_factory(
Mock(requires_auth=False),
mock_coro_func(exception=vol.Invalid('yo'))
)(mock_request)
async def test_handling_service_not_found(mock_request):
"""Test handling unauth exceptions."""
with pytest.raises(HTTPInternalServerError):
await request_handler_factory(
Mock(requires_auth=False),
mock_coro_func(exception=ServiceNotFound('test', 'test'))
)(mock_request)

View file

@ -3,6 +3,9 @@ import unittest
from unittest.mock import patch
import asyncio
import pytest
import voluptuous as vol
from homeassistant.setup import setup_component
from homeassistant.const import HTTP_HEADER_HA_AUTH
import homeassistant.components.media_player as mp
@ -43,6 +46,7 @@ class TestDemoMediaPlayer(unittest.TestCase):
state = self.hass.states.get(entity_id)
assert 'dvd' == state.attributes.get('source')
with pytest.raises(vol.Invalid):
common.select_source(self.hass, None, entity_id)
self.hass.block_till_done()
state = self.hass.states.get(entity_id)
@ -72,6 +76,7 @@ class TestDemoMediaPlayer(unittest.TestCase):
state = self.hass.states.get(entity_id)
assert 1.0 == state.attributes.get('volume_level')
with pytest.raises(vol.Invalid):
common.set_volume_level(self.hass, None, entity_id)
self.hass.block_till_done()
state = self.hass.states.get(entity_id)
@ -201,6 +206,7 @@ class TestDemoMediaPlayer(unittest.TestCase):
state.attributes.get('supported_features'))
assert state.attributes.get('media_content_id') is not None
with pytest.raises(vol.Invalid):
common.play_media(self.hass, None, 'some_id', ent_id)
self.hass.block_till_done()
state = self.hass.states.get(ent_id)
@ -216,6 +222,7 @@ class TestDemoMediaPlayer(unittest.TestCase):
assert 'some_id' == state.attributes.get('media_content_id')
assert not mock_seek.called
with pytest.raises(vol.Invalid):
common.media_seek(self.hass, None, ent_id)
self.hass.block_till_done()
assert not mock_seek.called

View file

@ -223,7 +223,7 @@ class TestMonopriceMediaPlayer(unittest.TestCase):
# Restoring wrong media player to its previous state
# Nothing should be done
self.hass.services.call(DOMAIN, SERVICE_RESTORE,
{'entity_id': 'not_existing'},
{'entity_id': 'media.not_existing'},
blocking=True)
# self.hass.block_till_done()

View file

@ -113,6 +113,7 @@ class TestMQTTComponent(unittest.TestCase):
"""
payload = "not a template"
payload_template = "a template"
with pytest.raises(vol.Invalid):
self.hass.services.call(mqtt.DOMAIN, mqtt.SERVICE_PUBLISH, {
mqtt.ATTR_TOPIC: "test/topic",
mqtt.ATTR_PAYLOAD: payload,

View file

@ -2,6 +2,9 @@
import unittest
from unittest.mock import patch
import pytest
import voluptuous as vol
import homeassistant.components.notify as notify
from homeassistant.setup import setup_component
from homeassistant.components.notify import demo
@ -81,6 +84,7 @@ class TestNotifyDemo(unittest.TestCase):
def test_sending_none_message(self):
"""Test send with None as message."""
self._setup_notify()
with pytest.raises(vol.Invalid):
common.send_message(self.hass, None)
self.hass.block_till_done()
assert len(self.events) == 0

View file

@ -99,6 +99,7 @@ class TestAlert(unittest.TestCase):
def setUp(self):
"""Set up things to be run when tests are started."""
self.hass = get_test_home_assistant()
self._setup_notify()
def tearDown(self):
"""Stop everything that was started."""

View file

@ -6,6 +6,7 @@ from unittest.mock import patch
from aiohttp import web
import pytest
import voluptuous as vol
from homeassistant import const
from homeassistant.bootstrap import DATA_LOGGING
@ -578,3 +579,29 @@ async def test_rendering_template_legacy_user(
json={"template": '{{ states.sensor.temperature.state }}'}
)
assert resp.status == 401
async def test_api_call_service_not_found(hass, mock_api_client):
"""Test if the API failes 400 if unknown service."""
resp = await mock_api_client.post(
const.URL_API_SERVICES_SERVICE.format(
"test_domain", "test_service"))
assert resp.status == 400
async def test_api_call_service_bad_data(hass, mock_api_client):
"""Test if the API failes 400 if unknown service."""
test_value = []
@ha.callback
def listener(service_call):
"""Record that our service got called."""
test_value.append(1)
hass.services.async_register("test_domain", "test_service", listener,
schema=vol.Schema({'hello': str}))
resp = await mock_api_client.post(
const.URL_API_SERVICES_SERVICE.format(
"test_domain", "test_service"), json={'hello': 5})
assert resp.status == 400

View file

@ -3,6 +3,9 @@
import asyncio
import datetime
import pytest
import voluptuous as vol
from homeassistant.core import CoreState, State, Context
from homeassistant.setup import async_setup_component
from homeassistant.components.input_datetime import (
@ -109,6 +112,7 @@ def test_set_invalid(hass):
dt_obj = datetime.datetime(2017, 9, 7, 19, 46)
time_portion = dt_obj.time()
with pytest.raises(vol.Invalid):
yield from hass.services.async_call('input_datetime', 'set_datetime', {
'entity_id': 'test_date',
'time': time_portion

View file

@ -4,6 +4,9 @@ import logging
from datetime import (timedelta, datetime)
import unittest
import pytest
import voluptuous as vol
from homeassistant.components import sun
import homeassistant.core as ha
from homeassistant.const import (
@ -89,6 +92,8 @@ class TestComponentLogbook(unittest.TestCase):
calls.append(event)
self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener)
with pytest.raises(vol.Invalid):
self.hass.services.call(logbook.DOMAIN, 'log', {}, True)
# Logbook entry service call results in firing an event.

View file

@ -2,6 +2,9 @@
import json
import logging
import pytest
import voluptuous as vol
from homeassistant.bootstrap import async_setup_component
from homeassistant.components.mqtt import MQTT_PUBLISH_SCHEMA
import homeassistant.components.snips as snips
@ -452,12 +455,11 @@ async def test_snips_say_invalid_config(hass, caplog):
snips.SERVICE_SCHEMA_SAY)
data = {'text': 'Hello', 'badKey': 'boo'}
with pytest.raises(vol.Invalid):
await hass.services.async_call('snips', 'say', data)
await hass.async_block_till_done()
assert len(calls) == 0
assert 'ERROR' in caplog.text
assert 'Invalid service data' in caplog.text
async def test_snips_say_action_invalid(hass, caplog):
@ -466,12 +468,12 @@ async def test_snips_say_action_invalid(hass, caplog):
snips.SERVICE_SCHEMA_SAY_ACTION)
data = {'text': 'Hello', 'can_be_enqueued': 'notabool'}
with pytest.raises(vol.Invalid):
await hass.services.async_call('snips', 'say_action', data)
await hass.async_block_till_done()
assert len(calls) == 0
assert 'ERROR' in caplog.text
assert 'Invalid service data' in caplog.text
async def test_snips_feedback_on(hass, caplog):
@ -510,6 +512,7 @@ async def test_snips_feedback_config(hass, caplog):
snips.SERVICE_SCHEMA_FEEDBACK)
data = {'site_id': 'remote', 'test': 'test'}
with pytest.raises(vol.Invalid):
await hass.services.async_call('snips', 'feedback_on', data)
await hass.async_block_till_done()

View file

@ -3,6 +3,7 @@ import asyncio
from unittest import mock
import pytest
import voluptuous as vol
from homeassistant.setup import async_setup_component
from homeassistant.components.wake_on_lan import (
@ -34,10 +35,10 @@ def test_send_magic_packet(hass, caplog, mock_wakeonlan):
assert mock_wakeonlan.mock_calls[-1][1][0] == mac
assert mock_wakeonlan.mock_calls[-1][2]['ip_address'] == bc_ip
with pytest.raises(vol.Invalid):
yield from hass.services.async_call(
DOMAIN, SERVICE_SEND_MAGIC_PACKET,
{"broadcast_address": bc_ip}, blocking=True)
assert 'ERROR' in caplog.text
assert len(mock_wakeonlan.mock_calls) == 1
yield from hass.services.async_call(

View file

@ -1,6 +1,9 @@
"""The tests for the demo water_heater component."""
import unittest
import pytest
import voluptuous as vol
from homeassistant.util.unit_system import (
IMPERIAL_SYSTEM
)
@ -48,6 +51,7 @@ class TestDemowater_heater(unittest.TestCase):
"""Test setting the target temperature without required attribute."""
state = self.hass.states.get(ENTITY_WATER_HEATER)
assert 119 == state.attributes.get('temperature')
with pytest.raises(vol.Invalid):
common.set_temperature(self.hass, None, ENTITY_WATER_HEATER)
self.hass.block_till_done()
assert 119 == state.attributes.get('temperature')
@ -69,6 +73,7 @@ class TestDemowater_heater(unittest.TestCase):
state = self.hass.states.get(ENTITY_WATER_HEATER)
assert "eco" == state.attributes.get('operation_mode')
assert "eco" == state.state
with pytest.raises(vol.Invalid):
common.set_operation_mode(self.hass, None, ENTITY_WATER_HEATER)
self.hass.block_till_done()
state = self.hass.states.get(ENTITY_WATER_HEATER)

View file

@ -49,6 +49,25 @@ async def test_call_service(hass, websocket_client):
assert call.data == {'hello': 'world'}
async def test_call_service_not_found(hass, websocket_client):
"""Test call service command."""
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == const.TYPE_RESULT
assert not msg['success']
assert msg['error']['code'] == const.ERR_NOT_FOUND
async def test_subscribe_unsubscribe_events(hass, websocket_client):
"""Test subscribe/unsubscribe events command."""
init_count = sum(hass.bus.async_listeners().values())

View file

@ -947,7 +947,7 @@ class TestZWaveServices(unittest.TestCase):
assert self.zwave_network.stop.called
assert len(self.zwave_network.stop.mock_calls) == 1
assert mock_fire.called
assert len(mock_fire.mock_calls) == 2
assert len(mock_fire.mock_calls) == 1
assert mock_fire.mock_calls[0][1][0] == const.EVENT_NETWORK_STOP
def test_rename_node(self):

View file

@ -21,7 +21,7 @@ from homeassistant.const import (
__version__, EVENT_STATE_CHANGED, ATTR_FRIENDLY_NAME, CONF_UNIT_SYSTEM,
ATTR_NOW, EVENT_TIME_CHANGED, EVENT_TIMER_OUT_OF_SYNC, ATTR_SECONDS,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_CLOSE,
EVENT_SERVICE_REGISTERED, EVENT_SERVICE_REMOVED, EVENT_SERVICE_EXECUTED)
EVENT_SERVICE_REGISTERED, EVENT_SERVICE_REMOVED)
from tests.common import get_test_home_assistant, async_mock_service
@ -673,13 +673,8 @@ class TestServiceRegistry(unittest.TestCase):
def test_call_non_existing_with_blocking(self):
"""Test non-existing with blocking."""
prior = ha.SERVICE_CALL_LIMIT
try:
ha.SERVICE_CALL_LIMIT = 0.01
assert not self.services.call('test_domain', 'i_do_not_exist',
blocking=True)
finally:
ha.SERVICE_CALL_LIMIT = prior
with pytest.raises(ha.ServiceNotFound):
self.services.call('test_domain', 'i_do_not_exist', blocking=True)
def test_async_service(self):
"""Test registering and calling an async service."""
@ -1005,4 +1000,3 @@ async def test_service_executed_with_subservices(hass):
assert len(calls) == 4
assert [call.service for call in calls] == [
'outer', 'inner', 'inner', 'outer']
assert len(hass.bus.async_listeners().get(EVENT_SERVICE_EXECUTED, [])) == 0