RFC: Call services directly (#18720)
* Call services directly * Simplify * Type * Lint * Update name * Fix tests * Catch exceptions in HTTP view * Lint * Handle ServiceNotFound in API endpoints that call services * Type * Don't crash recorder on non-JSON serializable objects
This commit is contained in:
parent
53cbb28926
commit
df21dd21f2
30 changed files with 312 additions and 186 deletions
|
@ -11,6 +11,7 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.const import CONF_EXCLUDE, CONF_INCLUDE
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import ServiceNotFound
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
|
||||
from . import MultiFactorAuthModule, MULTI_FACTOR_AUTH_MODULES, \
|
||||
|
@ -314,8 +315,11 @@ class NotifySetupFlow(SetupFlow):
|
|||
_generate_otp, self._secret, self._count)
|
||||
|
||||
assert self._notify_service
|
||||
try:
|
||||
await self._auth_module.async_notify(
|
||||
code, self._notify_service, self._target)
|
||||
except ServiceNotFound:
|
||||
return self.async_abort(reason='notify_service_not_exist')
|
||||
|
||||
return self.async_show_form(
|
||||
step_id='setup',
|
||||
|
|
|
@ -226,7 +226,11 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
|
||||
if user_input is None and hasattr(auth_module,
|
||||
'async_initialize_login_mfa_step'):
|
||||
try:
|
||||
await auth_module.async_initialize_login_mfa_step(self.user.id)
|
||||
except HomeAssistantError:
|
||||
_LOGGER.exception('Error initializing MFA step')
|
||||
return self.async_abort(reason='unknown_error')
|
||||
|
||||
if user_input is not None:
|
||||
expires = self.created_at + MFA_SESSION_EXPIRATION
|
||||
|
|
|
@ -9,7 +9,9 @@ import json
|
|||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.web_exceptions import HTTPBadRequest
|
||||
import async_timeout
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.bootstrap import DATA_LOGGING
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
|
@ -21,7 +23,8 @@ from homeassistant.const import (
|
|||
URL_API_TEMPLATE, __version__)
|
||||
import homeassistant.core as ha
|
||||
from homeassistant.auth.permissions.const import POLICY_READ
|
||||
from homeassistant.exceptions import TemplateError, Unauthorized
|
||||
from homeassistant.exceptions import (
|
||||
TemplateError, Unauthorized, ServiceNotFound)
|
||||
from homeassistant.helpers import template
|
||||
from homeassistant.helpers.service import async_get_all_descriptions
|
||||
from homeassistant.helpers.state import AsyncTrackStates
|
||||
|
@ -339,8 +342,11 @@ class APIDomainServicesView(HomeAssistantView):
|
|||
"Data should be valid JSON.", HTTP_BAD_REQUEST)
|
||||
|
||||
with AsyncTrackStates(hass) as changed_states:
|
||||
try:
|
||||
await hass.services.async_call(
|
||||
domain, service, data, True, self.context(request))
|
||||
except (vol.Invalid, ServiceNotFound):
|
||||
raise HTTPBadRequest()
|
||||
|
||||
return self.json(changed_states)
|
||||
|
||||
|
|
|
@ -9,7 +9,9 @@ import json
|
|||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPInternalServerError
|
||||
from aiohttp.web_exceptions import (
|
||||
HTTPUnauthorized, HTTPInternalServerError, HTTPBadRequest)
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.http.ban import process_success_login
|
||||
from homeassistant.core import Context, is_callback
|
||||
|
@ -114,6 +116,10 @@ def request_handler_factory(view, handler):
|
|||
|
||||
if asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
except vol.Invalid:
|
||||
raise HTTPBadRequest()
|
||||
except exceptions.ServiceNotFound:
|
||||
raise HTTPInternalServerError()
|
||||
except exceptions.Unauthorized:
|
||||
raise HTTPUnauthorized()
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from homeassistant.core import callback
|
|||
from homeassistant.components.mqtt import (
|
||||
valid_publish_topic, valid_subscribe_topic)
|
||||
from homeassistant.const import (
|
||||
ATTR_SERVICE_DATA, EVENT_CALL_SERVICE, EVENT_SERVICE_EXECUTED,
|
||||
ATTR_SERVICE_DATA, EVENT_CALL_SERVICE,
|
||||
EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL)
|
||||
from homeassistant.core import EventOrigin, State
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
@ -69,16 +69,6 @@ def async_setup(hass, config):
|
|||
):
|
||||
return
|
||||
|
||||
# Filter out all the "event service executed" events because they
|
||||
# are only used internally by core as callbacks for blocking
|
||||
# during the interval while a service is being executed.
|
||||
# They will serve no purpose to the external system,
|
||||
# and thus are unnecessary traffic.
|
||||
# And at any rate it would cause an infinite loop to publish them
|
||||
# because publishing to an MQTT topic itself triggers one.
|
||||
if event.event_type == EVENT_SERVICE_EXECUTED:
|
||||
return
|
||||
|
||||
event_info = {'event_type': event.event_type, 'event_data': event.data}
|
||||
msg = json.dumps(event_info, cls=JSONEncoder)
|
||||
mqtt.async_publish(pub_topic, msg)
|
||||
|
|
|
@ -300,14 +300,24 @@ class Recorder(threading.Thread):
|
|||
time.sleep(CONNECT_RETRY_WAIT)
|
||||
try:
|
||||
with session_scope(session=self.get_session()) as session:
|
||||
try:
|
||||
dbevent = Events.from_event(event)
|
||||
session.add(dbevent)
|
||||
session.flush()
|
||||
except (TypeError, ValueError):
|
||||
_LOGGER.warning(
|
||||
"Event is not JSON serializable: %s", event)
|
||||
|
||||
if event.event_type == EVENT_STATE_CHANGED:
|
||||
try:
|
||||
dbstate = States.from_event(event)
|
||||
dbstate.event_id = dbevent.event_id
|
||||
session.add(dbstate)
|
||||
except (TypeError, ValueError):
|
||||
_LOGGER.warning(
|
||||
"State is not JSON serializable: %s",
|
||||
event.data.get('new_state'))
|
||||
|
||||
updated = True
|
||||
|
||||
except exc.OperationalError as err:
|
||||
|
|
|
@ -3,7 +3,7 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.const import MATCH_ALL, EVENT_TIME_CHANGED
|
||||
from homeassistant.core import callback, DOMAIN as HASS_DOMAIN
|
||||
from homeassistant.exceptions import Unauthorized
|
||||
from homeassistant.exceptions import Unauthorized, ServiceNotFound
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.service import async_get_all_descriptions
|
||||
|
||||
|
@ -141,10 +141,15 @@ async def handle_call_service(hass, connection, msg):
|
|||
if (msg['domain'] == HASS_DOMAIN and
|
||||
msg['service'] in ['restart', 'stop']):
|
||||
blocking = False
|
||||
|
||||
try:
|
||||
await hass.services.async_call(
|
||||
msg['domain'], msg['service'], msg.get('service_data'), blocking,
|
||||
connection.context(msg))
|
||||
connection.send_message(messages.result_message(msg['id']))
|
||||
except ServiceNotFound:
|
||||
connection.send_message(messages.error_message(
|
||||
msg['id'], const.ERR_NOT_FOUND, 'Service not found.'))
|
||||
|
||||
|
||||
@callback
|
||||
|
|
|
@ -163,7 +163,6 @@ EVENT_HOMEASSISTANT_CLOSE = 'homeassistant_close'
|
|||
EVENT_STATE_CHANGED = 'state_changed'
|
||||
EVENT_TIME_CHANGED = 'time_changed'
|
||||
EVENT_CALL_SERVICE = 'call_service'
|
||||
EVENT_SERVICE_EXECUTED = 'service_executed'
|
||||
EVENT_PLATFORM_DISCOVERED = 'platform_discovered'
|
||||
EVENT_COMPONENT_LOADED = 'component_loaded'
|
||||
EVENT_SERVICE_REGISTERED = 'service_registered'
|
||||
|
@ -233,9 +232,6 @@ ATTR_ID = 'id'
|
|||
# Name
|
||||
ATTR_NAME = 'name'
|
||||
|
||||
# Data for a SERVICE_EXECUTED event
|
||||
ATTR_SERVICE_CALL_ID = 'service_call_id'
|
||||
|
||||
# Contains one string or a list of strings, each being an entity id
|
||||
ATTR_ENTITY_ID = 'entity_id'
|
||||
|
||||
|
|
|
@ -25,18 +25,18 @@ from typing import ( # noqa: F401 pylint: disable=unused-import
|
|||
from async_timeout import timeout
|
||||
import attr
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant.const import (
|
||||
ATTR_DOMAIN, ATTR_FRIENDLY_NAME, ATTR_NOW, ATTR_SERVICE,
|
||||
ATTR_SERVICE_CALL_ID, ATTR_SERVICE_DATA, ATTR_SECONDS, EVENT_CALL_SERVICE,
|
||||
ATTR_SERVICE_DATA, ATTR_SECONDS, EVENT_CALL_SERVICE,
|
||||
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
||||
EVENT_HOMEASSISTANT_CLOSE, EVENT_SERVICE_REMOVED,
|
||||
EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
|
||||
EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
|
||||
EVENT_TIME_CHANGED, EVENT_TIMER_OUT_OF_SYNC, MATCH_ALL, __version__)
|
||||
from homeassistant import loader
|
||||
from homeassistant.exceptions import (
|
||||
HomeAssistantError, InvalidEntityFormatError, InvalidStateError)
|
||||
HomeAssistantError, InvalidEntityFormatError, InvalidStateError,
|
||||
Unauthorized, ServiceNotFound)
|
||||
from homeassistant.util.async_ import (
|
||||
run_coroutine_threadsafe, run_callback_threadsafe,
|
||||
fire_coroutine_threadsafe)
|
||||
|
@ -954,7 +954,6 @@ class ServiceRegistry:
|
|||
"""Initialize a service registry."""
|
||||
self._services = {} # type: Dict[str, Dict[str, Service]]
|
||||
self._hass = hass
|
||||
self._async_unsub_call_event = None # type: Optional[CALLBACK_TYPE]
|
||||
|
||||
@property
|
||||
def services(self) -> Dict[str, Dict[str, Service]]:
|
||||
|
@ -1010,10 +1009,6 @@ class ServiceRegistry:
|
|||
else:
|
||||
self._services[domain] = {service: service_obj}
|
||||
|
||||
if self._async_unsub_call_event is None:
|
||||
self._async_unsub_call_event = self._hass.bus.async_listen(
|
||||
EVENT_CALL_SERVICE, self._event_to_service_call)
|
||||
|
||||
self._hass.bus.async_fire(
|
||||
EVENT_SERVICE_REGISTERED,
|
||||
{ATTR_DOMAIN: domain, ATTR_SERVICE: service}
|
||||
|
@ -1092,100 +1087,61 @@ class ServiceRegistry:
|
|||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
domain = domain.lower()
|
||||
service = service.lower()
|
||||
context = context or Context()
|
||||
call_id = uuid.uuid4().hex
|
||||
event_data = {
|
||||
service_data = service_data or {}
|
||||
|
||||
try:
|
||||
handler = self._services[domain][service]
|
||||
except KeyError:
|
||||
raise ServiceNotFound(domain, service) from None
|
||||
|
||||
if handler.schema:
|
||||
service_data = handler.schema(service_data)
|
||||
|
||||
service_call = ServiceCall(domain, service, service_data, context)
|
||||
|
||||
self._hass.bus.async_fire(EVENT_CALL_SERVICE, {
|
||||
ATTR_DOMAIN: domain.lower(),
|
||||
ATTR_SERVICE: service.lower(),
|
||||
ATTR_SERVICE_DATA: service_data,
|
||||
ATTR_SERVICE_CALL_ID: call_id,
|
||||
}
|
||||
})
|
||||
|
||||
if not blocking:
|
||||
self._hass.bus.async_fire(
|
||||
EVENT_CALL_SERVICE, event_data, EventOrigin.local, context)
|
||||
self._hass.async_create_task(
|
||||
self._safe_execute(handler, service_call))
|
||||
return None
|
||||
|
||||
fut = asyncio.Future() # type: asyncio.Future
|
||||
|
||||
@callback
|
||||
def service_executed(event: Event) -> None:
|
||||
"""Handle an executed service."""
|
||||
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
|
||||
fut.set_result(True)
|
||||
unsub()
|
||||
|
||||
unsub = self._hass.bus.async_listen(
|
||||
EVENT_SERVICE_EXECUTED, service_executed)
|
||||
|
||||
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data,
|
||||
EventOrigin.local, context)
|
||||
|
||||
done, _ = await asyncio.wait([fut], timeout=SERVICE_CALL_LIMIT)
|
||||
success = bool(done)
|
||||
if not success:
|
||||
unsub()
|
||||
return success
|
||||
|
||||
async def _event_to_service_call(self, event: Event) -> None:
|
||||
"""Handle the SERVICE_CALLED events from the EventBus."""
|
||||
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
|
||||
domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore
|
||||
service = event.data.get(ATTR_SERVICE).lower() # type: ignore
|
||||
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
|
||||
|
||||
if not self.has_service(domain, service):
|
||||
if event.origin == EventOrigin.local:
|
||||
_LOGGER.warning("Unable to find service %s/%s",
|
||||
domain, service)
|
||||
return
|
||||
|
||||
service_handler = self._services[domain][service]
|
||||
|
||||
def fire_service_executed() -> None:
|
||||
"""Fire service executed event."""
|
||||
if not call_id:
|
||||
return
|
||||
|
||||
data = {ATTR_SERVICE_CALL_ID: call_id}
|
||||
|
||||
if (service_handler.is_coroutinefunction or
|
||||
service_handler.is_callback):
|
||||
self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, data,
|
||||
EventOrigin.local, event.context)
|
||||
else:
|
||||
self._hass.bus.fire(EVENT_SERVICE_EXECUTED, data,
|
||||
EventOrigin.local, event.context)
|
||||
|
||||
try:
|
||||
if service_handler.schema:
|
||||
service_data = service_handler.schema(service_data)
|
||||
except vol.Invalid as ex:
|
||||
_LOGGER.error("Invalid service data for %s.%s: %s",
|
||||
domain, service, humanize_error(service_data, ex))
|
||||
fire_service_executed()
|
||||
return
|
||||
|
||||
service_call = ServiceCall(
|
||||
domain, service, service_data, event.context)
|
||||
with timeout(SERVICE_CALL_LIMIT):
|
||||
await asyncio.shield(
|
||||
self._execute_service(handler, service_call))
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
|
||||
async def _safe_execute(self, handler: Service,
|
||||
service_call: ServiceCall) -> None:
|
||||
"""Execute a service and catch exceptions."""
|
||||
try:
|
||||
if service_handler.is_callback:
|
||||
service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
elif service_handler.is_coroutinefunction:
|
||||
await service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
else:
|
||||
def execute_service() -> None:
|
||||
"""Execute a service and fires a SERVICE_EXECUTED event."""
|
||||
service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
|
||||
await self._hass.async_add_executor_job(execute_service)
|
||||
await self._execute_service(handler, service_call)
|
||||
except Unauthorized:
|
||||
_LOGGER.warning('Unauthorized service called %s/%s',
|
||||
service_call.domain, service_call.service)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception('Error executing service %s', service_call)
|
||||
|
||||
async def _execute_service(self, handler: Service,
|
||||
service_call: ServiceCall) -> None:
|
||||
"""Execute a service."""
|
||||
if handler.is_callback:
|
||||
handler.func(service_call)
|
||||
elif handler.is_coroutinefunction:
|
||||
await handler.func(service_call)
|
||||
else:
|
||||
await self._hass.async_add_executor_job(handler.func, service_call)
|
||||
|
||||
|
||||
class Config:
|
||||
"""Configuration settings for Home Assistant."""
|
||||
|
|
|
@ -58,3 +58,14 @@ class Unauthorized(HomeAssistantError):
|
|||
|
||||
class UnknownUser(Unauthorized):
|
||||
"""When call is made with user ID that doesn't exist."""
|
||||
|
||||
|
||||
class ServiceNotFound(HomeAssistantError):
|
||||
"""Raised when a service is not found."""
|
||||
|
||||
def __init__(self, domain: str, service: str) -> None:
|
||||
"""Initialize error."""
|
||||
super().__init__(
|
||||
self, "Service {}.{} not found".format(domain, service))
|
||||
self.domain = domain
|
||||
self.service = service
|
||||
|
|
|
@ -61,6 +61,7 @@ async def test_validating_mfa_counter(hass):
|
|||
'counter': 0,
|
||||
'notify_service': 'dummy',
|
||||
})
|
||||
async_mock_service(hass, 'notify', 'dummy')
|
||||
|
||||
assert notify_auth_module._user_settings
|
||||
notify_setting = list(notify_auth_module._user_settings.values())[0]
|
||||
|
@ -389,9 +390,8 @@ async def test_not_raise_exception_when_service_not_exist(hass):
|
|||
'username': 'test-user',
|
||||
'password': 'test-pass',
|
||||
})
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result['step_id'] == 'mfa'
|
||||
assert result['data_schema'].schema.get('code') == str
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result['reason'] == 'unknown_error'
|
||||
|
||||
# wait service call finished
|
||||
await hass.async_block_till_done()
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
"""The tests for the demo climate component."""
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.util.unit_system import (
|
||||
METRIC_SYSTEM
|
||||
)
|
||||
|
@ -57,6 +60,7 @@ class TestDemoClimate(unittest.TestCase):
|
|||
"""Test setting the target temperature without required attribute."""
|
||||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
assert 21 == state.attributes.get('temperature')
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.set_temperature(self.hass, None, ENTITY_CLIMATE)
|
||||
self.hass.block_till_done()
|
||||
assert 21 == state.attributes.get('temperature')
|
||||
|
@ -99,8 +103,10 @@ class TestDemoClimate(unittest.TestCase):
|
|||
assert state.attributes.get('temperature') is None
|
||||
assert 21.0 == state.attributes.get('target_temp_low')
|
||||
assert 24.0 == state.attributes.get('target_temp_high')
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.set_temperature(self.hass, temperature=None,
|
||||
entity_id=ENTITY_ECOBEE, target_temp_low=None,
|
||||
entity_id=ENTITY_ECOBEE,
|
||||
target_temp_low=None,
|
||||
target_temp_high=None)
|
||||
self.hass.block_till_done()
|
||||
state = self.hass.states.get(ENTITY_ECOBEE)
|
||||
|
@ -112,6 +118,7 @@ class TestDemoClimate(unittest.TestCase):
|
|||
"""Test setting the target humidity without required attribute."""
|
||||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
assert 67 == state.attributes.get('humidity')
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.set_humidity(self.hass, None, ENTITY_CLIMATE)
|
||||
self.hass.block_till_done()
|
||||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
|
@ -130,6 +137,7 @@ class TestDemoClimate(unittest.TestCase):
|
|||
"""Test setting fan mode without required attribute."""
|
||||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
assert "On High" == state.attributes.get('fan_mode')
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.set_fan_mode(self.hass, None, ENTITY_CLIMATE)
|
||||
self.hass.block_till_done()
|
||||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
|
@ -148,6 +156,7 @@ class TestDemoClimate(unittest.TestCase):
|
|||
"""Test setting swing mode without required attribute."""
|
||||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
assert "Off" == state.attributes.get('swing_mode')
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.set_swing_mode(self.hass, None, ENTITY_CLIMATE)
|
||||
self.hass.block_till_done()
|
||||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
|
@ -170,6 +179,7 @@ class TestDemoClimate(unittest.TestCase):
|
|||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
assert "cool" == state.attributes.get('operation_mode')
|
||||
assert "cool" == state.state
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.set_operation_mode(self.hass, None, ENTITY_CLIMATE)
|
||||
self.hass.block_till_done()
|
||||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
"""The tests for the climate component."""
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.climate import SET_TEMPERATURE_SCHEMA
|
||||
from tests.common import async_mock_service
|
||||
|
||||
|
@ -14,12 +17,11 @@ def test_set_temp_schema_no_req(hass, caplog):
|
|||
calls = async_mock_service(hass, domain, service, schema)
|
||||
|
||||
data = {'operation_mode': 'test', 'entity_id': ['climate.test_id']}
|
||||
with pytest.raises(vol.Invalid):
|
||||
yield from hass.services.async_call(domain, service, data)
|
||||
yield from hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 0
|
||||
assert 'ERROR' in caplog.text
|
||||
assert 'Invalid service data' in caplog.text
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
|
|
|
@ -2,6 +2,9 @@
|
|||
import unittest
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.util.unit_system import (
|
||||
METRIC_SYSTEM
|
||||
)
|
||||
|
@ -91,6 +94,7 @@ class TestMQTTClimate(unittest.TestCase):
|
|||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
assert "off" == state.attributes.get('operation_mode')
|
||||
assert "off" == state.state
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.set_operation_mode(self.hass, None, ENTITY_CLIMATE)
|
||||
self.hass.block_till_done()
|
||||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
|
@ -177,6 +181,7 @@ class TestMQTTClimate(unittest.TestCase):
|
|||
|
||||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
assert "low" == state.attributes.get('fan_mode')
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.set_fan_mode(self.hass, None, ENTITY_CLIMATE)
|
||||
self.hass.block_till_done()
|
||||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
|
@ -225,6 +230,7 @@ class TestMQTTClimate(unittest.TestCase):
|
|||
|
||||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
assert "off" == state.attributes.get('swing_mode')
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.set_swing_mode(self.hass, None, ENTITY_CLIMATE)
|
||||
self.hass.block_till_done()
|
||||
state = self.hass.states.get(ENTITY_CLIMATE)
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
"""Test deCONZ component setup process."""
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components import deconz
|
||||
|
||||
|
@ -163,9 +166,11 @@ async def test_service_configure(hass):
|
|||
await hass.async_block_till_done()
|
||||
|
||||
# field does not start with /
|
||||
with pytest.raises(vol.Invalid):
|
||||
with patch('pydeconz.DeconzSession.async_put_state',
|
||||
return_value=mock_coro(True)):
|
||||
await hass.services.async_call('deconz', 'configure', service_data={
|
||||
await hass.services.async_call(
|
||||
'deconz', 'configure', service_data={
|
||||
'entity': 'light.test', 'field': 'state', 'data': data})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
|
|
@ -1,8 +1,25 @@
|
|||
"""Tests for Home Assistant View."""
|
||||
from aiohttp.web_exceptions import HTTPInternalServerError
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from homeassistant.components.http.view import HomeAssistantView
|
||||
from aiohttp.web_exceptions import (
|
||||
HTTPInternalServerError, HTTPBadRequest, HTTPUnauthorized)
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.http.view import (
|
||||
HomeAssistantView, request_handler_factory)
|
||||
from homeassistant.exceptions import ServiceNotFound, Unauthorized
|
||||
|
||||
from tests.common import mock_coro_func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Mock a request."""
|
||||
return Mock(
|
||||
app={'hass': Mock(is_running=True)},
|
||||
match_info={},
|
||||
)
|
||||
|
||||
|
||||
async def test_invalid_json(caplog):
|
||||
|
@ -13,3 +30,30 @@ async def test_invalid_json(caplog):
|
|||
view.json(float("NaN"))
|
||||
|
||||
assert str(float("NaN")) in caplog.text
|
||||
|
||||
|
||||
async def test_handling_unauthorized(mock_request):
|
||||
"""Test handling unauth exceptions."""
|
||||
with pytest.raises(HTTPUnauthorized):
|
||||
await request_handler_factory(
|
||||
Mock(requires_auth=False),
|
||||
mock_coro_func(exception=Unauthorized)
|
||||
)(mock_request)
|
||||
|
||||
|
||||
async def test_handling_invalid_data(mock_request):
|
||||
"""Test handling unauth exceptions."""
|
||||
with pytest.raises(HTTPBadRequest):
|
||||
await request_handler_factory(
|
||||
Mock(requires_auth=False),
|
||||
mock_coro_func(exception=vol.Invalid('yo'))
|
||||
)(mock_request)
|
||||
|
||||
|
||||
async def test_handling_service_not_found(mock_request):
|
||||
"""Test handling unauth exceptions."""
|
||||
with pytest.raises(HTTPInternalServerError):
|
||||
await request_handler_factory(
|
||||
Mock(requires_auth=False),
|
||||
mock_coro_func(exception=ServiceNotFound('test', 'test'))
|
||||
)(mock_request)
|
||||
|
|
|
@ -3,6 +3,9 @@ import unittest
|
|||
from unittest.mock import patch
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.setup import setup_component
|
||||
from homeassistant.const import HTTP_HEADER_HA_AUTH
|
||||
import homeassistant.components.media_player as mp
|
||||
|
@ -43,6 +46,7 @@ class TestDemoMediaPlayer(unittest.TestCase):
|
|||
state = self.hass.states.get(entity_id)
|
||||
assert 'dvd' == state.attributes.get('source')
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.select_source(self.hass, None, entity_id)
|
||||
self.hass.block_till_done()
|
||||
state = self.hass.states.get(entity_id)
|
||||
|
@ -72,6 +76,7 @@ class TestDemoMediaPlayer(unittest.TestCase):
|
|||
state = self.hass.states.get(entity_id)
|
||||
assert 1.0 == state.attributes.get('volume_level')
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.set_volume_level(self.hass, None, entity_id)
|
||||
self.hass.block_till_done()
|
||||
state = self.hass.states.get(entity_id)
|
||||
|
@ -201,6 +206,7 @@ class TestDemoMediaPlayer(unittest.TestCase):
|
|||
state.attributes.get('supported_features'))
|
||||
assert state.attributes.get('media_content_id') is not None
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.play_media(self.hass, None, 'some_id', ent_id)
|
||||
self.hass.block_till_done()
|
||||
state = self.hass.states.get(ent_id)
|
||||
|
@ -216,6 +222,7 @@ class TestDemoMediaPlayer(unittest.TestCase):
|
|||
assert 'some_id' == state.attributes.get('media_content_id')
|
||||
|
||||
assert not mock_seek.called
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.media_seek(self.hass, None, ent_id)
|
||||
self.hass.block_till_done()
|
||||
assert not mock_seek.called
|
||||
|
|
|
@ -223,7 +223,7 @@ class TestMonopriceMediaPlayer(unittest.TestCase):
|
|||
# Restoring wrong media player to its previous state
|
||||
# Nothing should be done
|
||||
self.hass.services.call(DOMAIN, SERVICE_RESTORE,
|
||||
{'entity_id': 'not_existing'},
|
||||
{'entity_id': 'media.not_existing'},
|
||||
blocking=True)
|
||||
# self.hass.block_till_done()
|
||||
|
||||
|
|
|
@ -113,6 +113,7 @@ class TestMQTTComponent(unittest.TestCase):
|
|||
"""
|
||||
payload = "not a template"
|
||||
payload_template = "a template"
|
||||
with pytest.raises(vol.Invalid):
|
||||
self.hass.services.call(mqtt.DOMAIN, mqtt.SERVICE_PUBLISH, {
|
||||
mqtt.ATTR_TOPIC: "test/topic",
|
||||
mqtt.ATTR_PAYLOAD: payload,
|
||||
|
|
|
@ -2,6 +2,9 @@
|
|||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
import homeassistant.components.notify as notify
|
||||
from homeassistant.setup import setup_component
|
||||
from homeassistant.components.notify import demo
|
||||
|
@ -81,6 +84,7 @@ class TestNotifyDemo(unittest.TestCase):
|
|||
def test_sending_none_message(self):
|
||||
"""Test send with None as message."""
|
||||
self._setup_notify()
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.send_message(self.hass, None)
|
||||
self.hass.block_till_done()
|
||||
assert len(self.events) == 0
|
||||
|
|
|
@ -99,6 +99,7 @@ class TestAlert(unittest.TestCase):
|
|||
def setUp(self):
|
||||
"""Set up things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
self._setup_notify()
|
||||
|
||||
def tearDown(self):
|
||||
"""Stop everything that was started."""
|
||||
|
|
|
@ -6,6 +6,7 @@ from unittest.mock import patch
|
|||
|
||||
from aiohttp import web
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import const
|
||||
from homeassistant.bootstrap import DATA_LOGGING
|
||||
|
@ -578,3 +579,29 @@ async def test_rendering_template_legacy_user(
|
|||
json={"template": '{{ states.sensor.temperature.state }}'}
|
||||
)
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
async def test_api_call_service_not_found(hass, mock_api_client):
|
||||
"""Test if the API failes 400 if unknown service."""
|
||||
resp = await mock_api_client.post(
|
||||
const.URL_API_SERVICES_SERVICE.format(
|
||||
"test_domain", "test_service"))
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
async def test_api_call_service_bad_data(hass, mock_api_client):
|
||||
"""Test if the API failes 400 if unknown service."""
|
||||
test_value = []
|
||||
|
||||
@ha.callback
|
||||
def listener(service_call):
|
||||
"""Record that our service got called."""
|
||||
test_value.append(1)
|
||||
|
||||
hass.services.async_register("test_domain", "test_service", listener,
|
||||
schema=vol.Schema({'hello': str}))
|
||||
|
||||
resp = await mock_api_client.post(
|
||||
const.URL_API_SERVICES_SERVICE.format(
|
||||
"test_domain", "test_service"), json={'hello': 5})
|
||||
assert resp.status == 400
|
||||
|
|
|
@ -3,6 +3,9 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import CoreState, State, Context
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components.input_datetime import (
|
||||
|
@ -109,6 +112,7 @@ def test_set_invalid(hass):
|
|||
dt_obj = datetime.datetime(2017, 9, 7, 19, 46)
|
||||
time_portion = dt_obj.time()
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
yield from hass.services.async_call('input_datetime', 'set_datetime', {
|
||||
'entity_id': 'test_date',
|
||||
'time': time_portion
|
||||
|
|
|
@ -4,6 +4,9 @@ import logging
|
|||
from datetime import (timedelta, datetime)
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import sun
|
||||
import homeassistant.core as ha
|
||||
from homeassistant.const import (
|
||||
|
@ -89,6 +92,8 @@ class TestComponentLogbook(unittest.TestCase):
|
|||
calls.append(event)
|
||||
|
||||
self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener)
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
self.hass.services.call(logbook.DOMAIN, 'log', {}, True)
|
||||
|
||||
# Logbook entry service call results in firing an event.
|
||||
|
|
|
@ -2,6 +2,9 @@
|
|||
import json
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.bootstrap import async_setup_component
|
||||
from homeassistant.components.mqtt import MQTT_PUBLISH_SCHEMA
|
||||
import homeassistant.components.snips as snips
|
||||
|
@ -452,12 +455,11 @@ async def test_snips_say_invalid_config(hass, caplog):
|
|||
snips.SERVICE_SCHEMA_SAY)
|
||||
|
||||
data = {'text': 'Hello', 'badKey': 'boo'}
|
||||
with pytest.raises(vol.Invalid):
|
||||
await hass.services.async_call('snips', 'say', data)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 0
|
||||
assert 'ERROR' in caplog.text
|
||||
assert 'Invalid service data' in caplog.text
|
||||
|
||||
|
||||
async def test_snips_say_action_invalid(hass, caplog):
|
||||
|
@ -466,12 +468,12 @@ async def test_snips_say_action_invalid(hass, caplog):
|
|||
snips.SERVICE_SCHEMA_SAY_ACTION)
|
||||
|
||||
data = {'text': 'Hello', 'can_be_enqueued': 'notabool'}
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
await hass.services.async_call('snips', 'say_action', data)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 0
|
||||
assert 'ERROR' in caplog.text
|
||||
assert 'Invalid service data' in caplog.text
|
||||
|
||||
|
||||
async def test_snips_feedback_on(hass, caplog):
|
||||
|
@ -510,6 +512,7 @@ async def test_snips_feedback_config(hass, caplog):
|
|||
snips.SERVICE_SCHEMA_FEEDBACK)
|
||||
|
||||
data = {'site_id': 'remote', 'test': 'test'}
|
||||
with pytest.raises(vol.Invalid):
|
||||
await hass.services.async_call('snips', 'feedback_on', data)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import asyncio
|
|||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components.wake_on_lan import (
|
||||
|
@ -34,10 +35,10 @@ def test_send_magic_packet(hass, caplog, mock_wakeonlan):
|
|||
assert mock_wakeonlan.mock_calls[-1][1][0] == mac
|
||||
assert mock_wakeonlan.mock_calls[-1][2]['ip_address'] == bc_ip
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
yield from hass.services.async_call(
|
||||
DOMAIN, SERVICE_SEND_MAGIC_PACKET,
|
||||
{"broadcast_address": bc_ip}, blocking=True)
|
||||
assert 'ERROR' in caplog.text
|
||||
assert len(mock_wakeonlan.mock_calls) == 1
|
||||
|
||||
yield from hass.services.async_call(
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
"""The tests for the demo water_heater component."""
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.util.unit_system import (
|
||||
IMPERIAL_SYSTEM
|
||||
)
|
||||
|
@ -48,6 +51,7 @@ class TestDemowater_heater(unittest.TestCase):
|
|||
"""Test setting the target temperature without required attribute."""
|
||||
state = self.hass.states.get(ENTITY_WATER_HEATER)
|
||||
assert 119 == state.attributes.get('temperature')
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.set_temperature(self.hass, None, ENTITY_WATER_HEATER)
|
||||
self.hass.block_till_done()
|
||||
assert 119 == state.attributes.get('temperature')
|
||||
|
@ -69,6 +73,7 @@ class TestDemowater_heater(unittest.TestCase):
|
|||
state = self.hass.states.get(ENTITY_WATER_HEATER)
|
||||
assert "eco" == state.attributes.get('operation_mode')
|
||||
assert "eco" == state.state
|
||||
with pytest.raises(vol.Invalid):
|
||||
common.set_operation_mode(self.hass, None, ENTITY_WATER_HEATER)
|
||||
self.hass.block_till_done()
|
||||
state = self.hass.states.get(ENTITY_WATER_HEATER)
|
||||
|
|
|
@ -49,6 +49,25 @@ async def test_call_service(hass, websocket_client):
|
|||
assert call.data == {'hello': 'world'}
|
||||
|
||||
|
||||
async def test_call_service_not_found(hass, websocket_client):
|
||||
"""Test call service command."""
|
||||
await websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': commands.TYPE_CALL_SERVICE,
|
||||
'domain': 'domain_test',
|
||||
'service': 'test_service',
|
||||
'service_data': {
|
||||
'hello': 'world'
|
||||
}
|
||||
})
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == const.TYPE_RESULT
|
||||
assert not msg['success']
|
||||
assert msg['error']['code'] == const.ERR_NOT_FOUND
|
||||
|
||||
|
||||
async def test_subscribe_unsubscribe_events(hass, websocket_client):
|
||||
"""Test subscribe/unsubscribe events command."""
|
||||
init_count = sum(hass.bus.async_listeners().values())
|
||||
|
|
|
@ -947,7 +947,7 @@ class TestZWaveServices(unittest.TestCase):
|
|||
assert self.zwave_network.stop.called
|
||||
assert len(self.zwave_network.stop.mock_calls) == 1
|
||||
assert mock_fire.called
|
||||
assert len(mock_fire.mock_calls) == 2
|
||||
assert len(mock_fire.mock_calls) == 1
|
||||
assert mock_fire.mock_calls[0][1][0] == const.EVENT_NETWORK_STOP
|
||||
|
||||
def test_rename_node(self):
|
||||
|
|
|
@ -21,7 +21,7 @@ from homeassistant.const import (
|
|||
__version__, EVENT_STATE_CHANGED, ATTR_FRIENDLY_NAME, CONF_UNIT_SYSTEM,
|
||||
ATTR_NOW, EVENT_TIME_CHANGED, EVENT_TIMER_OUT_OF_SYNC, ATTR_SECONDS,
|
||||
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_CLOSE,
|
||||
EVENT_SERVICE_REGISTERED, EVENT_SERVICE_REMOVED, EVENT_SERVICE_EXECUTED)
|
||||
EVENT_SERVICE_REGISTERED, EVENT_SERVICE_REMOVED)
|
||||
|
||||
from tests.common import get_test_home_assistant, async_mock_service
|
||||
|
||||
|
@ -673,13 +673,8 @@ class TestServiceRegistry(unittest.TestCase):
|
|||
|
||||
def test_call_non_existing_with_blocking(self):
|
||||
"""Test non-existing with blocking."""
|
||||
prior = ha.SERVICE_CALL_LIMIT
|
||||
try:
|
||||
ha.SERVICE_CALL_LIMIT = 0.01
|
||||
assert not self.services.call('test_domain', 'i_do_not_exist',
|
||||
blocking=True)
|
||||
finally:
|
||||
ha.SERVICE_CALL_LIMIT = prior
|
||||
with pytest.raises(ha.ServiceNotFound):
|
||||
self.services.call('test_domain', 'i_do_not_exist', blocking=True)
|
||||
|
||||
def test_async_service(self):
|
||||
"""Test registering and calling an async service."""
|
||||
|
@ -1005,4 +1000,3 @@ async def test_service_executed_with_subservices(hass):
|
|||
assert len(calls) == 4
|
||||
assert [call.service for call in calls] == [
|
||||
'outer', 'inner', 'inner', 'outer']
|
||||
assert len(hass.bus.async_listeners().get(EVENT_SERVICE_EXECUTED, [])) == 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue