hass-core/homeassistant/components/websocket_api/commands.py
J. Nick Koston f0f817c361
Serialize websocket event message once (#40453)
Since most of the json serialize work for the websocket was done
multiple times for the same message, we can avoid the overhead
of serializing the same message many times (once per websocket
client) with a cache.
2020-09-22 08:47:04 -05:00

395 lines
12 KiB
Python

"""Commands part of Websocket API."""
import asyncio
import logging
import voluptuous as vol
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ
from homeassistant.components.websocket_api.const import ERR_NOT_FOUND
from homeassistant.const import EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL
from homeassistant.core import DOMAIN as HASS_DOMAIN, callback
from homeassistant.exceptions import (
HomeAssistantError,
ServiceNotFound,
TemplateError,
Unauthorized,
)
from homeassistant.helpers import config_validation as cv, entity
from homeassistant.helpers.event import TrackTemplate, async_track_template_result
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import IntegrationNotFound, async_get_integration
from . import const, decorators, messages
_LOGGER = logging.getLogger(__name__)
# mypy: allow-untyped-calls, allow-untyped-defs
@callback
def async_register_commands(hass, async_reg):
"""Register commands."""
async_reg(hass, handle_subscribe_events)
async_reg(hass, handle_unsubscribe_events)
async_reg(hass, handle_call_service)
async_reg(hass, handle_get_states)
async_reg(hass, handle_get_services)
async_reg(hass, handle_get_config)
async_reg(hass, handle_ping)
async_reg(hass, handle_render_template)
async_reg(hass, handle_manifest_list)
async_reg(hass, handle_manifest_get)
async_reg(hass, handle_entity_source)
async_reg(hass, handle_subscribe_trigger)
async_reg(hass, handle_test_condition)
def pong_message(iden):
"""Return a pong message."""
return {"id": iden, "type": "pong"}
@callback
@decorators.websocket_command(
{
vol.Required("type"): "subscribe_events",
vol.Optional("event_type", default=MATCH_ALL): str,
}
)
def handle_subscribe_events(hass, connection, msg):
"""Handle subscribe events command."""
# Circular dep
# pylint: disable=import-outside-toplevel
from .permissions import SUBSCRIBE_WHITELIST
event_type = msg["event_type"]
if event_type not in SUBSCRIBE_WHITELIST and not connection.user.is_admin:
raise Unauthorized
if event_type == EVENT_STATE_CHANGED:
@callback
def forward_events(event):
"""Forward state changed events to websocket."""
if not connection.user.permissions.check_entity(
event.data["entity_id"], POLICY_READ
):
return
connection.send_message(messages.cached_event_message(msg["id"], event))
else:
@callback
def forward_events(event):
"""Forward events to websocket."""
if event.event_type == EVENT_TIME_CHANGED:
return
connection.send_message(messages.cached_event_message(msg["id"], event))
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
event_type, forward_events
)
connection.send_message(messages.result_message(msg["id"]))
@callback
@decorators.websocket_command(
{
vol.Required("type"): "unsubscribe_events",
vol.Required("subscription"): cv.positive_int,
}
)
def handle_unsubscribe_events(hass, connection, msg):
"""Handle unsubscribe events command."""
subscription = msg["subscription"]
if subscription in connection.subscriptions:
connection.subscriptions.pop(subscription)()
connection.send_message(messages.result_message(msg["id"]))
else:
connection.send_message(
messages.error_message(
msg["id"], const.ERR_NOT_FOUND, "Subscription not found."
)
)
@decorators.websocket_command(
{
vol.Required("type"): "call_service",
vol.Required("domain"): str,
vol.Required("service"): str,
vol.Optional("service_data"): dict,
}
)
@decorators.async_response
async def handle_call_service(hass, connection, msg):
"""Handle call service command."""
blocking = True
if msg["domain"] == HASS_DOMAIN and msg["service"] in ["restart", "stop"]:
blocking = False
try:
await hass.services.async_call(
msg["domain"],
msg["service"],
msg.get("service_data"),
blocking,
connection.context(msg),
)
connection.send_message(
messages.result_message(msg["id"], {"context": connection.context(msg)})
)
except ServiceNotFound as err:
if err.domain == msg["domain"] and err.service == msg["service"]:
connection.send_message(
messages.error_message(
msg["id"], const.ERR_NOT_FOUND, "Service not found."
)
)
else:
connection.send_message(
messages.error_message(
msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err)
)
)
except HomeAssistantError as err:
connection.logger.exception(err)
connection.send_message(
messages.error_message(msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err))
)
except Exception as err: # pylint: disable=broad-except
connection.logger.exception(err)
connection.send_message(
messages.error_message(msg["id"], const.ERR_UNKNOWN_ERROR, str(err))
)
@callback
@decorators.websocket_command({vol.Required("type"): "get_states"})
def handle_get_states(hass, connection, msg):
"""Handle get states command."""
if connection.user.permissions.access_all_entities("read"):
states = hass.states.async_all()
else:
entity_perm = connection.user.permissions.check_entity
states = [
state
for state in hass.states.async_all()
if entity_perm(state.entity_id, "read")
]
connection.send_message(messages.result_message(msg["id"], states))
@decorators.websocket_command({vol.Required("type"): "get_services"})
@decorators.async_response
async def handle_get_services(hass, connection, msg):
"""Handle get services command."""
descriptions = await async_get_all_descriptions(hass)
connection.send_message(messages.result_message(msg["id"], descriptions))
@callback
@decorators.websocket_command({vol.Required("type"): "get_config"})
def handle_get_config(hass, connection, msg):
"""Handle get config command."""
connection.send_message(messages.result_message(msg["id"], hass.config.as_dict()))
@decorators.websocket_command({vol.Required("type"): "manifest/list"})
@decorators.async_response
async def handle_manifest_list(hass, connection, msg):
"""Handle integrations command."""
integrations = await asyncio.gather(
*[
async_get_integration(hass, domain)
for domain in hass.config.components
# Filter out platforms.
if "." not in domain
]
)
connection.send_result(
msg["id"], [integration.manifest for integration in integrations]
)
@decorators.websocket_command(
{vol.Required("type"): "manifest/get", vol.Required("integration"): str}
)
@decorators.async_response
async def handle_manifest_get(hass, connection, msg):
"""Handle integrations command."""
try:
integration = await async_get_integration(hass, msg["integration"])
connection.send_result(msg["id"], integration.manifest)
except IntegrationNotFound:
connection.send_error(msg["id"], const.ERR_NOT_FOUND, "Integration not found")
@callback
@decorators.websocket_command({vol.Required("type"): "ping"})
def handle_ping(hass, connection, msg):
"""Handle ping command."""
connection.send_message(pong_message(msg["id"]))
@callback
@decorators.websocket_command(
{
vol.Required("type"): "render_template",
vol.Required("template"): cv.template,
vol.Optional("entity_ids"): cv.entity_ids,
vol.Optional("variables"): dict,
}
)
def handle_render_template(hass, connection, msg):
"""Handle render_template command."""
template = msg["template"]
template.hass = hass
variables = msg.get("variables")
info = None
@callback
def _template_listener(event, updates):
nonlocal info
track_template_result = updates.pop()
result = track_template_result.result
if isinstance(result, TemplateError):
_LOGGER.error(
"TemplateError('%s') " "while processing template '%s'",
result,
track_template_result.template,
)
result = None
connection.send_message(
messages.event_message(
msg["id"], {"result": result, "listeners": info.listeners} # type: ignore
)
)
info = async_track_template_result(
hass, [TrackTemplate(template, variables)], _template_listener
)
connection.subscriptions[msg["id"]] = info.async_remove
connection.send_result(msg["id"])
hass.loop.call_soon_threadsafe(info.async_refresh)
@callback
@decorators.websocket_command(
{vol.Required("type"): "entity/source", vol.Optional("entity_id"): [cv.entity_id]}
)
def handle_entity_source(hass, connection, msg):
"""Handle entity source command."""
raw_sources = entity.entity_sources(hass)
entity_perm = connection.user.permissions.check_entity
if "entity_id" not in msg:
if connection.user.permissions.access_all_entities("read"):
sources = raw_sources
else:
sources = {
entity_id: source
for entity_id, source in raw_sources.items()
if entity_perm(entity_id, "read")
}
connection.send_message(messages.result_message(msg["id"], sources))
return
sources = {}
for entity_id in msg["entity_id"]:
if not entity_perm(entity_id, "read"):
raise Unauthorized(
context=connection.context(msg),
permission=POLICY_READ,
perm_category=CAT_ENTITIES,
)
source = raw_sources.get(entity_id)
if source is None:
connection.send_error(msg["id"], ERR_NOT_FOUND, "Entity not found")
return
sources[entity_id] = source
connection.send_result(msg["id"], sources)
@callback
@decorators.websocket_command(
{
vol.Required("type"): "subscribe_trigger",
vol.Required("trigger"): cv.TRIGGER_SCHEMA,
vol.Optional("variables"): dict,
}
)
@decorators.require_admin
@decorators.async_response
async def handle_subscribe_trigger(hass, connection, msg):
"""Handle subscribe trigger command."""
# Circular dep
# pylint: disable=import-outside-toplevel
from homeassistant.helpers import trigger
trigger_config = await trigger.async_validate_trigger_config(hass, msg["trigger"])
@callback
def forward_triggers(variables, context=None):
"""Forward events to websocket."""
connection.send_message(
messages.event_message(
msg["id"], {"variables": variables, "context": context}
)
)
connection.subscriptions[msg["id"]] = (
await trigger.async_initialize_triggers(
hass,
trigger_config,
forward_triggers,
const.DOMAIN,
const.DOMAIN,
connection.logger.log,
variables=msg.get("variables"),
)
) or (
# Some triggers won't return an unsub function. Since the caller expects
# a subscription, we're going to fake one.
lambda: None
)
connection.send_result(msg["id"])
@decorators.websocket_command(
{
vol.Required("type"): "test_condition",
vol.Required("condition"): cv.CONDITION_SCHEMA,
vol.Optional("variables"): dict,
}
)
@decorators.require_admin
@decorators.async_response
async def handle_test_condition(hass, connection, msg):
"""Handle test condition command."""
# Circular dep
# pylint: disable=import-outside-toplevel
from homeassistant.helpers import condition
check_condition = await condition.async_from_config(hass, msg["condition"])
connection.send_result(
msg["id"], {"result": check_condition(hass, msg.get("variables"))}
)