Break up websocket component (#17003)

* Break up websocket component

* Lint
This commit is contained in:
Paulus Schoutsen 2018-10-01 11:21:00 +02:00 committed by GitHub
parent 9edf1e5151
commit 22a80cf733
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 1041 additions and 1003 deletions

View file

@ -452,27 +452,23 @@ class CameraMjpegStream(CameraView):
raise web.HTTPBadRequest() raise web.HTTPBadRequest()
@callback @websocket_api.async_response
def websocket_camera_thumbnail(hass, connection, msg): async def websocket_camera_thumbnail(hass, connection, msg):
"""Handle get camera thumbnail websocket command. """Handle get camera thumbnail websocket command.
Async friendly. Async friendly.
""" """
async def send_camera_still(): try:
"""Send a camera still.""" image = await async_get_image(hass, msg['entity_id'])
try: connection.send_message_outside(websocket_api.result_message(
image = await async_get_image(hass, msg['entity_id']) msg['id'], {
connection.send_message_outside(websocket_api.result_message( 'content_type': image.content_type,
msg['id'], { 'content': base64.b64encode(image.content).decode('utf-8')
'content_type': image.content_type, }
'content': base64.b64encode(image.content).decode('utf-8') ))
} except HomeAssistantError:
)) connection.send_message_outside(websocket_api.error_message(
except HomeAssistantError: msg['id'], 'image_fetch_failed', 'Unable to fetch image'))
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'image_fetch_failed', 'Unable to fetch image'))
hass.async_add_job(send_camera_still())
async def async_handle_snapshot_service(camera, service): async def async_handle_snapshot_service(camera, service):

View file

@ -3,6 +3,7 @@ import voluptuous as vol
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.decorators import require_owner
WS_TYPE_LIST = 'config/auth/list' WS_TYPE_LIST = 'config/auth/list'
@ -41,7 +42,7 @@ async def async_setup(hass):
@callback @callback
@websocket_api.require_owner @require_owner
def websocket_list(hass, connection, msg): def websocket_list(hass, connection, msg):
"""Return a list of users.""" """Return a list of users."""
async def send_users(): async def send_users():
@ -55,7 +56,7 @@ def websocket_list(hass, connection, msg):
@callback @callback
@websocket_api.require_owner @require_owner
def websocket_delete(hass, connection, msg): def websocket_delete(hass, connection, msg):
"""Delete a user.""" """Delete a user."""
async def delete_user(): async def delete_user():
@ -82,7 +83,7 @@ def websocket_delete(hass, connection, msg):
@callback @callback
@websocket_api.require_owner @require_owner
def websocket_create(hass, connection, msg): def websocket_create(hass, connection, msg):
"""Create a user.""" """Create a user."""
async def create_user(): async def create_user():

View file

@ -4,6 +4,7 @@ import voluptuous as vol
from homeassistant.auth.providers import homeassistant as auth_ha from homeassistant.auth.providers import homeassistant as auth_ha
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.decorators import require_owner
WS_TYPE_CREATE = 'config/auth_provider/homeassistant/create' WS_TYPE_CREATE = 'config/auth_provider/homeassistant/create'
@ -55,7 +56,7 @@ def _get_provider(hass):
@callback @callback
@websocket_api.require_owner @require_owner
def websocket_create(hass, connection, msg): def websocket_create(hass, connection, msg):
"""Create credentials and attach to a user.""" """Create credentials and attach to a user."""
async def create_creds(): async def create_creds():
@ -96,7 +97,7 @@ def websocket_create(hass, connection, msg):
@callback @callback
@websocket_api.require_owner @require_owner
def websocket_delete(hass, connection, msg): def websocket_delete(hass, connection, msg):
"""Delete username and related credential.""" """Delete username and related credential."""
async def delete_creds(): async def delete_creds():

View file

@ -4,6 +4,8 @@ import voluptuous as vol
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.entity_registry import async_get_registry from homeassistant.helpers.entity_registry import async_get_registry
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.const import ERR_NOT_FOUND
from homeassistant.components.websocket_api.decorators import async_response
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
DEPENDENCIES = ['websocket_api'] DEPENDENCIES = ['websocket_api']
@ -46,89 +48,77 @@ async def async_setup(hass):
return True return True
@callback @async_response
def websocket_list_entities(hass, connection, msg): async def websocket_list_entities(hass, connection, msg):
"""Handle list registry entries command. """Handle list registry entries command.
Async friendly. Async friendly.
""" """
async def retrieve_entities(): registry = await async_get_registry(hass)
"""Get entities from registry.""" connection.send_message_outside(websocket_api.result_message(
registry = await async_get_registry(hass) msg['id'], [{
connection.send_message_outside(websocket_api.result_message( 'config_entry_id': entry.config_entry_id,
msg['id'], [{ 'device_id': entry.device_id,
'config_entry_id': entry.config_entry_id, 'disabled_by': entry.disabled_by,
'device_id': entry.device_id, 'entity_id': entry.entity_id,
'disabled_by': entry.disabled_by, 'name': entry.name,
'entity_id': entry.entity_id, 'platform': entry.platform,
'name': entry.name, } for entry in registry.entities.values()]
'platform': entry.platform, ))
} for entry in registry.entities.values()]
))
hass.async_add_job(retrieve_entities())
@callback @async_response
def websocket_get_entity(hass, connection, msg): async def websocket_get_entity(hass, connection, msg):
"""Handle get entity registry entry command. """Handle get entity registry entry command.
Async friendly. Async friendly.
""" """
async def retrieve_entity(): registry = await async_get_registry(hass)
"""Get entity from registry.""" entry = registry.entities.get(msg['entity_id'])
registry = await async_get_registry(hass)
entry = registry.entities.get(msg['entity_id'])
if entry is None: if entry is None:
connection.send_message_outside(websocket_api.error_message( connection.send_message_outside(websocket_api.error_message(
msg['id'], websocket_api.ERR_NOT_FOUND, 'Entity not found')) msg['id'], ERR_NOT_FOUND, 'Entity not found'))
return return
connection.send_message_outside(websocket_api.result_message( connection.send_message_outside(websocket_api.result_message(
msg['id'], _entry_dict(entry) msg['id'], _entry_dict(entry)
)) ))
hass.async_add_job(retrieve_entity())
@callback @async_response
def websocket_update_entity(hass, connection, msg): async def websocket_update_entity(hass, connection, msg):
"""Handle get camera thumbnail websocket command. """Handle get camera thumbnail websocket command.
Async friendly. Async friendly.
""" """
async def update_entity(): registry = await async_get_registry(hass)
"""Get entity from registry."""
registry = await async_get_registry(hass)
if msg['entity_id'] not in registry.entities: if msg['entity_id'] not in registry.entities:
connection.send_message_outside(websocket_api.error_message( connection.send_message_outside(websocket_api.error_message(
msg['id'], websocket_api.ERR_NOT_FOUND, 'Entity not found')) msg['id'], ERR_NOT_FOUND, 'Entity not found'))
return return
changes = {} changes = {}
if 'name' in msg: if 'name' in msg:
changes['name'] = msg['name'] changes['name'] = msg['name']
if 'new_entity_id' in msg: if 'new_entity_id' in msg:
changes['new_entity_id'] = msg['new_entity_id'] changes['new_entity_id'] = msg['new_entity_id']
try: try:
if changes: if changes:
entry = registry.async_update_entity( entry = registry.async_update_entity(
msg['entity_id'], **changes) msg['entity_id'], **changes)
except ValueError as err: except ValueError as err:
connection.send_message_outside(websocket_api.error_message( connection.send_message_outside(websocket_api.error_message(
msg['id'], 'invalid_info', str(err) msg['id'], 'invalid_info', str(err)
)) ))
else: else:
connection.send_message_outside(websocket_api.result_message( connection.send_message_outside(websocket_api.result_message(
msg['id'], _entry_dict(entry) msg['id'], _entry_dict(entry)
)) ))
hass.async_create_task(update_entity())
@callback @callback

View file

@ -28,7 +28,6 @@ from homeassistant.const import (
SERVICE_TOGGLE, SERVICE_TURN_OFF, SERVICE_TURN_ON, SERVICE_VOLUME_DOWN, SERVICE_TOGGLE, SERVICE_TURN_OFF, SERVICE_TURN_ON, SERVICE_VOLUME_DOWN,
SERVICE_VOLUME_MUTE, SERVICE_VOLUME_SET, SERVICE_VOLUME_UP, STATE_IDLE, SERVICE_VOLUME_MUTE, SERVICE_VOLUME_SET, SERVICE_VOLUME_UP, STATE_IDLE,
STATE_OFF, STATE_PLAYING, STATE_UNKNOWN) STATE_OFF, STATE_PLAYING, STATE_UNKNOWN)
from homeassistant.core import callback
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa
@ -865,8 +864,8 @@ class MediaPlayerImageView(HomeAssistantView):
body=data, content_type=content_type, headers=headers) body=data, content_type=content_type, headers=headers)
@callback @websocket_api.async_response
def websocket_handle_thumbnail(hass, connection, msg): async def websocket_handle_thumbnail(hass, connection, msg):
"""Handle get media player cover command. """Handle get media player cover command.
Async friendly. Async friendly.
@ -879,20 +878,16 @@ def websocket_handle_thumbnail(hass, connection, msg):
msg['id'], 'entity_not_found', 'Entity not found')) msg['id'], 'entity_not_found', 'Entity not found'))
return return
async def send_image(): data, content_type = await player.async_get_media_image()
"""Send image."""
data, content_type = await player.async_get_media_image()
if data is None: if data is None:
connection.send_message_outside(websocket_api.error_message( connection.send_message_outside(websocket_api.error_message(
msg['id'], 'thumbnail_fetch_failed', msg['id'], 'thumbnail_fetch_failed',
'Failed to fetch thumbnail')) 'Failed to fetch thumbnail'))
return return
connection.send_message_outside(websocket_api.result_message( connection.send_message_outside(websocket_api.result_message(
msg['id'], { msg['id'], {
'content_type': content_type, 'content_type': content_type,
'content': base64.b64encode(data).decode('utf-8') 'content': base64.b64encode(data).decode('utf-8')
})) }))
hass.async_add_job(send_image())

View file

@ -7,7 +7,7 @@ https://developers.home-assistant.io/docs/external_api_websocket.html
import asyncio import asyncio
from concurrent import futures from concurrent import futures
from contextlib import suppress from contextlib import suppress
from functools import partial, wraps from functools import partial
import json import json
import logging import logging
@ -15,20 +15,18 @@ from aiohttp import web
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
from homeassistant.const import ( from homeassistant.const import EVENT_HOMEASSISTANT_STOP, __version__
MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP, from homeassistant.core import Context, callback
__version__)
from homeassistant.core import Context, callback, HomeAssistant
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.auth import validate_password from homeassistant.components.http.auth import validate_password
from homeassistant.components.http.const import KEY_AUTHENTICATED from homeassistant.components.http.const import KEY_AUTHENTICATED
from homeassistant.components.http.ban import process_wrong_login, \ from homeassistant.components.http.ban import process_wrong_login, \
process_success_login process_success_login
from . import commands, const, decorators, messages
DOMAIN = 'websocket_api' DOMAIN = 'websocket_api'
URL = '/api/websocket' URL = '/api/websocket'
@ -36,87 +34,32 @@ DEPENDENCIES = ('http',)
MAX_PENDING_MSG = 512 MAX_PENDING_MSG = 512
ERR_ID_REUSE = 1
ERR_INVALID_FORMAT = 2
ERR_NOT_FOUND = 3
ERR_UNKNOWN_COMMAND = 4
ERR_UNKNOWN_ERROR = 5
TYPE_AUTH = 'auth'
TYPE_AUTH_INVALID = 'auth_invalid'
TYPE_AUTH_OK = 'auth_ok'
TYPE_AUTH_REQUIRED = 'auth_required'
TYPE_CALL_SERVICE = 'call_service'
TYPE_EVENT = 'event'
TYPE_GET_CONFIG = 'get_config'
TYPE_GET_SERVICES = 'get_services'
TYPE_GET_STATES = 'get_states'
TYPE_PING = 'ping'
TYPE_PONG = 'pong'
TYPE_RESULT = 'result'
TYPE_SUBSCRIBE_EVENTS = 'subscribe_events'
TYPE_UNSUBSCRIBE_EVENTS = 'unsubscribe_events'
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
JSON_DUMP = partial(json.dumps, cls=JSONEncoder) JSON_DUMP = partial(json.dumps, cls=JSONEncoder)
TYPE_AUTH = 'auth'
TYPE_AUTH_INVALID = 'auth_invalid'
TYPE_AUTH_OK = 'auth_ok'
TYPE_AUTH_REQUIRED = 'auth_required'
# Backwards compat
# pylint: disable=invalid-name
BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA
error_message = messages.error_message
result_message = messages.result_message
async_response = decorators.async_response
ws_require_user = decorators.ws_require_user
# pylint: enable=invalid-name
AUTH_MESSAGE_SCHEMA = vol.Schema({ AUTH_MESSAGE_SCHEMA = vol.Schema({
vol.Required('type'): TYPE_AUTH, vol.Required('type'): TYPE_AUTH,
vol.Exclusive('api_password', 'auth'): str, vol.Exclusive('api_password', 'auth'): str,
vol.Exclusive('access_token', 'auth'): str, vol.Exclusive('access_token', 'auth'): str,
}) })
# Minimal requirements of a message
MINIMAL_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): cv.string,
}, extra=vol.ALLOW_EXTRA)
# Base schema to extend by message handlers
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
})
SCHEMA_SUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_SUBSCRIBE_EVENTS,
vol.Optional('event_type', default=MATCH_ALL): str,
})
SCHEMA_UNSUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS,
vol.Required('subscription'): cv.positive_int,
})
SCHEMA_CALL_SERVICE = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_CALL_SERVICE,
vol.Required('domain'): str,
vol.Required('service'): str,
vol.Optional('service_data'): dict
})
SCHEMA_GET_STATES = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_STATES,
})
SCHEMA_GET_SERVICES = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_SERVICES,
})
SCHEMA_GET_CONFIG = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_CONFIG,
})
SCHEMA_PING = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_PING,
})
# Define the possible errors that occur when connections are cancelled. # Define the possible errors that occur when connections are cancelled.
# Originally, this was just asyncio.CancelledError, but issue #9546 showed # Originally, this was just asyncio.CancelledError, but issue #9546 showed
@ -148,46 +91,6 @@ def auth_invalid_message(message):
} }
def event_message(iden, event):
"""Return an event message."""
return {
'id': iden,
'type': TYPE_EVENT,
'event': event.as_dict(),
}
def error_message(iden, code, message):
"""Return an error result message."""
return {
'id': iden,
'type': TYPE_RESULT,
'success': False,
'error': {
'code': code,
'message': message,
},
}
def pong_message(iden):
"""Return a pong message."""
return {
'id': iden,
'type': TYPE_PONG,
}
def result_message(iden, result=None):
"""Return a success result message."""
return {
'id': iden,
'type': TYPE_RESULT,
'success': True,
'result': result,
}
@bind_hass @bind_hass
@callback @callback
def async_register_command(hass, command, handler, schema): def async_register_command(hass, command, handler, schema):
@ -198,43 +101,10 @@ def async_register_command(hass, command, handler, schema):
handlers[command] = (handler, schema) handlers[command] = (handler, schema)
def require_owner(func):
"""Websocket decorator to require user to be an owner."""
@wraps(func)
def with_owner(hass, connection, msg):
"""Check owner and call function."""
user = connection.request.get('hass_user')
if user is None or not user.is_owner:
connection.to_write.put_nowait(error_message(
msg['id'], 'unauthorized', 'This command is for owners only.'))
return
func(hass, connection, msg)
return with_owner
async def async_setup(hass, config): async def async_setup(hass, config):
"""Initialize the websocket API.""" """Initialize the websocket API."""
hass.http.register_view(WebsocketAPIView) hass.http.register_view(WebsocketAPIView)
commands.async_register_commands(hass)
async_register_command(hass, TYPE_SUBSCRIBE_EVENTS,
handle_subscribe_events, SCHEMA_SUBSCRIBE_EVENTS)
async_register_command(hass, TYPE_UNSUBSCRIBE_EVENTS,
handle_unsubscribe_events,
SCHEMA_UNSUBSCRIBE_EVENTS)
async_register_command(hass, TYPE_CALL_SERVICE,
handle_call_service, SCHEMA_CALL_SERVICE)
async_register_command(hass, TYPE_GET_STATES,
handle_get_states, SCHEMA_GET_STATES)
async_register_command(hass, TYPE_GET_SERVICES,
handle_get_services, SCHEMA_GET_SERVICES)
async_register_command(hass, TYPE_GET_CONFIG,
handle_get_config, SCHEMA_GET_CONFIG)
async_register_command(hass, TYPE_PING,
handle_ping, SCHEMA_PING)
return True return True
@ -389,19 +259,19 @@ class ActiveConnection:
while msg: while msg:
self.debug("Received", msg) self.debug("Received", msg)
msg = MINIMAL_MESSAGE_SCHEMA(msg) msg = messages.MINIMAL_MESSAGE_SCHEMA(msg)
cur_id = msg['id'] cur_id = msg['id']
if cur_id <= last_id: if cur_id <= last_id:
self.to_write.put_nowait(error_message( self.to_write.put_nowait(messages.error_message(
cur_id, ERR_ID_REUSE, cur_id, const.ERR_ID_REUSE,
'Identifier values have to increase.')) 'Identifier values have to increase.'))
elif msg['type'] not in handlers: elif msg['type'] not in handlers:
self.log_error( self.log_error(
'Received invalid command: {}'.format(msg['type'])) 'Received invalid command: {}'.format(msg['type']))
self.to_write.put_nowait(error_message( self.to_write.put_nowait(messages.error_message(
cur_id, ERR_UNKNOWN_COMMAND, cur_id, const.ERR_UNKNOWN_COMMAND,
'Unknown command.')) 'Unknown command.'))
else: else:
@ -410,8 +280,8 @@ class ActiveConnection:
handler(self.hass, self, schema(msg)) handler(self.hass, self, schema(msg))
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error handling message: %s', msg) _LOGGER.exception('Error handling message: %s', msg)
self.to_write.put_nowait(error_message( self.to_write.put_nowait(messages.error_message(
cur_id, ERR_UNKNOWN_ERROR, cur_id, const.ERR_UNKNOWN_ERROR,
'Unknown error.')) 'Unknown error.'))
last_id = cur_id last_id = cur_id
@ -435,8 +305,8 @@ class ActiveConnection:
else: else:
iden = None iden = None
final_message = error_message( final_message = messages.error_message(
iden, ERR_INVALID_FORMAT, error_msg) iden, const.ERR_INVALID_FORMAT, error_msg)
except TypeError as err: except TypeError as err:
if wsock.closed: if wsock.closed:
@ -485,170 +355,3 @@ class ActiveConnection:
self.debug("Closed connection") self.debug("Closed connection")
return wsock return wsock
def async_response(func):
"""Decorate an async function to handle WebSocket API messages."""
async def handle_msg_response(hass, connection, msg):
"""Create a response and handle exception."""
try:
await func(hass, connection, msg)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
connection.send_message_outside(error_message(
msg['id'], 'unknown', 'Unexpected error occurred'))
@callback
@wraps(func)
def schedule_handler(hass, connection, msg):
"""Schedule the handler."""
hass.async_create_task(handle_msg_response(hass, connection, msg))
return schedule_handler
@callback
def handle_subscribe_events(hass, connection, msg):
"""Handle subscribe events command.
Async friendly.
"""
async def forward_events(event):
"""Forward events to websocket."""
if event.event_type == EVENT_TIME_CHANGED:
return
connection.send_message_outside(event_message(msg['id'], event))
connection.event_listeners[msg['id']] = hass.bus.async_listen(
msg['event_type'], forward_events)
connection.to_write.put_nowait(result_message(msg['id']))
@callback
def handle_unsubscribe_events(hass, connection, msg):
"""Handle unsubscribe events command.
Async friendly.
"""
subscription = msg['subscription']
if subscription in connection.event_listeners:
connection.event_listeners.pop(subscription)()
connection.to_write.put_nowait(result_message(msg['id']))
else:
connection.to_write.put_nowait(error_message(
msg['id'], ERR_NOT_FOUND, 'Subscription not found.'))
@async_response
async def handle_call_service(hass, connection, msg):
"""Handle call service command.
Async friendly.
"""
blocking = True
if (msg['domain'] == 'homeassistant' and
msg['service'] in ['restart', 'stop']):
blocking = False
await hass.services.async_call(
msg['domain'], msg['service'], msg.get('service_data'), blocking,
connection.context(msg))
connection.send_message_outside(result_message(msg['id']))
@callback
def handle_get_states(hass, connection, msg):
"""Handle get states command.
Async friendly.
"""
connection.to_write.put_nowait(result_message(
msg['id'], hass.states.async_all()))
@async_response
async def handle_get_services(hass, connection, msg):
"""Handle get services command.
Async friendly.
"""
descriptions = await async_get_all_descriptions(hass)
connection.send_message_outside(
result_message(msg['id'], descriptions))
@callback
def handle_get_config(hass, connection, msg):
"""Handle get config command.
Async friendly.
"""
connection.to_write.put_nowait(result_message(
msg['id'], hass.config.as_dict()))
@callback
def handle_ping(hass, connection, msg):
"""Handle ping command.
Async friendly.
"""
connection.to_write.put_nowait(pong_message(msg['id']))
def ws_require_user(
only_owner=False, only_system_user=False, allow_system_user=True,
only_active_user=True, only_inactive_user=False):
"""Decorate function validating login user exist in current WS connection.
Will write out error message if not authenticated.
"""
def validator(func):
"""Decorate func."""
@wraps(func)
def check_current_user(hass: HomeAssistant,
connection: ActiveConnection,
msg):
"""Check current user."""
def output_error(message_id, message):
"""Output error message."""
connection.send_message_outside(error_message(
msg['id'], message_id, message))
if connection.user is None:
output_error('no_user', 'Not authenticated as a user')
return
if only_owner and not connection.user.is_owner:
output_error('only_owner', 'Only allowed as owner')
return
if (only_system_user and
not connection.user.system_generated):
output_error('only_system_user',
'Only allowed as system user')
return
if (not allow_system_user
and connection.user.system_generated):
output_error('not_system_user', 'Not allowed as system user')
return
if (only_active_user and
not connection.user.is_active):
output_error('only_active_user',
'Only allowed as active user')
return
if only_inactive_user and connection.user.is_active:
output_error('only_inactive_user',
'Not allowed as active user')
return
return func(hass, connection, msg)
return check_current_user
return validator

View file

@ -0,0 +1,183 @@
"""Commands part of Websocket API."""
import voluptuous as vol
from homeassistant.const import MATCH_ALL, EVENT_TIME_CHANGED
from homeassistant.core import callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.service import async_get_all_descriptions
from . import const, decorators, messages
TYPE_CALL_SERVICE = 'call_service'
TYPE_EVENT = 'event'
TYPE_GET_CONFIG = 'get_config'
TYPE_GET_SERVICES = 'get_services'
TYPE_GET_STATES = 'get_states'
TYPE_PING = 'ping'
TYPE_PONG = 'pong'
TYPE_SUBSCRIBE_EVENTS = 'subscribe_events'
TYPE_UNSUBSCRIBE_EVENTS = 'unsubscribe_events'
@callback
def async_register_commands(hass):
"""Register commands."""
async_reg = hass.components.websocket_api.async_register_command
async_reg(TYPE_SUBSCRIBE_EVENTS, handle_subscribe_events,
SCHEMA_SUBSCRIBE_EVENTS)
async_reg(TYPE_UNSUBSCRIBE_EVENTS, handle_unsubscribe_events,
SCHEMA_UNSUBSCRIBE_EVENTS)
async_reg(TYPE_CALL_SERVICE, handle_call_service, SCHEMA_CALL_SERVICE)
async_reg(TYPE_GET_STATES, handle_get_states, SCHEMA_GET_STATES)
async_reg(TYPE_GET_SERVICES, handle_get_services, SCHEMA_GET_SERVICES)
async_reg(TYPE_GET_CONFIG, handle_get_config, SCHEMA_GET_CONFIG)
async_reg(TYPE_PING, handle_ping, SCHEMA_PING)
SCHEMA_SUBSCRIBE_EVENTS = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_SUBSCRIBE_EVENTS,
vol.Optional('event_type', default=MATCH_ALL): str,
})
SCHEMA_UNSUBSCRIBE_EVENTS = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS,
vol.Required('subscription'): cv.positive_int,
})
SCHEMA_CALL_SERVICE = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_CALL_SERVICE,
vol.Required('domain'): str,
vol.Required('service'): str,
vol.Optional('service_data'): dict
})
SCHEMA_GET_STATES = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_STATES,
})
SCHEMA_GET_SERVICES = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_SERVICES,
})
SCHEMA_GET_CONFIG = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_CONFIG,
})
SCHEMA_PING = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_PING,
})
def event_message(iden, event):
"""Return an event message."""
return {
'id': iden,
'type': TYPE_EVENT,
'event': event.as_dict(),
}
def pong_message(iden):
"""Return a pong message."""
return {
'id': iden,
'type': TYPE_PONG,
}
@callback
def handle_subscribe_events(hass, connection, msg):
"""Handle subscribe events command.
Async friendly.
"""
async def forward_events(event):
"""Forward events to websocket."""
if event.event_type == EVENT_TIME_CHANGED:
return
connection.send_message_outside(event_message(msg['id'], event))
connection.event_listeners[msg['id']] = hass.bus.async_listen(
msg['event_type'], forward_events)
connection.to_write.put_nowait(messages.result_message(msg['id']))
@callback
def handle_unsubscribe_events(hass, connection, msg):
"""Handle unsubscribe events command.
Async friendly.
"""
subscription = msg['subscription']
if subscription in connection.event_listeners:
connection.event_listeners.pop(subscription)()
connection.to_write.put_nowait(messages.result_message(msg['id']))
else:
connection.to_write.put_nowait(messages.error_message(
msg['id'], const.ERR_NOT_FOUND, 'Subscription not found.'))
@decorators.async_response
async def handle_call_service(hass, connection, msg):
"""Handle call service command.
Async friendly.
"""
blocking = True
if (msg['domain'] == 'homeassistant' and
msg['service'] in ['restart', 'stop']):
blocking = False
await hass.services.async_call(
msg['domain'], msg['service'], msg.get('service_data'), blocking,
connection.context(msg))
connection.send_message_outside(messages.result_message(msg['id']))
@callback
def handle_get_states(hass, connection, msg):
"""Handle get states command.
Async friendly.
"""
connection.to_write.put_nowait(messages.result_message(
msg['id'], hass.states.async_all()))
@decorators.async_response
async def handle_get_services(hass, connection, msg):
"""Handle get services command.
Async friendly.
"""
descriptions = await async_get_all_descriptions(hass)
connection.send_message_outside(
messages.result_message(msg['id'], descriptions))
@callback
def handle_get_config(hass, connection, msg):
"""Handle get config command.
Async friendly.
"""
connection.to_write.put_nowait(messages.result_message(
msg['id'], hass.config.as_dict()))
@callback
def handle_ping(hass, connection, msg):
"""Handle ping command.
Async friendly.
"""
connection.to_write.put_nowait(pong_message(msg['id']))

View file

@ -0,0 +1,8 @@
"""Websocket constants."""
ERR_ID_REUSE = 1
ERR_INVALID_FORMAT = 2
ERR_NOT_FOUND = 3
ERR_UNKNOWN_COMMAND = 4
ERR_UNKNOWN_ERROR = 5
TYPE_RESULT = 'result'

View file

@ -0,0 +1,101 @@
"""Decorators for the Websocket API."""
from functools import wraps
import logging
from homeassistant.core import callback
from . import messages
_LOGGER = logging.getLogger(__name__)
def async_response(func):
"""Decorate an async function to handle WebSocket API messages."""
async def handle_msg_response(hass, connection, msg):
"""Create a response and handle exception."""
try:
await func(hass, connection, msg)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
connection.send_message_outside(messages.error_message(
msg['id'], 'unknown', 'Unexpected error occurred'))
@callback
@wraps(func)
def schedule_handler(hass, connection, msg):
"""Schedule the handler."""
hass.async_create_task(handle_msg_response(hass, connection, msg))
return schedule_handler
def require_owner(func):
"""Websocket decorator to require user to be an owner."""
@wraps(func)
def with_owner(hass, connection, msg):
"""Check owner and call function."""
user = connection.request.get('hass_user')
if user is None or not user.is_owner:
connection.to_write.put_nowait(messages.error_message(
msg['id'], 'unauthorized', 'This command is for owners only.'))
return
func(hass, connection, msg)
return with_owner
def ws_require_user(
only_owner=False, only_system_user=False, allow_system_user=True,
only_active_user=True, only_inactive_user=False):
"""Decorate function validating login user exist in current WS connection.
Will write out error message if not authenticated.
"""
def validator(func):
"""Decorate func."""
@wraps(func)
def check_current_user(hass, connection, msg):
"""Check current user."""
def output_error(message_id, message):
"""Output error message."""
connection.send_message_outside(messages.error_message(
msg['id'], message_id, message))
if connection.user is None:
output_error('no_user', 'Not authenticated as a user')
return
if only_owner and not connection.user.is_owner:
output_error('only_owner', 'Only allowed as owner')
return
if (only_system_user and
not connection.user.system_generated):
output_error('only_system_user',
'Only allowed as system user')
return
if (not allow_system_user
and connection.user.system_generated):
output_error('not_system_user', 'Not allowed as system user')
return
if (only_active_user and
not connection.user.is_active):
output_error('only_active_user',
'Only allowed as active user')
return
if only_inactive_user and connection.user.is_active:
output_error('only_inactive_user',
'Not allowed as active user')
return
return func(hass, connection, msg)
return check_current_user
return validator

View file

@ -0,0 +1,42 @@
"""Message templates for websocket commands."""
import voluptuous as vol
from homeassistant.helpers import config_validation as cv
from . import const
# Minimal requirements of a message
MINIMAL_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): cv.string,
}, extra=vol.ALLOW_EXTRA)
# Base schema to extend by message handlers
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
})
def result_message(iden, result=None):
"""Return a success result message."""
return {
'id': iden,
'type': const.TYPE_RESULT,
'success': True,
'result': result,
}
def error_message(iden, code, message):
"""Return an error result message."""
return {
'id': iden,
'type': const.TYPE_RESULT,
'success': False,
'error': {
'code': code,
'message': message,
},
}

View file

@ -7,7 +7,8 @@ import pytest
from homeassistant.setup import setup_component, async_setup_component from homeassistant.setup import setup_component, async_setup_component
from homeassistant.const import ATTR_ENTITY_PICTURE from homeassistant.const import ATTR_ENTITY_PICTURE
from homeassistant.components import camera, http, websocket_api from homeassistant.components import camera, http
from homeassistant.components.websocket_api.const import TYPE_RESULT
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.async_ import run_coroutine_threadsafe from homeassistant.util.async_ import run_coroutine_threadsafe
@ -150,7 +151,7 @@ async def test_webocket_camera_thumbnail(hass, hass_ws_client, mock_camera):
msg = await client.receive_json() msg = await client.receive_json()
assert msg['id'] == 5 assert msg['id'] == 5
assert msg['type'] == websocket_api.TYPE_RESULT assert msg['type'] == TYPE_RESULT
assert msg['success'] assert msg['success']
assert msg['result']['content_type'] == 'image/jpeg' assert msg['result']['content_type'] == 'image/jpeg'
assert msg['result']['content'] == \ assert msg['result']['content'] == \

View file

@ -9,7 +9,7 @@ from homeassistant.setup import async_setup_component
from homeassistant.components.frontend import ( from homeassistant.components.frontend import (
DOMAIN, CONF_JS_VERSION, CONF_THEMES, CONF_EXTRA_HTML_URL, DOMAIN, CONF_JS_VERSION, CONF_THEMES, CONF_EXTRA_HTML_URL,
CONF_EXTRA_HTML_URL_ES5) CONF_EXTRA_HTML_URL_ES5)
from homeassistant.components import websocket_api as wapi from homeassistant.components.websocket_api.const import TYPE_RESULT
from tests.common import mock_coro from tests.common import mock_coro
@ -213,7 +213,7 @@ async def test_missing_themes(hass, hass_ws_client):
msg = await client.receive_json() msg = await client.receive_json()
assert msg['id'] == 5 assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT assert msg['type'] == TYPE_RESULT
assert msg['success'] assert msg['success']
assert msg['result']['default_theme'] == 'default' assert msg['result']['default_theme'] == 'default'
assert msg['result']['themes'] == {} assert msg['result']['themes'] == {}
@ -252,7 +252,7 @@ async def test_get_panels(hass, hass_ws_client):
msg = await client.receive_json() msg = await client.receive_json()
assert msg['id'] == 5 assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT assert msg['type'] == TYPE_RESULT
assert msg['success'] assert msg['success']
assert msg['result']['map']['component_name'] == 'map' assert msg['result']['map']['component_name'] == 'map'
assert msg['result']['map']['url_path'] == 'map' assert msg['result']['map']['url_path'] == 'map'
@ -275,7 +275,7 @@ async def test_get_translations(hass, hass_ws_client):
msg = await client.receive_json() msg = await client.receive_json()
assert msg['id'] == 5 assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT assert msg['type'] == TYPE_RESULT
assert msg['success'] assert msg['success']
assert msg['result'] == {'resources': {'lang': 'nl'}} assert msg['result'] == {'resources': {'lang': 'nl'}}

View file

@ -3,7 +3,7 @@ from unittest.mock import patch
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.components import websocket_api as wapi from homeassistant.components.websocket_api.const import TYPE_RESULT
async def test_deprecated_lovelace_ui(hass, hass_ws_client): async def test_deprecated_lovelace_ui(hass, hass_ws_client):
@ -20,7 +20,7 @@ async def test_deprecated_lovelace_ui(hass, hass_ws_client):
msg = await client.receive_json() msg = await client.receive_json()
assert msg['id'] == 5 assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT assert msg['type'] == TYPE_RESULT
assert msg['success'] assert msg['success']
assert msg['result'] == {'hello': 'world'} assert msg['result'] == {'hello': 'world'}
@ -39,7 +39,7 @@ async def test_deprecated_lovelace_ui_not_found(hass, hass_ws_client):
msg = await client.receive_json() msg = await client.receive_json()
assert msg['id'] == 5 assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT assert msg['type'] == TYPE_RESULT
assert msg['success'] is False assert msg['success'] is False
assert msg['error']['code'] == 'file_not_found' assert msg['error']['code'] == 'file_not_found'
@ -58,7 +58,7 @@ async def test_deprecated_lovelace_ui_load_err(hass, hass_ws_client):
msg = await client.receive_json() msg = await client.receive_json()
assert msg['id'] == 5 assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT assert msg['type'] == TYPE_RESULT
assert msg['success'] is False assert msg['success'] is False
assert msg['error']['code'] == 'load_error' assert msg['error']['code'] == 'load_error'
@ -77,7 +77,7 @@ async def test_lovelace_ui(hass, hass_ws_client):
msg = await client.receive_json() msg = await client.receive_json()
assert msg['id'] == 5 assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT assert msg['type'] == TYPE_RESULT
assert msg['success'] assert msg['success']
assert msg['result'] == {'hello': 'world'} assert msg['result'] == {'hello': 'world'}
@ -96,7 +96,7 @@ async def test_lovelace_ui_not_found(hass, hass_ws_client):
msg = await client.receive_json() msg = await client.receive_json()
assert msg['id'] == 5 assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT assert msg['type'] == TYPE_RESULT
assert msg['success'] is False assert msg['success'] is False
assert msg['error']['code'] == 'file_not_found' assert msg['error']['code'] == 'file_not_found'
@ -115,6 +115,6 @@ async def test_lovelace_ui_load_err(hass, hass_ws_client):
msg = await client.receive_json() msg = await client.receive_json()
assert msg['id'] == 5 assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT assert msg['type'] == TYPE_RESULT
assert msg['success'] is False assert msg['success'] is False
assert msg['error']['code'] == 'load_error' assert msg['error']['code'] == 'load_error'

View file

@ -3,7 +3,7 @@ import base64
from unittest.mock import patch from unittest.mock import patch
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.components import websocket_api from homeassistant.components.websocket_api.const import TYPE_RESULT
from tests.common import mock_coro from tests.common import mock_coro
@ -30,7 +30,7 @@ async def test_get_panels(hass, hass_ws_client):
msg = await client.receive_json() msg = await client.receive_json()
assert msg['id'] == 5 assert msg['id'] == 5
assert msg['type'] == websocket_api.TYPE_RESULT assert msg['type'] == TYPE_RESULT
assert msg['success'] assert msg['success']
assert msg['result']['content_type'] == 'image/jpeg' assert msg['result']['content_type'] == 'image/jpeg'
assert msg['result']['content'] == \ assert msg['result']['content'] == \

View file

@ -1,5 +1,5 @@
"""The tests for the persistent notification component.""" """The tests for the persistent notification component."""
from homeassistant.components import websocket_api from homeassistant.components.websocket_api.const import TYPE_RESULT
from homeassistant.setup import setup_component, async_setup_component from homeassistant.setup import setup_component, async_setup_component
import homeassistant.components.persistent_notification as pn import homeassistant.components.persistent_notification as pn
@ -151,7 +151,7 @@ async def test_ws_get_notifications(hass, hass_ws_client):
}) })
msg = await client.receive_json() msg = await client.receive_json()
assert msg['id'] == 5 assert msg['id'] == 5
assert msg['type'] == websocket_api.TYPE_RESULT assert msg['type'] == TYPE_RESULT
assert msg['success'] assert msg['success']
notifications = msg['result'] notifications = msg['result']
assert len(notifications) == 0 assert len(notifications) == 0
@ -165,7 +165,7 @@ async def test_ws_get_notifications(hass, hass_ws_client):
}) })
msg = await client.receive_json() msg = await client.receive_json()
assert msg['id'] == 6 assert msg['id'] == 6
assert msg['type'] == websocket_api.TYPE_RESULT assert msg['type'] == TYPE_RESULT
assert msg['success'] assert msg['success']
notifications = msg['result'] notifications = msg['result']
assert len(notifications) == 1 assert len(notifications) == 1

View file

@ -1,558 +0,0 @@
"""Tests for the Home Assistant Websocket API."""
import asyncio
from unittest.mock import patch, Mock
from aiohttp import WSMsgType
from async_timeout import timeout
import pytest
from homeassistant.core import callback
from homeassistant.components import websocket_api as wapi
from homeassistant.setup import async_setup_component
from tests.common import mock_coro, async_mock_service
API_PASSWORD = 'test1234'
@pytest.fixture
def websocket_client(hass, hass_ws_client):
"""Create a websocket client."""
return hass.loop.run_until_complete(hass_ws_client(hass))
@pytest.fixture
def no_auth_websocket_client(hass, loop, aiohttp_client):
"""Websocket connection that requires authentication."""
assert loop.run_until_complete(
async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
}))
client = loop.run_until_complete(aiohttp_client(hass.http.app))
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
auth_ok = loop.run_until_complete(ws.receive_json())
assert auth_ok['type'] == wapi.TYPE_AUTH_REQUIRED
yield ws
if not ws.closed:
loop.run_until_complete(ws.close())
@pytest.fixture
def mock_low_queue():
"""Mock a low queue."""
with patch.object(wapi, 'MAX_PENDING_MSG', 5):
yield
@asyncio.coroutine
def test_auth_via_msg(no_auth_websocket_client):
"""Test authenticating."""
yield from no_auth_websocket_client.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
msg = yield from no_auth_websocket_client.receive_json()
assert msg['type'] == wapi.TYPE_AUTH_OK
@asyncio.coroutine
def test_auth_via_msg_incorrect_pass(no_auth_websocket_client):
"""Test authenticating."""
with patch('homeassistant.components.websocket_api.process_wrong_login',
return_value=mock_coro()) as mock_process_wrong_login:
yield from no_auth_websocket_client.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD + 'wrong'
})
msg = yield from no_auth_websocket_client.receive_json()
assert mock_process_wrong_login.called
assert msg['type'] == wapi.TYPE_AUTH_INVALID
assert msg['message'] == 'Invalid access token or password'
@asyncio.coroutine
def test_pre_auth_only_auth_allowed(no_auth_websocket_client):
"""Verify that before authentication, only auth messages are allowed."""
yield from no_auth_websocket_client.send_json({
'type': wapi.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = yield from no_auth_websocket_client.receive_json()
assert msg['type'] == wapi.TYPE_AUTH_INVALID
assert msg['message'].startswith('Message incorrectly formatted')
@asyncio.coroutine
def test_invalid_message_format(websocket_client):
"""Test sending invalid JSON."""
yield from websocket_client.send_json({'type': 5})
msg = yield from websocket_client.receive_json()
assert msg['type'] == wapi.TYPE_RESULT
error = msg['error']
assert error['code'] == wapi.ERR_INVALID_FORMAT
assert error['message'].startswith('Message incorrectly formatted')
@asyncio.coroutine
def test_invalid_json(websocket_client):
"""Test sending invalid JSON."""
yield from websocket_client.send_str('this is not JSON')
msg = yield from websocket_client.receive()
assert msg.type == WSMsgType.close
@asyncio.coroutine
def test_quiting_hass(hass, websocket_client):
"""Test sending invalid JSON."""
with patch.object(hass.loop, 'stop'):
yield from hass.async_stop()
msg = yield from websocket_client.receive()
assert msg.type == WSMsgType.CLOSE
@asyncio.coroutine
def test_call_service(hass, websocket_client):
"""Test call service command."""
calls = []
@callback
def service_call(call):
calls.append(call)
hass.services.async_register('domain_test', 'test_service', service_call)
yield from websocket_client.send_json({
'id': 5,
'type': wapi.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['success']
assert len(calls) == 1
call = calls[0]
assert call.domain == 'domain_test'
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
@asyncio.coroutine
def test_subscribe_unsubscribe_events(hass, websocket_client):
"""Test subscribe/unsubscribe events command."""
init_count = sum(hass.bus.async_listeners().values())
yield from websocket_client.send_json({
'id': 5,
'type': wapi.TYPE_SUBSCRIBE_EVENTS,
'event_type': 'test_event'
})
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['success']
# Verify we have a new listener
assert sum(hass.bus.async_listeners().values()) == init_count + 1
hass.bus.async_fire('ignore_event')
hass.bus.async_fire('test_event', {'hello': 'world'})
hass.bus.async_fire('ignore_event')
with timeout(3, loop=hass.loop):
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_EVENT
event = msg['event']
assert event['event_type'] == 'test_event'
assert event['data'] == {'hello': 'world'}
assert event['origin'] == 'LOCAL'
yield from websocket_client.send_json({
'id': 6,
'type': wapi.TYPE_UNSUBSCRIBE_EVENTS,
'subscription': 5
})
msg = yield from websocket_client.receive_json()
assert msg['id'] == 6
assert msg['type'] == wapi.TYPE_RESULT
assert msg['success']
# Check our listener got unsubscribed
assert sum(hass.bus.async_listeners().values()) == init_count
@asyncio.coroutine
def test_get_states(hass, websocket_client):
"""Test get_states command."""
hass.states.async_set('greeting.hello', 'world')
hass.states.async_set('greeting.bye', 'universe')
yield from websocket_client.send_json({
'id': 5,
'type': wapi.TYPE_GET_STATES,
})
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['success']
states = []
for state in hass.states.async_all():
state = state.as_dict()
state['last_changed'] = state['last_changed'].isoformat()
state['last_updated'] = state['last_updated'].isoformat()
states.append(state)
assert msg['result'] == states
@asyncio.coroutine
def test_get_services(hass, websocket_client):
"""Test get_services command."""
yield from websocket_client.send_json({
'id': 5,
'type': wapi.TYPE_GET_SERVICES,
})
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['success']
assert msg['result'] == hass.services.async_services()
@asyncio.coroutine
def test_get_config(hass, websocket_client):
"""Test get_config command."""
yield from websocket_client.send_json({
'id': 5,
'type': wapi.TYPE_GET_CONFIG,
})
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['success']
if 'components' in msg['result']:
msg['result']['components'] = set(msg['result']['components'])
if 'whitelist_external_dirs' in msg['result']:
msg['result']['whitelist_external_dirs'] = \
set(msg['result']['whitelist_external_dirs'])
assert msg['result'] == hass.config.as_dict()
@asyncio.coroutine
def test_ping(websocket_client):
"""Test get_panels command."""
yield from websocket_client.send_json({
'id': 5,
'type': wapi.TYPE_PING,
})
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_PONG
@asyncio.coroutine
def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
"""Test get_panels command."""
for idx in range(10):
yield from websocket_client.send_json({
'id': idx + 1,
'type': wapi.TYPE_PING,
})
msg = yield from websocket_client.receive()
assert msg.type == WSMsgType.close
@asyncio.coroutine
def test_unknown_command(websocket_client):
"""Test get_panels command."""
yield from websocket_client.send_json({
'id': 5,
'type': 'unknown_command',
})
msg = yield from websocket_client.receive_json()
assert not msg['success']
assert msg['error']['code'] == wapi.ERR_UNKNOWN_COMMAND
async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': hass_access_token
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
async def test_auth_active_user_inactive(hass, aiohttp_client,
hass_access_token):
"""Test authenticating with a token."""
refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_active = False
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': hass_access_token
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
async def test_auth_active_with_password_not_allow(hass, aiohttp_client):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active',
return_value=True):
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
async def test_auth_legacy_support_with_password(hass, aiohttp_client):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active',
return_value=True),\
patch('homeassistant.auth.AuthManager.support_legacy',
return_value=True):
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
async def test_auth_with_invalid_token(hass, aiohttp_client):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': 'incorrect'
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
async def test_call_service_context_with_user(hass, aiohttp_client,
hass_access_token):
"""Test that the user is set in the service call context."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
calls = async_mock_service(hass, 'domain_test', 'test_service')
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': hass_access_token
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
await ws.send_json({
'id': 5,
'type': wapi.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await ws.receive_json()
assert msg['success']
refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
assert len(calls) == 1
call = calls[0]
assert call.domain == 'domain_test'
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
assert call.context.user_id == refresh_token.user.id
async def test_call_service_context_no_user(hass, aiohttp_client):
"""Test that connection without user sets context."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
calls = async_mock_service(hass, 'domain_test', 'test_service')
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
await ws.send_json({
'id': 5,
'type': wapi.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await ws.receive_json()
assert msg['success']
assert len(calls) == 1
call = calls[0]
assert call.domain == 'domain_test'
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
assert call.context.user_id is None
async def test_handler_failing(hass, websocket_client):
"""Test a command that raises."""
hass.components.websocket_api.async_register_command(
'bla', Mock(side_effect=TypeError),
wapi.BASE_COMMAND_MESSAGE_SCHEMA.extend({'type': 'bla'}))
await websocket_client.send_json({
'id': 5,
'type': 'bla',
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert not msg['success']
assert msg['error']['code'] == wapi.ERR_UNKNOWN_ERROR

View file

@ -0,0 +1,2 @@
"""Tests for the websocket API."""
API_PASSWORD = 'test1234'

View file

@ -0,0 +1,35 @@
"""Fixtures for websocket tests."""
import pytest
from homeassistant.setup import async_setup_component
from homeassistant.components import websocket_api as wapi
from . import API_PASSWORD
@pytest.fixture
def websocket_client(hass, hass_ws_client):
"""Create a websocket client."""
return hass.loop.run_until_complete(hass_ws_client(hass))
@pytest.fixture
def no_auth_websocket_client(hass, loop, aiohttp_client):
"""Websocket connection that requires authentication."""
assert loop.run_until_complete(
async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
}))
client = loop.run_until_complete(aiohttp_client(hass.http.app))
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
auth_ok = loop.run_until_complete(ws.receive_json())
assert auth_ok['type'] == wapi.TYPE_AUTH_REQUIRED
yield ws
if not ws.closed:
loop.run_until_complete(ws.close())

View file

@ -0,0 +1,186 @@
"""Test auth of websocket API."""
from unittest.mock import patch
from homeassistant.components import websocket_api as wapi
from homeassistant.components.websocket_api import commands
from homeassistant.setup import async_setup_component
from tests.common import mock_coro
from . import API_PASSWORD
async def test_auth_via_msg(no_auth_websocket_client):
"""Test authenticating."""
await no_auth_websocket_client.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
msg = await no_auth_websocket_client.receive_json()
assert msg['type'] == wapi.TYPE_AUTH_OK
async def test_auth_via_msg_incorrect_pass(no_auth_websocket_client):
"""Test authenticating."""
with patch('homeassistant.components.websocket_api.process_wrong_login',
return_value=mock_coro()) as mock_process_wrong_login:
await no_auth_websocket_client.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD + 'wrong'
})
msg = await no_auth_websocket_client.receive_json()
assert mock_process_wrong_login.called
assert msg['type'] == wapi.TYPE_AUTH_INVALID
assert msg['message'] == 'Invalid access token or password'
async def test_pre_auth_only_auth_allowed(no_auth_websocket_client):
"""Verify that before authentication, only auth messages are allowed."""
await no_auth_websocket_client.send_json({
'type': commands.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await no_auth_websocket_client.receive_json()
assert msg['type'] == wapi.TYPE_AUTH_INVALID
assert msg['message'].startswith('Message incorrectly formatted')
async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': hass_access_token
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
async def test_auth_active_user_inactive(hass, aiohttp_client,
hass_access_token):
"""Test authenticating with a token."""
refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_active = False
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': hass_access_token
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
async def test_auth_active_with_password_not_allow(hass, aiohttp_client):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active',
return_value=True):
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
async def test_auth_legacy_support_with_password(hass, aiohttp_client):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active',
return_value=True),\
patch('homeassistant.auth.AuthManager.support_legacy',
return_value=True):
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
async def test_auth_with_invalid_token(hass, aiohttp_client):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': 'incorrect'
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID

View file

@ -0,0 +1,260 @@
"""Tests for WebSocket API commands."""
from unittest.mock import patch
from async_timeout import timeout
from homeassistant.core import callback
from homeassistant.components import websocket_api as wapi
from homeassistant.components.websocket_api import const, commands
from homeassistant.setup import async_setup_component
from tests.common import async_mock_service
from . import API_PASSWORD
async def test_call_service(hass, websocket_client):
"""Test call service command."""
calls = []
@callback
def service_call(call):
calls.append(call)
hass.services.async_register('domain_test', 'test_service', service_call)
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == const.TYPE_RESULT
assert msg['success']
assert len(calls) == 1
call = calls[0]
assert call.domain == 'domain_test'
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
async def test_subscribe_unsubscribe_events(hass, websocket_client):
"""Test subscribe/unsubscribe events command."""
init_count = sum(hass.bus.async_listeners().values())
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_SUBSCRIBE_EVENTS,
'event_type': 'test_event'
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == const.TYPE_RESULT
assert msg['success']
# Verify we have a new listener
assert sum(hass.bus.async_listeners().values()) == init_count + 1
hass.bus.async_fire('ignore_event')
hass.bus.async_fire('test_event', {'hello': 'world'})
hass.bus.async_fire('ignore_event')
with timeout(3, loop=hass.loop):
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == commands.TYPE_EVENT
event = msg['event']
assert event['event_type'] == 'test_event'
assert event['data'] == {'hello': 'world'}
assert event['origin'] == 'LOCAL'
await websocket_client.send_json({
'id': 6,
'type': commands.TYPE_UNSUBSCRIBE_EVENTS,
'subscription': 5
})
msg = await websocket_client.receive_json()
assert msg['id'] == 6
assert msg['type'] == const.TYPE_RESULT
assert msg['success']
# Check our listener got unsubscribed
assert sum(hass.bus.async_listeners().values()) == init_count
async def test_get_states(hass, websocket_client):
"""Test get_states command."""
hass.states.async_set('greeting.hello', 'world')
hass.states.async_set('greeting.bye', 'universe')
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_GET_STATES,
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == const.TYPE_RESULT
assert msg['success']
states = []
for state in hass.states.async_all():
state = state.as_dict()
state['last_changed'] = state['last_changed'].isoformat()
state['last_updated'] = state['last_updated'].isoformat()
states.append(state)
assert msg['result'] == states
async def test_get_services(hass, websocket_client):
"""Test get_services command."""
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_GET_SERVICES,
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == const.TYPE_RESULT
assert msg['success']
assert msg['result'] == hass.services.async_services()
async def test_get_config(hass, websocket_client):
"""Test get_config command."""
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_GET_CONFIG,
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == const.TYPE_RESULT
assert msg['success']
if 'components' in msg['result']:
msg['result']['components'] = set(msg['result']['components'])
if 'whitelist_external_dirs' in msg['result']:
msg['result']['whitelist_external_dirs'] = \
set(msg['result']['whitelist_external_dirs'])
assert msg['result'] == hass.config.as_dict()
async def test_ping(websocket_client):
"""Test get_panels command."""
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_PING,
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == commands.TYPE_PONG
async def test_call_service_context_with_user(hass, aiohttp_client,
hass_access_token):
"""Test that the user is set in the service call context."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
calls = async_mock_service(hass, 'domain_test', 'test_service')
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': hass_access_token
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
await ws.send_json({
'id': 5,
'type': commands.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await ws.receive_json()
assert msg['success']
refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
assert len(calls) == 1
call = calls[0]
assert call.domain == 'domain_test'
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
assert call.context.user_id == refresh_token.user.id
async def test_call_service_context_no_user(hass, aiohttp_client):
"""Test that connection without user sets context."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
calls = async_mock_service(hass, 'domain_test', 'test_service')
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
await ws.send_json({
'id': 5,
'type': commands.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await ws.receive_json()
assert msg['success']
assert len(calls) == 1
call = calls[0]
assert call.domain == 'domain_test'
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
assert call.context.user_id is None

View file

@ -0,0 +1,92 @@
"""Tests for the Home Assistant Websocket API."""
import asyncio
from unittest.mock import patch, Mock
from aiohttp import WSMsgType
import pytest
from homeassistant.components import websocket_api as wapi
from homeassistant.components.websocket_api import const, commands, messages
@pytest.fixture
def mock_low_queue():
"""Mock a low queue."""
with patch.object(wapi, 'MAX_PENDING_MSG', 5):
yield
@asyncio.coroutine
def test_invalid_message_format(websocket_client):
"""Test sending invalid JSON."""
yield from websocket_client.send_json({'type': 5})
msg = yield from websocket_client.receive_json()
assert msg['type'] == const.TYPE_RESULT
error = msg['error']
assert error['code'] == const.ERR_INVALID_FORMAT
assert error['message'].startswith('Message incorrectly formatted')
@asyncio.coroutine
def test_invalid_json(websocket_client):
"""Test sending invalid JSON."""
yield from websocket_client.send_str('this is not JSON')
msg = yield from websocket_client.receive()
assert msg.type == WSMsgType.close
@asyncio.coroutine
def test_quiting_hass(hass, websocket_client):
"""Test sending invalid JSON."""
with patch.object(hass.loop, 'stop'):
yield from hass.async_stop()
msg = yield from websocket_client.receive()
assert msg.type == WSMsgType.CLOSE
@asyncio.coroutine
def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
"""Test get_panels command."""
for idx in range(10):
yield from websocket_client.send_json({
'id': idx + 1,
'type': commands.TYPE_PING,
})
msg = yield from websocket_client.receive()
assert msg.type == WSMsgType.close
@asyncio.coroutine
def test_unknown_command(websocket_client):
"""Test get_panels command."""
yield from websocket_client.send_json({
'id': 5,
'type': 'unknown_command',
})
msg = yield from websocket_client.receive_json()
assert not msg['success']
assert msg['error']['code'] == const.ERR_UNKNOWN_COMMAND
async def test_handler_failing(hass, websocket_client):
"""Test a command that raises."""
hass.components.websocket_api.async_register_command(
'bla', Mock(side_effect=TypeError),
messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({'type': 'bla'}))
await websocket_client.send_json({
'id': 5,
'type': 'bla',
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == const.TYPE_RESULT
assert not msg['success']
assert msg['error']['code'] == const.ERR_UNKNOWN_ERROR