Add missing type hints to websocket_api (#50915)
This commit is contained in:
parent
dc65f279a7
commit
42ff687c32
11 changed files with 251 additions and 159 deletions
|
@ -1,11 +1,12 @@
|
||||||
"""WebSocket based API for Home Assistant."""
|
"""WebSocket based API for Home Assistant."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import cast
|
from typing import Final, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
|
|
||||||
from . import commands, connection, const, decorators, http, messages # noqa: F401
|
from . import commands, connection, const, decorators, http, messages # noqa: F401
|
||||||
|
@ -34,11 +35,9 @@ from .messages import ( # noqa: F401
|
||||||
result_message,
|
result_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
DOMAIN: Final = const.DOMAIN
|
||||||
|
|
||||||
DOMAIN = const.DOMAIN
|
DEPENDENCIES: Final[tuple[str]] = ("http",)
|
||||||
|
|
||||||
DEPENDENCIES = ("http",)
|
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
|
@ -53,8 +52,8 @@ def async_register_command(
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if handler is None:
|
if handler is None:
|
||||||
handler = cast(const.WebSocketCommandHandler, command_or_handler)
|
handler = cast(const.WebSocketCommandHandler, command_or_handler)
|
||||||
command = handler._ws_command # type: ignore
|
command = handler._ws_command # type: ignore[attr-defined]
|
||||||
schema = handler._ws_schema # type: ignore
|
schema = handler._ws_schema # type: ignore[attr-defined]
|
||||||
else:
|
else:
|
||||||
command = command_or_handler
|
command = command_or_handler
|
||||||
handlers = hass.data.get(DOMAIN)
|
handlers = hass.data.get(DOMAIN)
|
||||||
|
@ -63,8 +62,8 @@ def async_register_command(
|
||||||
handlers[command] = (handler, schema)
|
handlers[command] = (handler, schema)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass, config):
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Initialize the websocket API."""
|
"""Initialize the websocket API."""
|
||||||
hass.http.register_view(http.WebsocketAPIView)
|
hass.http.register_view(http.WebsocketAPIView())
|
||||||
commands.async_register_commands(hass, async_register_command)
|
commands.async_register_commands(hass, async_register_command)
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -1,22 +1,31 @@
|
||||||
"""Handle the auth of a connection."""
|
"""Handle the auth of a connection."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import TYPE_CHECKING, Any, Final
|
||||||
|
|
||||||
|
from aiohttp.web import Request
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous.humanize import humanize_error
|
from voluptuous.humanize import humanize_error
|
||||||
|
|
||||||
from homeassistant.auth.models import RefreshToken, User
|
from homeassistant.auth.models import RefreshToken, User
|
||||||
from homeassistant.components.http.ban import process_success_login, process_wrong_login
|
from homeassistant.components.http.ban import process_success_login, process_wrong_login
|
||||||
from homeassistant.const import __version__
|
from homeassistant.const import __version__
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
from .connection import ActiveConnection
|
from .connection import ActiveConnection
|
||||||
from .error import Disconnect
|
from .error import Disconnect
|
||||||
|
|
||||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
if TYPE_CHECKING:
|
||||||
|
from .http import WebSocketAdapter
|
||||||
|
|
||||||
TYPE_AUTH = "auth"
|
|
||||||
TYPE_AUTH_INVALID = "auth_invalid"
|
|
||||||
TYPE_AUTH_OK = "auth_ok"
|
|
||||||
TYPE_AUTH_REQUIRED = "auth_required"
|
|
||||||
|
|
||||||
AUTH_MESSAGE_SCHEMA = vol.Schema(
|
TYPE_AUTH: Final = "auth"
|
||||||
|
TYPE_AUTH_INVALID: Final = "auth_invalid"
|
||||||
|
TYPE_AUTH_OK: Final = "auth_ok"
|
||||||
|
TYPE_AUTH_REQUIRED: Final = "auth_required"
|
||||||
|
|
||||||
|
AUTH_MESSAGE_SCHEMA: Final = vol.Schema(
|
||||||
{
|
{
|
||||||
vol.Required("type"): TYPE_AUTH,
|
vol.Required("type"): TYPE_AUTH,
|
||||||
vol.Exclusive("api_password", "auth"): str,
|
vol.Exclusive("api_password", "auth"): str,
|
||||||
|
@ -25,17 +34,17 @@ AUTH_MESSAGE_SCHEMA = vol.Schema(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def auth_ok_message():
|
def auth_ok_message() -> dict[str, str]:
|
||||||
"""Return an auth_ok message."""
|
"""Return an auth_ok message."""
|
||||||
return {"type": TYPE_AUTH_OK, "ha_version": __version__}
|
return {"type": TYPE_AUTH_OK, "ha_version": __version__}
|
||||||
|
|
||||||
|
|
||||||
def auth_required_message():
|
def auth_required_message() -> dict[str, str]:
|
||||||
"""Return an auth_required message."""
|
"""Return an auth_required message."""
|
||||||
return {"type": TYPE_AUTH_REQUIRED, "ha_version": __version__}
|
return {"type": TYPE_AUTH_REQUIRED, "ha_version": __version__}
|
||||||
|
|
||||||
|
|
||||||
def auth_invalid_message(message):
|
def auth_invalid_message(message: str) -> dict[str, str]:
|
||||||
"""Return an auth_invalid message."""
|
"""Return an auth_invalid message."""
|
||||||
return {"type": TYPE_AUTH_INVALID, "message": message}
|
return {"type": TYPE_AUTH_INVALID, "message": message}
|
||||||
|
|
||||||
|
@ -43,16 +52,20 @@ def auth_invalid_message(message):
|
||||||
class AuthPhase:
|
class AuthPhase:
|
||||||
"""Connection that requires client to authenticate first."""
|
"""Connection that requires client to authenticate first."""
|
||||||
|
|
||||||
def __init__(self, logger, hass, send_message, request):
|
def __init__(
|
||||||
|
self,
|
||||||
|
logger: WebSocketAdapter,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
send_message: Callable[[str | dict[str, Any]], None],
|
||||||
|
request: Request,
|
||||||
|
) -> None:
|
||||||
"""Initialize the authentiated connection."""
|
"""Initialize the authentiated connection."""
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
self._send_message = send_message
|
self._send_message = send_message
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._request = request
|
self._request = request
|
||||||
self._authenticated = False
|
|
||||||
self._connection = None
|
|
||||||
|
|
||||||
async def async_handle(self, msg):
|
async def async_handle(self, msg: dict[str, str]) -> ActiveConnection:
|
||||||
"""Handle authentication."""
|
"""Handle authentication."""
|
||||||
try:
|
try:
|
||||||
msg = AUTH_MESSAGE_SCHEMA(msg)
|
msg = AUTH_MESSAGE_SCHEMA(msg)
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
"""Commands part of Websocket API."""
|
"""Commands part of Websocket API."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import Callable
|
||||||
import json
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -8,7 +12,7 @@ from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ
|
||||||
from homeassistant.bootstrap import SIGNAL_BOOTSTRAP_INTEGRATONS
|
from homeassistant.bootstrap import SIGNAL_BOOTSTRAP_INTEGRATONS
|
||||||
from homeassistant.components.websocket_api.const import ERR_NOT_FOUND
|
from homeassistant.components.websocket_api.const import ERR_NOT_FOUND
|
||||||
from homeassistant.const import EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL
|
from homeassistant.const import EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import Context, Event, HomeAssistant, callback
|
||||||
from homeassistant.exceptions import (
|
from homeassistant.exceptions import (
|
||||||
HomeAssistantError,
|
HomeAssistantError,
|
||||||
ServiceNotFound,
|
ServiceNotFound,
|
||||||
|
@ -17,19 +21,25 @@ from homeassistant.exceptions import (
|
||||||
)
|
)
|
||||||
from homeassistant.helpers import config_validation as cv, entity, template
|
from homeassistant.helpers import config_validation as cv, entity, template
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||||
from homeassistant.helpers.event import TrackTemplate, async_track_template_result
|
from homeassistant.helpers.event import (
|
||||||
|
TrackTemplate,
|
||||||
|
TrackTemplateResult,
|
||||||
|
async_track_template_result,
|
||||||
|
)
|
||||||
from homeassistant.helpers.json import ExtendedJSONEncoder
|
from homeassistant.helpers.json import ExtendedJSONEncoder
|
||||||
from homeassistant.helpers.service import async_get_all_descriptions
|
from homeassistant.helpers.service import async_get_all_descriptions
|
||||||
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
||||||
from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations
|
from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations
|
||||||
|
|
||||||
from . import const, decorators, messages
|
from . import const, decorators, messages
|
||||||
|
from .connection import ActiveConnection
|
||||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_register_commands(hass, async_reg):
|
def async_register_commands(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
async_reg: Callable[[HomeAssistant, const.WebSocketCommandHandler], None],
|
||||||
|
) -> None:
|
||||||
"""Register commands."""
|
"""Register commands."""
|
||||||
async_reg(hass, handle_call_service)
|
async_reg(hass, handle_call_service)
|
||||||
async_reg(hass, handle_entity_source)
|
async_reg(hass, handle_entity_source)
|
||||||
|
@ -49,7 +59,7 @@ def async_register_commands(hass, async_reg):
|
||||||
async_reg(hass, handle_unsubscribe_events)
|
async_reg(hass, handle_unsubscribe_events)
|
||||||
|
|
||||||
|
|
||||||
def pong_message(iden):
|
def pong_message(iden: int) -> dict[str, Any]:
|
||||||
"""Return a pong message."""
|
"""Return a pong message."""
|
||||||
return {"id": iden, "type": "pong"}
|
return {"id": iden, "type": "pong"}
|
||||||
|
|
||||||
|
@ -61,7 +71,9 @@ def pong_message(iden):
|
||||||
vol.Optional("event_type", default=MATCH_ALL): str,
|
vol.Optional("event_type", default=MATCH_ALL): str,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def handle_subscribe_events(hass, connection, msg):
|
def handle_subscribe_events(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle subscribe events command."""
|
"""Handle subscribe events command."""
|
||||||
# Circular dep
|
# Circular dep
|
||||||
# pylint: disable=import-outside-toplevel
|
# pylint: disable=import-outside-toplevel
|
||||||
|
@ -75,7 +87,7 @@ def handle_subscribe_events(hass, connection, msg):
|
||||||
if event_type == EVENT_STATE_CHANGED:
|
if event_type == EVENT_STATE_CHANGED:
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def forward_events(event):
|
def forward_events(event: Event) -> None:
|
||||||
"""Forward state changed events to websocket."""
|
"""Forward state changed events to websocket."""
|
||||||
if not connection.user.permissions.check_entity(
|
if not connection.user.permissions.check_entity(
|
||||||
event.data["entity_id"], POLICY_READ
|
event.data["entity_id"], POLICY_READ
|
||||||
|
@ -87,7 +99,7 @@ def handle_subscribe_events(hass, connection, msg):
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def forward_events(event):
|
def forward_events(event: Event) -> None:
|
||||||
"""Forward events to websocket."""
|
"""Forward events to websocket."""
|
||||||
if event.event_type == EVENT_TIME_CHANGED:
|
if event.event_type == EVENT_TIME_CHANGED:
|
||||||
return
|
return
|
||||||
|
@ -107,11 +119,13 @@ def handle_subscribe_events(hass, connection, msg):
|
||||||
vol.Required("type"): "subscribe_bootstrap_integrations",
|
vol.Required("type"): "subscribe_bootstrap_integrations",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def handle_subscribe_bootstrap_integrations(hass, connection, msg):
|
def handle_subscribe_bootstrap_integrations(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle subscribe bootstrap integrations command."""
|
"""Handle subscribe bootstrap integrations command."""
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def forward_bootstrap_integrations(message):
|
def forward_bootstrap_integrations(message: dict[str, Any]) -> None:
|
||||||
"""Forward bootstrap integrations to websocket."""
|
"""Forward bootstrap integrations to websocket."""
|
||||||
connection.send_message(messages.event_message(msg["id"], message))
|
connection.send_message(messages.event_message(msg["id"], message))
|
||||||
|
|
||||||
|
@ -129,7 +143,9 @@ def handle_subscribe_bootstrap_integrations(hass, connection, msg):
|
||||||
vol.Required("subscription"): cv.positive_int,
|
vol.Required("subscription"): cv.positive_int,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def handle_unsubscribe_events(hass, connection, msg):
|
def handle_unsubscribe_events(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle unsubscribe events command."""
|
"""Handle unsubscribe events command."""
|
||||||
subscription = msg["subscription"]
|
subscription = msg["subscription"]
|
||||||
|
|
||||||
|
@ -154,7 +170,9 @@ def handle_unsubscribe_events(hass, connection, msg):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@decorators.async_response
|
@decorators.async_response
|
||||||
async def handle_call_service(hass, connection, msg):
|
async def handle_call_service(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle call service command."""
|
"""Handle call service command."""
|
||||||
blocking = True
|
blocking = True
|
||||||
# We do not support templates.
|
# We do not support templates.
|
||||||
|
@ -206,7 +224,9 @@ async def handle_call_service(hass, connection, msg):
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@decorators.websocket_command({vol.Required("type"): "get_states"})
|
@decorators.websocket_command({vol.Required("type"): "get_states"})
|
||||||
def handle_get_states(hass, connection, msg):
|
def handle_get_states(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle get states command."""
|
"""Handle get states command."""
|
||||||
if connection.user.permissions.access_all_entities("read"):
|
if connection.user.permissions.access_all_entities("read"):
|
||||||
states = hass.states.async_all()
|
states = hass.states.async_all()
|
||||||
|
@ -223,7 +243,9 @@ def handle_get_states(hass, connection, msg):
|
||||||
|
|
||||||
@decorators.websocket_command({vol.Required("type"): "get_services"})
|
@decorators.websocket_command({vol.Required("type"): "get_services"})
|
||||||
@decorators.async_response
|
@decorators.async_response
|
||||||
async def handle_get_services(hass, connection, msg):
|
async def handle_get_services(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle get services command."""
|
"""Handle get services command."""
|
||||||
descriptions = await async_get_all_descriptions(hass)
|
descriptions = await async_get_all_descriptions(hass)
|
||||||
connection.send_message(messages.result_message(msg["id"], descriptions))
|
connection.send_message(messages.result_message(msg["id"], descriptions))
|
||||||
|
@ -231,14 +253,18 @@ async def handle_get_services(hass, connection, msg):
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@decorators.websocket_command({vol.Required("type"): "get_config"})
|
@decorators.websocket_command({vol.Required("type"): "get_config"})
|
||||||
def handle_get_config(hass, connection, msg):
|
def handle_get_config(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle get config command."""
|
"""Handle get config command."""
|
||||||
connection.send_message(messages.result_message(msg["id"], hass.config.as_dict()))
|
connection.send_message(messages.result_message(msg["id"], hass.config.as_dict()))
|
||||||
|
|
||||||
|
|
||||||
@decorators.websocket_command({vol.Required("type"): "manifest/list"})
|
@decorators.websocket_command({vol.Required("type"): "manifest/list"})
|
||||||
@decorators.async_response
|
@decorators.async_response
|
||||||
async def handle_manifest_list(hass, connection, msg):
|
async def handle_manifest_list(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle integrations command."""
|
"""Handle integrations command."""
|
||||||
loaded_integrations = async_get_loaded_integrations(hass)
|
loaded_integrations = async_get_loaded_integrations(hass)
|
||||||
integrations = await asyncio.gather(
|
integrations = await asyncio.gather(
|
||||||
|
@ -253,7 +279,9 @@ async def handle_manifest_list(hass, connection, msg):
|
||||||
{vol.Required("type"): "manifest/get", vol.Required("integration"): str}
|
{vol.Required("type"): "manifest/get", vol.Required("integration"): str}
|
||||||
)
|
)
|
||||||
@decorators.async_response
|
@decorators.async_response
|
||||||
async def handle_manifest_get(hass, connection, msg):
|
async def handle_manifest_get(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle integrations command."""
|
"""Handle integrations command."""
|
||||||
try:
|
try:
|
||||||
integration = await async_get_integration(hass, msg["integration"])
|
integration = await async_get_integration(hass, msg["integration"])
|
||||||
|
@ -264,7 +292,9 @@ async def handle_manifest_get(hass, connection, msg):
|
||||||
|
|
||||||
@decorators.websocket_command({vol.Required("type"): "integration/setup_info"})
|
@decorators.websocket_command({vol.Required("type"): "integration/setup_info"})
|
||||||
@decorators.async_response
|
@decorators.async_response
|
||||||
async def handle_integration_setup_info(hass, connection, msg):
|
async def handle_integration_setup_info(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle integrations command."""
|
"""Handle integrations command."""
|
||||||
connection.send_result(
|
connection.send_result(
|
||||||
msg["id"],
|
msg["id"],
|
||||||
|
@ -277,7 +307,9 @@ async def handle_integration_setup_info(hass, connection, msg):
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@decorators.websocket_command({vol.Required("type"): "ping"})
|
@decorators.websocket_command({vol.Required("type"): "ping"})
|
||||||
def handle_ping(hass, connection, msg):
|
def handle_ping(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle ping command."""
|
"""Handle ping command."""
|
||||||
connection.send_message(pong_message(msg["id"]))
|
connection.send_message(pong_message(msg["id"]))
|
||||||
|
|
||||||
|
@ -293,10 +325,12 @@ def handle_ping(hass, connection, msg):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@decorators.async_response
|
@decorators.async_response
|
||||||
async def handle_render_template(hass, connection, msg):
|
async def handle_render_template(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle render_template command."""
|
"""Handle render_template command."""
|
||||||
template_str = msg["template"]
|
template_str = msg["template"]
|
||||||
template_obj = template.Template(template_str, hass)
|
template_obj = template.Template(template_str, hass) # type: ignore[no-untyped-call]
|
||||||
variables = msg.get("variables")
|
variables = msg.get("variables")
|
||||||
timeout = msg.get("timeout")
|
timeout = msg.get("timeout")
|
||||||
info = None
|
info = None
|
||||||
|
@ -319,7 +353,7 @@ async def handle_render_template(hass, connection, msg):
|
||||||
return
|
return
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _template_listener(event, updates):
|
def _template_listener(event: Event, updates: list[TrackTemplateResult]) -> None:
|
||||||
nonlocal info
|
nonlocal info
|
||||||
track_template_result = updates.pop()
|
track_template_result = updates.pop()
|
||||||
result = track_template_result.result
|
result = track_template_result.result
|
||||||
|
@ -329,7 +363,7 @@ async def handle_render_template(hass, connection, msg):
|
||||||
|
|
||||||
connection.send_message(
|
connection.send_message(
|
||||||
messages.event_message(
|
messages.event_message(
|
||||||
msg["id"], {"result": result, "listeners": info.listeners} # type: ignore
|
msg["id"], {"result": result, "listeners": info.listeners} # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -356,7 +390,9 @@ async def handle_render_template(hass, connection, msg):
|
||||||
@decorators.websocket_command(
|
@decorators.websocket_command(
|
||||||
{vol.Required("type"): "entity/source", vol.Optional("entity_id"): [cv.entity_id]}
|
{vol.Required("type"): "entity/source", vol.Optional("entity_id"): [cv.entity_id]}
|
||||||
)
|
)
|
||||||
def handle_entity_source(hass, connection, msg):
|
def handle_entity_source(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle entity source command."""
|
"""Handle entity source command."""
|
||||||
raw_sources = entity.entity_sources(hass)
|
raw_sources = entity.entity_sources(hass)
|
||||||
entity_perm = connection.user.permissions.check_entity
|
entity_perm = connection.user.permissions.check_entity
|
||||||
|
@ -404,7 +440,9 @@ def handle_entity_source(hass, connection, msg):
|
||||||
)
|
)
|
||||||
@decorators.require_admin
|
@decorators.require_admin
|
||||||
@decorators.async_response
|
@decorators.async_response
|
||||||
async def handle_subscribe_trigger(hass, connection, msg):
|
async def handle_subscribe_trigger(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle subscribe trigger command."""
|
"""Handle subscribe trigger command."""
|
||||||
# Circular dep
|
# Circular dep
|
||||||
# pylint: disable=import-outside-toplevel
|
# pylint: disable=import-outside-toplevel
|
||||||
|
@ -413,7 +451,9 @@ async def handle_subscribe_trigger(hass, connection, msg):
|
||||||
trigger_config = await trigger.async_validate_trigger_config(hass, msg["trigger"])
|
trigger_config = await trigger.async_validate_trigger_config(hass, msg["trigger"])
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def forward_triggers(variables, context=None):
|
def forward_triggers(
|
||||||
|
variables: dict[str, Any], context: Context | None = None
|
||||||
|
) -> None:
|
||||||
"""Forward events to websocket."""
|
"""Forward events to websocket."""
|
||||||
message = messages.event_message(
|
message = messages.event_message(
|
||||||
msg["id"], {"variables": variables, "context": context}
|
msg["id"], {"variables": variables, "context": context}
|
||||||
|
@ -449,7 +489,9 @@ async def handle_subscribe_trigger(hass, connection, msg):
|
||||||
)
|
)
|
||||||
@decorators.require_admin
|
@decorators.require_admin
|
||||||
@decorators.async_response
|
@decorators.async_response
|
||||||
async def handle_test_condition(hass, connection, msg):
|
async def handle_test_condition(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle test condition command."""
|
"""Handle test condition command."""
|
||||||
# Circular dep
|
# Circular dep
|
||||||
# pylint: disable=import-outside-toplevel
|
# pylint: disable=import-outside-toplevel
|
||||||
|
@ -470,7 +512,9 @@ async def handle_test_condition(hass, connection, msg):
|
||||||
)
|
)
|
||||||
@decorators.require_admin
|
@decorators.require_admin
|
||||||
@decorators.async_response
|
@decorators.async_response
|
||||||
async def handle_execute_script(hass, connection, msg):
|
async def handle_execute_script(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Handle execute script command."""
|
"""Handle execute script command."""
|
||||||
# Circular dep
|
# Circular dep
|
||||||
# pylint: disable=import-outside-toplevel
|
# pylint: disable=import-outside-toplevel
|
||||||
|
|
|
@ -3,48 +3,50 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Hashable
|
from collections.abc import Hashable
|
||||||
from typing import Any, Callable
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import Context, callback
|
from homeassistant.auth.models import RefreshToken, User
|
||||||
|
from homeassistant.core import Context, HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError, Unauthorized
|
from homeassistant.exceptions import HomeAssistantError, Unauthorized
|
||||||
|
|
||||||
from . import const, messages
|
from . import const, messages
|
||||||
|
|
||||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
if TYPE_CHECKING:
|
||||||
|
from .http import WebSocketAdapter
|
||||||
|
|
||||||
|
|
||||||
class ActiveConnection:
|
class ActiveConnection:
|
||||||
"""Handle an active websocket client connection."""
|
"""Handle an active websocket client connection."""
|
||||||
|
|
||||||
def __init__(self, logger, hass, send_message, user, refresh_token):
|
def __init__(
|
||||||
|
self,
|
||||||
|
logger: WebSocketAdapter,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
send_message: Callable[[str | dict[str, Any]], None],
|
||||||
|
user: User,
|
||||||
|
refresh_token: RefreshToken,
|
||||||
|
) -> None:
|
||||||
"""Initialize an active connection."""
|
"""Initialize an active connection."""
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.send_message = send_message
|
self.send_message = send_message
|
||||||
self.user = user
|
self.user = user
|
||||||
if refresh_token:
|
self.refresh_token_id = refresh_token.id
|
||||||
self.refresh_token_id = refresh_token.id
|
|
||||||
else:
|
|
||||||
self.refresh_token_id = None
|
|
||||||
|
|
||||||
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
||||||
self.last_id = 0
|
self.last_id = 0
|
||||||
|
|
||||||
def context(self, msg):
|
def context(self, msg: dict[str, Any]) -> Context:
|
||||||
"""Return a context."""
|
"""Return a context."""
|
||||||
user = self.user
|
return Context(user_id=self.user.id)
|
||||||
if user is None:
|
|
||||||
return Context()
|
|
||||||
return Context(user_id=user.id)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def send_result(self, msg_id: int, result: Any | None = None) -> None:
|
def send_result(self, msg_id: int, result: Any | None = None) -> None:
|
||||||
"""Send a result message."""
|
"""Send a result message."""
|
||||||
self.send_message(messages.result_message(msg_id, result))
|
self.send_message(messages.result_message(msg_id, result))
|
||||||
|
|
||||||
async def send_big_result(self, msg_id, result):
|
async def send_big_result(self, msg_id: int, result: Any) -> None:
|
||||||
"""Send a result message that would be expensive to JSON serialize."""
|
"""Send a result message that would be expensive to JSON serialize."""
|
||||||
content = await self.hass.async_add_executor_job(
|
content = await self.hass.async_add_executor_job(
|
||||||
const.JSON_DUMP, messages.result_message(msg_id, result)
|
const.JSON_DUMP, messages.result_message(msg_id, result)
|
||||||
|
@ -57,7 +59,7 @@ class ActiveConnection:
|
||||||
self.send_message(messages.error_message(msg_id, code, message))
|
self.send_message(messages.error_message(msg_id, code, message))
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_handle(self, msg):
|
def async_handle(self, msg: dict[str, Any]) -> None:
|
||||||
"""Handle a single incoming message."""
|
"""Handle a single incoming message."""
|
||||||
handlers = self.hass.data[const.DOMAIN]
|
handlers = self.hass.data[const.DOMAIN]
|
||||||
|
|
||||||
|
@ -102,13 +104,13 @@ class ActiveConnection:
|
||||||
self.last_id = cur_id
|
self.last_id = cur_id
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_close(self):
|
def async_close(self) -> None:
|
||||||
"""Close down connection."""
|
"""Close down connection."""
|
||||||
for unsub in self.subscriptions.values():
|
for unsub in self.subscriptions.values():
|
||||||
unsub()
|
unsub()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_handle_exception(self, msg, err):
|
def async_handle_exception(self, msg: dict[str, Any], err: Exception) -> None:
|
||||||
"""Handle an exception while processing a handler."""
|
"""Handle an exception while processing a handler."""
|
||||||
log_handler = self.logger.error
|
log_handler = self.logger.error
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
"""Websocket constants."""
|
"""Websocket constants."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, Callable
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Final
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers.json import JSONEncoder
|
from homeassistant.helpers.json import JSONEncoder
|
||||||
|
@ -12,37 +14,42 @@ if TYPE_CHECKING:
|
||||||
from .connection import ActiveConnection
|
from .connection import ActiveConnection
|
||||||
|
|
||||||
|
|
||||||
WebSocketCommandHandler = Callable[[HomeAssistant, "ActiveConnection", dict], None]
|
WebSocketCommandHandler = Callable[
|
||||||
|
[HomeAssistant, "ActiveConnection", Dict[str, Any]], None
|
||||||
|
]
|
||||||
|
AsyncWebSocketCommandHandler = Callable[
|
||||||
|
[HomeAssistant, "ActiveConnection", Dict[str, Any]], Awaitable[None]
|
||||||
|
]
|
||||||
|
|
||||||
DOMAIN = "websocket_api"
|
DOMAIN: Final = "websocket_api"
|
||||||
URL = "/api/websocket"
|
URL: Final = "/api/websocket"
|
||||||
PENDING_MSG_PEAK = 512
|
PENDING_MSG_PEAK: Final = 512
|
||||||
PENDING_MSG_PEAK_TIME = 5
|
PENDING_MSG_PEAK_TIME: Final = 5
|
||||||
MAX_PENDING_MSG = 2048
|
MAX_PENDING_MSG: Final = 2048
|
||||||
|
|
||||||
ERR_ID_REUSE = "id_reuse"
|
ERR_ID_REUSE: Final = "id_reuse"
|
||||||
ERR_INVALID_FORMAT = "invalid_format"
|
ERR_INVALID_FORMAT: Final = "invalid_format"
|
||||||
ERR_NOT_FOUND = "not_found"
|
ERR_NOT_FOUND: Final = "not_found"
|
||||||
ERR_NOT_SUPPORTED = "not_supported"
|
ERR_NOT_SUPPORTED: Final = "not_supported"
|
||||||
ERR_HOME_ASSISTANT_ERROR = "home_assistant_error"
|
ERR_HOME_ASSISTANT_ERROR: Final = "home_assistant_error"
|
||||||
ERR_UNKNOWN_COMMAND = "unknown_command"
|
ERR_UNKNOWN_COMMAND: Final = "unknown_command"
|
||||||
ERR_UNKNOWN_ERROR = "unknown_error"
|
ERR_UNKNOWN_ERROR: Final = "unknown_error"
|
||||||
ERR_UNAUTHORIZED = "unauthorized"
|
ERR_UNAUTHORIZED: Final = "unauthorized"
|
||||||
ERR_TIMEOUT = "timeout"
|
ERR_TIMEOUT: Final = "timeout"
|
||||||
ERR_TEMPLATE_ERROR = "template_error"
|
ERR_TEMPLATE_ERROR: Final = "template_error"
|
||||||
|
|
||||||
TYPE_RESULT = "result"
|
TYPE_RESULT: Final = "result"
|
||||||
|
|
||||||
# Define the possible errors that occur when connections are cancelled.
|
# Define the possible errors that occur when connections are cancelled.
|
||||||
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
|
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
|
||||||
# that futures.CancelledErrors can also occur in some situations.
|
# that futures.CancelledErrors can also occur in some situations.
|
||||||
CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError)
|
CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError)
|
||||||
|
|
||||||
# Event types
|
# Event types
|
||||||
SIGNAL_WEBSOCKET_CONNECTED = "websocket_connected"
|
SIGNAL_WEBSOCKET_CONNECTED: Final = "websocket_connected"
|
||||||
SIGNAL_WEBSOCKET_DISCONNECTED = "websocket_disconnected"
|
SIGNAL_WEBSOCKET_DISCONNECTED: Final = "websocket_disconnected"
|
||||||
|
|
||||||
# Data used to store the current connection list
|
# Data used to store the current connection list
|
||||||
DATA_CONNECTIONS = f"{DOMAIN}.connections"
|
DATA_CONNECTIONS: Final = f"{DOMAIN}.connections"
|
||||||
|
|
||||||
JSON_DUMP = partial(json.dumps, cls=JSONEncoder, allow_nan=False)
|
JSON_DUMP: Final = partial(json.dumps, cls=JSONEncoder, allow_nan=False)
|
||||||
|
|
|
@ -2,9 +2,10 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Awaitable
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.exceptions import Unauthorized
|
from homeassistant.exceptions import Unauthorized
|
||||||
|
@ -12,10 +13,13 @@ from homeassistant.exceptions import Unauthorized
|
||||||
from . import const, messages
|
from . import const, messages
|
||||||
from .connection import ActiveConnection
|
from .connection import ActiveConnection
|
||||||
|
|
||||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
|
||||||
|
|
||||||
|
async def _handle_async_response(
|
||||||
async def _handle_async_response(func, hass, connection, msg):
|
func: const.AsyncWebSocketCommandHandler,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
connection: ActiveConnection,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
"""Create a response and handle exception."""
|
"""Create a response and handle exception."""
|
||||||
try:
|
try:
|
||||||
await func(hass, connection, msg)
|
await func(hass, connection, msg)
|
||||||
|
@ -24,13 +28,15 @@ async def _handle_async_response(func, hass, connection, msg):
|
||||||
|
|
||||||
|
|
||||||
def async_response(
|
def async_response(
|
||||||
func: Callable[[HomeAssistant, ActiveConnection, dict], Awaitable[None]]
|
func: const.AsyncWebSocketCommandHandler,
|
||||||
) -> const.WebSocketCommandHandler:
|
) -> const.WebSocketCommandHandler:
|
||||||
"""Decorate an async function to handle WebSocket API messages."""
|
"""Decorate an async function to handle WebSocket API messages."""
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def schedule_handler(hass, connection, msg):
|
def schedule_handler(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Schedule the handler."""
|
"""Schedule the handler."""
|
||||||
# As the webserver is now started before the start
|
# As the webserver is now started before the start
|
||||||
# event we do not want to block for websocket responders
|
# event we do not want to block for websocket responders
|
||||||
|
@ -43,7 +49,9 @@ def require_admin(func: const.WebSocketCommandHandler) -> const.WebSocketCommand
|
||||||
"""Websocket decorator to require user to be an admin."""
|
"""Websocket decorator to require user to be an admin."""
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def with_admin(hass, connection, msg):
|
def with_admin(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Check admin and call function."""
|
"""Check admin and call function."""
|
||||||
user = connection.user
|
user = connection.user
|
||||||
|
|
||||||
|
@ -56,34 +64,32 @@ def require_admin(func: const.WebSocketCommandHandler) -> const.WebSocketCommand
|
||||||
|
|
||||||
|
|
||||||
def ws_require_user(
|
def ws_require_user(
|
||||||
only_owner=False,
|
only_owner: bool = False,
|
||||||
only_system_user=False,
|
only_system_user: bool = False,
|
||||||
allow_system_user=True,
|
allow_system_user: bool = True,
|
||||||
only_active_user=True,
|
only_active_user: bool = True,
|
||||||
only_inactive_user=False,
|
only_inactive_user: bool = False,
|
||||||
):
|
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
|
||||||
"""Decorate function validating login user exist in current WS connection.
|
"""Decorate function validating login user exist in current WS connection.
|
||||||
|
|
||||||
Will write out error message if not authenticated.
|
Will write out error message if not authenticated.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def validator(func):
|
def validator(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
|
||||||
"""Decorate func."""
|
"""Decorate func."""
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def check_current_user(hass, connection, msg):
|
def check_current_user(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Check current user."""
|
"""Check current user."""
|
||||||
|
|
||||||
def output_error(message_id, message):
|
def output_error(message_id: str, message: str) -> None:
|
||||||
"""Output error message."""
|
"""Output error message."""
|
||||||
connection.send_message(
|
connection.send_message(
|
||||||
messages.error_message(msg["id"], message_id, message)
|
messages.error_message(msg["id"], message_id, message)
|
||||||
)
|
)
|
||||||
|
|
||||||
if connection.user is None:
|
|
||||||
output_error("no_user", "Not authenticated as a user")
|
|
||||||
return
|
|
||||||
|
|
||||||
if only_owner and not connection.user.is_owner:
|
if only_owner and not connection.user.is_owner:
|
||||||
output_error("only_owner", "Only allowed as owner")
|
output_error("only_owner", "Only allowed as owner")
|
||||||
return
|
return
|
||||||
|
@ -112,16 +118,16 @@ def ws_require_user(
|
||||||
|
|
||||||
|
|
||||||
def websocket_command(
|
def websocket_command(
|
||||||
schema: dict,
|
schema: dict[vol.Marker, Any],
|
||||||
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
|
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
|
||||||
"""Tag a function as a websocket command."""
|
"""Tag a function as a websocket command."""
|
||||||
command = schema["type"]
|
command = schema["type"]
|
||||||
|
|
||||||
def decorate(func):
|
def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
|
||||||
"""Decorate ws command function."""
|
"""Decorate ws command function."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema)
|
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined]
|
||||||
func._ws_command = command
|
func._ws_command = command # type: ignore[attr-defined]
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorate
|
return decorate
|
||||||
|
|
|
@ -2,15 +2,18 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import Callable
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
|
import datetime as dt
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any, Final
|
||||||
|
|
||||||
from aiohttp import WSMsgType, web
|
from aiohttp import WSMsgType, web
|
||||||
import async_timeout
|
import async_timeout
|
||||||
|
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import Event, HomeAssistant, callback
|
||||||
from homeassistant.helpers.event import async_call_later
|
from homeassistant.helpers.event import async_call_later
|
||||||
|
|
||||||
from .auth import AuthPhase, auth_required_message
|
from .auth import AuthPhase, auth_required_message
|
||||||
|
@ -27,16 +30,15 @@ from .const import (
|
||||||
from .error import Disconnect
|
from .error import Disconnect
|
||||||
from .messages import message_to_json
|
from .messages import message_to_json
|
||||||
|
|
||||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
|
||||||
_WS_LOGGER = logging.getLogger(f"{__name__}.connection")
|
|
||||||
|
|
||||||
|
|
||||||
class WebsocketAPIView(HomeAssistantView):
|
class WebsocketAPIView(HomeAssistantView):
|
||||||
"""View to serve a websockets endpoint."""
|
"""View to serve a websockets endpoint."""
|
||||||
|
|
||||||
name = "websocketapi"
|
name: str = "websocketapi"
|
||||||
url = URL
|
url: str = URL
|
||||||
requires_auth = False
|
requires_auth: bool = False
|
||||||
|
|
||||||
async def get(self, request: web.Request) -> web.WebSocketResponse:
|
async def get(self, request: web.Request) -> web.WebSocketResponse:
|
||||||
"""Handle an incoming websocket connection."""
|
"""Handle an incoming websocket connection."""
|
||||||
|
@ -46,7 +48,7 @@ class WebsocketAPIView(HomeAssistantView):
|
||||||
class WebSocketAdapter(logging.LoggerAdapter):
|
class WebSocketAdapter(logging.LoggerAdapter):
|
||||||
"""Add connection id to websocket messages."""
|
"""Add connection id to websocket messages."""
|
||||||
|
|
||||||
def process(self, msg, kwargs):
|
def process(self, msg: str, kwargs: Any) -> tuple[str, Any]:
|
||||||
"""Add connid to websocket log messages."""
|
"""Add connid to websocket log messages."""
|
||||||
return f'[{self.extra["connid"]}] {msg}', kwargs
|
return f'[{self.extra["connid"]}] {msg}', kwargs
|
||||||
|
|
||||||
|
@ -54,20 +56,21 @@ class WebSocketAdapter(logging.LoggerAdapter):
|
||||||
class WebSocketHandler:
|
class WebSocketHandler:
|
||||||
"""Handle an active websocket client connection."""
|
"""Handle an active websocket client connection."""
|
||||||
|
|
||||||
def __init__(self, hass, request):
|
def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
|
||||||
"""Initialize an active connection."""
|
"""Initialize an active connection."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.request = request
|
self.request = request
|
||||||
self.wsock: web.WebSocketResponse | None = None
|
self.wsock: web.WebSocketResponse | None = None
|
||||||
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
|
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
|
||||||
self._handle_task = None
|
self._handle_task: asyncio.Task | None = None
|
||||||
self._writer_task = None
|
self._writer_task: asyncio.Task | None = None
|
||||||
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
|
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
|
||||||
self._peak_checker_unsub = None
|
self._peak_checker_unsub: Callable[[], None] | None = None
|
||||||
|
|
||||||
async def _writer(self):
|
async def _writer(self) -> None:
|
||||||
"""Write outgoing messages."""
|
"""Write outgoing messages."""
|
||||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||||
|
assert self.wsock is not None
|
||||||
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
||||||
while not self.wsock.closed:
|
while not self.wsock.closed:
|
||||||
message = await self._to_write.get()
|
message = await self._to_write.get()
|
||||||
|
@ -78,12 +81,12 @@ class WebSocketHandler:
|
||||||
await self.wsock.send_str(message)
|
await self.wsock.send_str(message)
|
||||||
|
|
||||||
# Clean up the peaker checker when we shut down the writer
|
# Clean up the peaker checker when we shut down the writer
|
||||||
if self._peak_checker_unsub:
|
if self._peak_checker_unsub is not None:
|
||||||
self._peak_checker_unsub()
|
self._peak_checker_unsub()
|
||||||
self._peak_checker_unsub = None
|
self._peak_checker_unsub = None
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _send_message(self, message):
|
def _send_message(self, message: str | dict[str, Any]) -> None:
|
||||||
"""Send a message to the client.
|
"""Send a message to the client.
|
||||||
|
|
||||||
Closes connection if the client is not reading the messages.
|
Closes connection if the client is not reading the messages.
|
||||||
|
@ -114,7 +117,7 @@ class WebSocketHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _check_write_peak(self, _):
|
def _check_write_peak(self, _utc_time: dt.datetime) -> None:
|
||||||
"""Check that we are no longer above the write peak."""
|
"""Check that we are no longer above the write peak."""
|
||||||
self._peak_checker_unsub = None
|
self._peak_checker_unsub = None
|
||||||
|
|
||||||
|
@ -129,10 +132,12 @@ class WebSocketHandler:
|
||||||
self._cancel()
|
self._cancel()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _cancel(self):
|
def _cancel(self) -> None:
|
||||||
"""Cancel the connection."""
|
"""Cancel the connection."""
|
||||||
self._handle_task.cancel()
|
if self._handle_task is not None:
|
||||||
self._writer_task.cancel()
|
self._handle_task.cancel()
|
||||||
|
if self._writer_task is not None:
|
||||||
|
self._writer_task.cancel()
|
||||||
|
|
||||||
async def async_handle(self) -> web.WebSocketResponse:
|
async def async_handle(self) -> web.WebSocketResponse:
|
||||||
"""Handle a websocket response."""
|
"""Handle a websocket response."""
|
||||||
|
@ -143,7 +148,7 @@ class WebSocketHandler:
|
||||||
self._handle_task = asyncio.current_task()
|
self._handle_task = asyncio.current_task()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def handle_hass_stop(event):
|
def handle_hass_stop(event: Event) -> None:
|
||||||
"""Cancel this connection."""
|
"""Cancel this connection."""
|
||||||
self._cancel()
|
self._cancel()
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, Final
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -17,28 +17,27 @@ from homeassistant.util.yaml.loader import JSON_TYPE
|
||||||
|
|
||||||
from . import const
|
from . import const
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER: Final = logging.getLogger(__name__)
|
||||||
# mypy: allow-untyped-defs
|
|
||||||
|
|
||||||
# Minimal requirements of a message
|
# Minimal requirements of a message
|
||||||
MINIMAL_MESSAGE_SCHEMA = vol.Schema(
|
MINIMAL_MESSAGE_SCHEMA: Final = vol.Schema(
|
||||||
{vol.Required("id"): cv.positive_int, vol.Required("type"): cv.string},
|
{vol.Required("id"): cv.positive_int, vol.Required("type"): cv.string},
|
||||||
extra=vol.ALLOW_EXTRA,
|
extra=vol.ALLOW_EXTRA,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Base schema to extend by message handlers
|
# Base schema to extend by message handlers
|
||||||
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({vol.Required("id"): cv.positive_int})
|
BASE_COMMAND_MESSAGE_SCHEMA: Final = vol.Schema({vol.Required("id"): cv.positive_int})
|
||||||
|
|
||||||
IDEN_TEMPLATE = "__IDEN__"
|
IDEN_TEMPLATE: Final = "__IDEN__"
|
||||||
IDEN_JSON_TEMPLATE = '"__IDEN__"'
|
IDEN_JSON_TEMPLATE: Final = '"__IDEN__"'
|
||||||
|
|
||||||
|
|
||||||
def result_message(iden: int, result: Any = None) -> dict:
|
def result_message(iden: int, result: Any = None) -> dict[str, Any]:
|
||||||
"""Return a success result message."""
|
"""Return a success result message."""
|
||||||
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}
|
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}
|
||||||
|
|
||||||
|
|
||||||
def error_message(iden: int, code: str, message: str) -> dict:
|
def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]:
|
||||||
"""Return an error result message."""
|
"""Return an error result message."""
|
||||||
return {
|
return {
|
||||||
"id": iden,
|
"id": iden,
|
||||||
|
@ -48,7 +47,7 @@ def error_message(iden: int, code: str, message: str) -> dict:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def event_message(iden: JSON_TYPE, event: Any) -> dict:
|
def event_message(iden: JSON_TYPE, event: Any) -> dict[str, Any]:
|
||||||
"""Return an event message."""
|
"""Return an event message."""
|
||||||
return {"id": iden, "type": "event", "event": event}
|
return {"id": iden, "type": "event", "event": event}
|
||||||
|
|
||||||
|
@ -75,7 +74,7 @@ def _cached_event_message(event: Event) -> str:
|
||||||
return message_to_json(event_message(IDEN_TEMPLATE, event))
|
return message_to_json(event_message(IDEN_TEMPLATE, event))
|
||||||
|
|
||||||
|
|
||||||
def message_to_json(message: Any) -> str:
|
def message_to_json(message: dict[str, Any]) -> str:
|
||||||
"""Serialize a websocket message to json."""
|
"""Serialize a websocket message to json."""
|
||||||
try:
|
try:
|
||||||
return const.JSON_DUMP(message)
|
return const.JSON_DUMP(message)
|
||||||
|
|
|
@ -2,6 +2,10 @@
|
||||||
|
|
||||||
Separate file to avoid circular imports.
|
Separate file to avoid circular imports.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
from homeassistant.components.frontend import EVENT_PANELS_UPDATED
|
from homeassistant.components.frontend import EVENT_PANELS_UPDATED
|
||||||
from homeassistant.components.lovelace.const import EVENT_LOVELACE_UPDATED
|
from homeassistant.components.lovelace.const import EVENT_LOVELACE_UPDATED
|
||||||
from homeassistant.components.persistent_notification import (
|
from homeassistant.components.persistent_notification import (
|
||||||
|
@ -22,7 +26,7 @@ from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
|
||||||
|
|
||||||
# These are events that do not contain any sensitive data
|
# These are events that do not contain any sensitive data
|
||||||
# Except for state_changed, which is handled accordingly.
|
# Except for state_changed, which is handled accordingly.
|
||||||
SUBSCRIBE_ALLOWLIST = {
|
SUBSCRIBE_ALLOWLIST: Final[set[str]] = {
|
||||||
EVENT_AREA_REGISTRY_UPDATED,
|
EVENT_AREA_REGISTRY_UPDATED,
|
||||||
EVENT_COMPONENT_LOADED,
|
EVENT_COMPONENT_LOADED,
|
||||||
EVENT_CORE_CONFIG_UPDATE,
|
EVENT_CORE_CONFIG_UPDATE,
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
"""Entity to track connections to websocket API."""
|
"""Entity to track connections to websocket API."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.components.sensor import SensorEntity
|
from homeassistant.components.sensor import SensorEntity
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
DATA_CONNECTIONS,
|
DATA_CONNECTIONS,
|
||||||
|
@ -9,10 +14,13 @@ from .const import (
|
||||||
SIGNAL_WEBSOCKET_DISCONNECTED,
|
SIGNAL_WEBSOCKET_DISCONNECTED,
|
||||||
)
|
)
|
||||||
|
|
||||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
|
||||||
|
|
||||||
|
async def async_setup_platform(
|
||||||
async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
discovery_info: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
"""Set up the API streams platform."""
|
"""Set up the API streams platform."""
|
||||||
entity = APICount()
|
entity = APICount()
|
||||||
|
|
||||||
|
@ -22,11 +30,11 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
|
||||||
class APICount(SensorEntity):
|
class APICount(SensorEntity):
|
||||||
"""Entity to represent how many people are connected to the stream API."""
|
"""Entity to represent how many people are connected to the stream API."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
"""Initialize the API count."""
|
"""Initialize the API count."""
|
||||||
self.count = 0
|
self.count = 0
|
||||||
|
|
||||||
async def async_added_to_hass(self):
|
async def async_added_to_hass(self) -> None:
|
||||||
"""Added to hass."""
|
"""Added to hass."""
|
||||||
self.async_on_remove(
|
self.async_on_remove(
|
||||||
self.hass.helpers.dispatcher.async_dispatcher_connect(
|
self.hass.helpers.dispatcher.async_dispatcher_connect(
|
||||||
|
@ -40,21 +48,21 @@ class APICount(SensorEntity):
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
"""Return name of entity."""
|
"""Return name of entity."""
|
||||||
return "Connected clients"
|
return "Connected clients"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self):
|
def state(self) -> int:
|
||||||
"""Return current API count."""
|
"""Return current API count."""
|
||||||
return self.count
|
return self.count
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unit_of_measurement(self):
|
def unit_of_measurement(self) -> str:
|
||||||
"""Return the unit of measurement."""
|
"""Return the unit of measurement."""
|
||||||
return "clients"
|
return "clients"
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _update_count(self):
|
def _update_count(self) -> None:
|
||||||
self.count = self.hass.data.get(DATA_CONNECTIONS, 0)
|
self.count = self.hass.data.get(DATA_CONNECTIONS, 0)
|
||||||
self.async_write_ha_state()
|
self.async_write_ha_state()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Test WebSocket Connection class."""
|
"""Test WebSocket Connection class."""
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -8,6 +9,8 @@ from homeassistant import exceptions
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.components.websocket_api import const
|
from homeassistant.components.websocket_api import const
|
||||||
|
|
||||||
|
from tests.common import MockUser
|
||||||
|
|
||||||
|
|
||||||
async def test_send_big_result(hass, websocket_client):
|
async def test_send_big_result(hass, websocket_client):
|
||||||
"""Test sending big results over the WS."""
|
"""Test sending big results over the WS."""
|
||||||
|
@ -31,8 +34,10 @@ async def test_send_big_result(hass, websocket_client):
|
||||||
async def test_exception_handling():
|
async def test_exception_handling():
|
||||||
"""Test handling of exceptions."""
|
"""Test handling of exceptions."""
|
||||||
send_messages = []
|
send_messages = []
|
||||||
|
user = MockUser()
|
||||||
|
refresh_token = Mock()
|
||||||
conn = websocket_api.ActiveConnection(
|
conn = websocket_api.ActiveConnection(
|
||||||
logging.getLogger(__name__), None, send_messages.append, None, None
|
logging.getLogger(__name__), None, send_messages.append, user, refresh_token
|
||||||
)
|
)
|
||||||
|
|
||||||
for (exc, code, err) in (
|
for (exc, code, err) in (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue