diff --git a/homeassistant/components/frontend/__init__.py b/homeassistant/components/frontend/__init__.py index 4a181c00c02..564ba286b96 100644 --- a/homeassistant/components/frontend/__init__.py +++ b/homeassistant/components/frontend/__init__.py @@ -16,8 +16,9 @@ import voluptuous as vol import jinja2 import homeassistant.helpers.config_validation as cv -from homeassistant.components.http import HomeAssistantView +from homeassistant.components.http.view import HomeAssistantView from homeassistant.components.http.const import KEY_AUTHENTICATED +from homeassistant.components import websocket_api from homeassistant.config import find_config_file, load_yaml_config_file from homeassistant.const import CONF_NAME, EVENT_THEMES_UPDATED from homeassistant.core import callback @@ -94,6 +95,10 @@ SERVICE_RELOAD_THEMES = 'reload_themes' SERVICE_SET_THEME_SCHEMA = vol.Schema({ vol.Required(CONF_NAME): cv.string, }) +WS_TYPE_GET_PANELS = 'get_panels' +SCHEMA_GET_PANELS = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({ + vol.Required('type'): WS_TYPE_GET_PANELS, +}) class AbstractPanel: @@ -291,6 +296,8 @@ def add_manifest_json_key(key, val): @asyncio.coroutine def async_setup(hass, config): """Set up the serving of the frontend.""" + hass.components.websocket_api.async_register_command( + WS_TYPE_GET_PANELS, websocket_handle_get_panels, SCHEMA_GET_PANELS) hass.http.register_view(ManifestJSONView) conf = config.get(DOMAIN, {}) @@ -597,3 +604,18 @@ def _is_latest(js_option, request): useragent = request.headers.get('User-Agent') return useragent and hass_frontend.version(useragent) + + +def websocket_handle_get_panels(hass, connection, msg): + """Handle get panels command. + + Async friendly. + """ + panels = { + panel: + connection.hass.data[DATA_PANELS][panel].to_response( + connection.hass, connection.request) + for panel in connection.hass.data[DATA_PANELS]} + + connection.to_write.put_nowait(websocket_api.result_message( + msg['id'], panels)) diff --git a/homeassistant/components/websocket_api.py b/homeassistant/components/websocket_api.py index 1e23ad19897..84c92631572 100644 --- a/homeassistant/components/websocket_api.py +++ b/homeassistant/components/websocket_api.py @@ -18,8 +18,8 @@ from voluptuous.humanize import humanize_error from homeassistant.const import ( MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP, __version__) -from homeassistant.components import frontend from homeassistant.core import callback +from homeassistant.loader import bind_hass from homeassistant.remote import JSONEncoder from homeassistant.helpers import config_validation as cv from homeassistant.helpers.service import async_get_all_descriptions @@ -46,7 +46,6 @@ TYPE_AUTH_REQUIRED = 'auth_required' TYPE_CALL_SERVICE = 'call_service' TYPE_EVENT = 'event' TYPE_GET_CONFIG = 'get_config' -TYPE_GET_PANELS = 'get_panels' TYPE_GET_SERVICES = 'get_services' TYPE_GET_STATES = 'get_states' TYPE_PING = 'ping' @@ -64,62 +63,56 @@ AUTH_MESSAGE_SCHEMA = vol.Schema({ vol.Required('api_password'): str, }) -SUBSCRIBE_EVENTS_MESSAGE_SCHEMA = vol.Schema({ +# Minimal requirements of a message +MINIMAL_MESSAGE_SCHEMA = vol.Schema({ vol.Required('id'): cv.positive_int, + vol.Required('type'): cv.string, +}, extra=vol.ALLOW_EXTRA) +# Base schema to extend by message handlers +BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({ + vol.Required('id'): cv.positive_int, +}) + + +SCHEMA_SUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): TYPE_SUBSCRIBE_EVENTS, vol.Optional('event_type', default=MATCH_ALL): str, }) -UNSUBSCRIBE_EVENTS_MESSAGE_SCHEMA = vol.Schema({ - vol.Required('id'): cv.positive_int, + +SCHEMA_UNSUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS, vol.Required('subscription'): cv.positive_int, }) -CALL_SERVICE_MESSAGE_SCHEMA = vol.Schema({ - vol.Required('id'): cv.positive_int, + +SCHEMA_CALL_SERVICE = BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): TYPE_CALL_SERVICE, vol.Required('domain'): str, vol.Required('service'): str, vol.Optional('service_data'): dict }) -GET_STATES_MESSAGE_SCHEMA = vol.Schema({ - vol.Required('id'): cv.positive_int, + +SCHEMA_GET_STATES = BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): TYPE_GET_STATES, }) -GET_SERVICES_MESSAGE_SCHEMA = vol.Schema({ - vol.Required('id'): cv.positive_int, + +SCHEMA_GET_SERVICES = BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): TYPE_GET_SERVICES, }) -GET_CONFIG_MESSAGE_SCHEMA = vol.Schema({ - vol.Required('id'): cv.positive_int, + +SCHEMA_GET_CONFIG = BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): TYPE_GET_CONFIG, }) -GET_PANELS_MESSAGE_SCHEMA = vol.Schema({ - vol.Required('id'): cv.positive_int, - vol.Required('type'): TYPE_GET_PANELS, -}) -PING_MESSAGE_SCHEMA = vol.Schema({ - vol.Required('id'): cv.positive_int, +SCHEMA_PING = BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): TYPE_PING, }) -BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({ - vol.Required('id'): cv.positive_int, - vol.Required('type'): vol.Any(TYPE_CALL_SERVICE, - TYPE_SUBSCRIBE_EVENTS, - TYPE_UNSUBSCRIBE_EVENTS, - TYPE_GET_STATES, - TYPE_GET_SERVICES, - TYPE_GET_CONFIG, - TYPE_GET_PANELS, - TYPE_PING) -}, extra=vol.ALLOW_EXTRA) # Define the possible errors that occur when connections are cancelled. # Originally, this was just asyncio.CancelledError, but issue #9546 showed @@ -191,9 +184,36 @@ def result_message(iden, result=None): } +@bind_hass +@callback +def async_register_command(hass, command, handler, schema): + """Register a websocket command.""" + handlers = hass.data.get(DOMAIN) + if handlers is None: + handlers = hass.data[DOMAIN] = {} + handlers[command] = (handler, schema) + + async def async_setup(hass, config): """Initialize the websocket API.""" hass.http.register_view(WebsocketAPIView) + + async_register_command(hass, TYPE_SUBSCRIBE_EVENTS, + handle_subscribe_events, SCHEMA_SUBSCRIBE_EVENTS) + async_register_command(hass, TYPE_UNSUBSCRIBE_EVENTS, + handle_unsubscribe_events, + SCHEMA_UNSUBSCRIBE_EVENTS) + async_register_command(hass, TYPE_CALL_SERVICE, + handle_call_service, SCHEMA_CALL_SERVICE) + async_register_command(hass, TYPE_GET_STATES, + handle_get_states, SCHEMA_GET_STATES) + async_register_command(hass, TYPE_GET_SERVICES, + handle_get_services, SCHEMA_GET_SERVICES) + async_register_command(hass, TYPE_GET_CONFIG, + handle_get_config, SCHEMA_GET_CONFIG) + async_register_command(hass, TYPE_PING, + handle_ping, SCHEMA_PING) + return True @@ -316,10 +336,11 @@ class ActiveConnection: msg = await wsock.receive_json() last_id = 0 + handlers = self.hass.data[DOMAIN] while msg: self.debug("Received", msg) - msg = BASE_COMMAND_MESSAGE_SCHEMA(msg) + msg = MINIMAL_MESSAGE_SCHEMA(msg) cur_id = msg['id'] if cur_id <= last_id: @@ -327,9 +348,13 @@ class ActiveConnection: cur_id, ERR_ID_REUSE, 'Identifier values have to increase.')) + elif msg['type'] not in handlers: + # Unknown command + break + else: - handler_name = 'handle_{}'.format(msg['type']) - getattr(self, handler_name)(msg) + handler, schema = handlers[msg['type']] + handler(self.hass, self, schema(msg)) last_id = cur_id msg = await wsock.receive_json() @@ -403,109 +428,89 @@ class ActiveConnection: return wsock - def handle_subscribe_events(self, msg): - """Handle subscribe events command. - Async friendly. - """ - msg = SUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg) +def handle_subscribe_events(hass, connection, msg): + """Handle subscribe events command. - async def forward_events(event): - """Forward events to websocket.""" - if event.event_type == EVENT_TIME_CHANGED: - return + Async friendly. + """ + async def forward_events(event): + """Forward events to websocket.""" + if event.event_type == EVENT_TIME_CHANGED: + return - self.send_message_outside(event_message(msg['id'], event)) + connection.send_message_outside(event_message(msg['id'], event)) - self.event_listeners[msg['id']] = self.hass.bus.async_listen( - msg['event_type'], forward_events) + connection.event_listeners[msg['id']] = hass.bus.async_listen( + msg['event_type'], forward_events) - self.to_write.put_nowait(result_message(msg['id'])) + connection.to_write.put_nowait(result_message(msg['id'])) - def handle_unsubscribe_events(self, msg): - """Handle unsubscribe events command. - Async friendly. - """ - msg = UNSUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg) +def handle_unsubscribe_events(hass, connection, msg): + """Handle unsubscribe events command. - subscription = msg['subscription'] + Async friendly. + """ + subscription = msg['subscription'] - if subscription in self.event_listeners: - self.event_listeners.pop(subscription)() - self.to_write.put_nowait(result_message(msg['id'])) - else: - self.to_write.put_nowait(error_message( - msg['id'], ERR_NOT_FOUND, - 'Subscription not found.')) + if subscription in connection.event_listeners: + connection.event_listeners.pop(subscription)() + connection.to_write.put_nowait(result_message(msg['id'])) + else: + connection.to_write.put_nowait(error_message( + msg['id'], ERR_NOT_FOUND, 'Subscription not found.')) - def handle_call_service(self, msg): - """Handle call service command. - Async friendly. - """ - msg = CALL_SERVICE_MESSAGE_SCHEMA(msg) +def handle_call_service(hass, connection, msg): + """Handle call service command. - async def call_service_helper(msg): - """Call a service and fire complete message.""" - await self.hass.services.async_call( - msg['domain'], msg['service'], msg.get('service_data'), True) - self.send_message_outside(result_message(msg['id'])) + Async friendly. + """ + async def call_service_helper(msg): + """Call a service and fire complete message.""" + await hass.services.async_call( + msg['domain'], msg['service'], msg.get('service_data'), True) + connection.send_message_outside(result_message(msg['id'])) - self.hass.async_add_job(call_service_helper(msg)) + hass.async_add_job(call_service_helper(msg)) - def handle_get_states(self, msg): - """Handle get states command. - Async friendly. - """ - msg = GET_STATES_MESSAGE_SCHEMA(msg) +def handle_get_states(hass, connection, msg): + """Handle get states command. - self.to_write.put_nowait(result_message( - msg['id'], self.hass.states.async_all())) + Async friendly. + """ + connection.to_write.put_nowait(result_message( + msg['id'], hass.states.async_all())) - def handle_get_services(self, msg): - """Handle get services command. - Async friendly. - """ - msg = GET_SERVICES_MESSAGE_SCHEMA(msg) +def handle_get_services(hass, connection, msg): + """Handle get services command. - async def get_services_helper(msg): - """Get available services and fire complete message.""" - descriptions = await async_get_all_descriptions(self.hass) - self.send_message_outside(result_message(msg['id'], descriptions)) + Async friendly. + """ + async def get_services_helper(msg): + """Get available services and fire complete message.""" + descriptions = await async_get_all_descriptions(hass) + connection.send_message_outside( + result_message(msg['id'], descriptions)) - self.hass.async_add_job(get_services_helper(msg)) + hass.async_add_job(get_services_helper(msg)) - def handle_get_config(self, msg): - """Handle get config command. - Async friendly. - """ - msg = GET_CONFIG_MESSAGE_SCHEMA(msg) +def handle_get_config(hass, connection, msg): + """Handle get config command. - self.to_write.put_nowait(result_message( - msg['id'], self.hass.config.as_dict())) + Async friendly. + """ + connection.to_write.put_nowait(result_message( + msg['id'], hass.config.as_dict())) - def handle_get_panels(self, msg): - """Handle get panels command. - Async friendly. - """ - msg = GET_PANELS_MESSAGE_SCHEMA(msg) - panels = { - panel: - self.hass.data[frontend.DATA_PANELS][panel].to_response( - self.hass, self.request) - for panel in self.hass.data[frontend.DATA_PANELS]} +def handle_ping(hass, connection, msg): + """Handle ping command. - self.to_write.put_nowait(result_message( - msg['id'], panels)) - - def handle_ping(self, msg): - """Handle ping command. - - Async friendly. - """ - self.to_write.put_nowait(pong_message(msg['id'])) + Async friendly. + """ + connection.to_write.put_nowait(pong_message(msg['id'])) diff --git a/tests/components/conftest.py b/tests/components/conftest.py new file mode 100644 index 00000000000..53caeb80783 --- /dev/null +++ b/tests/components/conftest.py @@ -0,0 +1,22 @@ +"""Fixtures for component testing.""" +import pytest + +from homeassistant.setup import async_setup_component + + +@pytest.fixture +def hass_ws_client(aiohttp_client): + """Websocket client fixture connected to websocket server.""" + async def create_client(hass): + """Create a websocket client.""" + wapi = hass.components.websocket_api + assert await async_setup_component(hass, 'websocket_api') + + client = await aiohttp_client(hass.http.app) + websocket = await client.ws_connect(wapi.URL) + auth_ok = await websocket.receive_json() + assert auth_ok['type'] == wapi.TYPE_AUTH_OK + + return websocket + + return create_client diff --git a/tests/components/test_frontend.py b/tests/components/test_frontend.py index c742e215738..973544495d7 100644 --- a/tests/components/test_frontend.py +++ b/tests/components/test_frontend.py @@ -9,6 +9,7 @@ from homeassistant.setup import async_setup_component from homeassistant.components.frontend import ( DOMAIN, CONF_JS_VERSION, CONF_THEMES, CONF_EXTRA_HTML_URL, CONF_EXTRA_HTML_URL_ES5, DATA_PANELS) +from homeassistant.components import websocket_api as wapi @pytest.fixture @@ -189,3 +190,26 @@ def test_panel_without_path(hass): 'test_component', 'nonexistant_file') yield from async_setup_component(hass, 'frontend', {}) assert 'test_component' not in hass.data[DATA_PANELS] + + +async def test_get_panels(hass, hass_ws_client): + """Test get_panels command.""" + await async_setup_component(hass, 'frontend') + await hass.components.frontend.async_register_built_in_panel( + 'map', 'Map', 'mdi:account-location') + + client = await hass_ws_client(hass) + await client.send_json({ + 'id': 5, + 'type': 'get_panels', + }) + + msg = await client.receive_json() + + assert msg['id'] == 5 + assert msg['type'] == wapi.TYPE_RESULT + assert msg['success'] + assert msg['result']['map']['component_name'] == 'map' + assert msg['result']['map']['url_path'] == 'map' + assert msg['result']['map']['icon'] == 'mdi:account-location' + assert msg['result']['map']['title'] == 'Map' diff --git a/tests/components/test_websocket_api.py b/tests/components/test_websocket_api.py index 4deccf65209..0a130e507d4 100644 --- a/tests/components/test_websocket_api.py +++ b/tests/components/test_websocket_api.py @@ -7,7 +7,7 @@ from async_timeout import timeout import pytest from homeassistant.core import callback -from homeassistant.components import websocket_api as wapi, frontend +from homeassistant.components import websocket_api as wapi from homeassistant.setup import async_setup_component from tests.common import mock_coro @@ -16,20 +16,9 @@ API_PASSWORD = 'test1234' @pytest.fixture -def websocket_client(loop, hass, aiohttp_client): - """Websocket client fixture connected to websocket server.""" - assert loop.run_until_complete( - async_setup_component(hass, 'websocket_api')) - - client = loop.run_until_complete(aiohttp_client(hass.http.app)) - ws = loop.run_until_complete(client.ws_connect(wapi.URL)) - auth_ok = loop.run_until_complete(ws.receive_json()) - assert auth_ok['type'] == wapi.TYPE_AUTH_OK - - yield ws - - if not ws.closed: - loop.run_until_complete(ws.close()) +def websocket_client(hass, hass_ws_client): + """Create a websocket client.""" + return hass.loop.run_until_complete(hass_ws_client(hass)) @pytest.fixture @@ -289,31 +278,6 @@ def test_get_config(hass, websocket_client): assert msg['result'] == hass.config.as_dict() -@asyncio.coroutine -def test_get_panels(hass, websocket_client): - """Test get_panels command.""" - yield from hass.components.frontend.async_register_built_in_panel( - 'map', 'Map', 'mdi:account-location') - hass.data[frontend.DATA_JS_VERSION] = 'es5' - yield from websocket_client.send_json({ - 'id': 5, - 'type': wapi.TYPE_GET_PANELS, - }) - - msg = yield from websocket_client.receive_json() - assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT - assert msg['success'] - assert msg['result'] == {'map': { - 'component_name': 'map', - 'url_path': 'map', - 'config': None, - 'url': None, - 'icon': 'mdi:account-location', - 'title': 'Map', - }} - - @asyncio.coroutine def test_ping(websocket_client): """Test get_panels command.""" @@ -337,3 +301,15 @@ def test_pending_msg_overflow(hass, mock_low_queue, websocket_client): }) msg = yield from websocket_client.receive() assert msg.type == WSMsgType.close + + +@asyncio.coroutine +def test_unknown_command(websocket_client): + """Test get_panels command.""" + yield from websocket_client.send_json({ + 'id': 5, + 'type': 'unknown_command', + }) + + msg = yield from websocket_client.receive() + assert msg.type == WSMsgType.close