Add missing type hints to websocket_api (#50915)

This commit is contained in:
Ruslan Sayfutdinov 2021-05-21 17:39:18 +01:00 committed by GitHub
parent dc65f279a7
commit 42ff687c32
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 251 additions and 159 deletions

View file

@ -1,11 +1,12 @@
"""WebSocket based API for Home Assistant.""" """WebSocket based API for Home Assistant."""
from __future__ import annotations from __future__ import annotations
from typing import cast from typing import Final, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from . import commands, connection, const, decorators, http, messages # noqa: F401 from . import commands, connection, const, decorators, http, messages # noqa: F401
@ -34,11 +35,9 @@ from .messages import ( # noqa: F401
result_message, result_message,
) )
# mypy: allow-untyped-calls, allow-untyped-defs DOMAIN: Final = const.DOMAIN
DOMAIN = const.DOMAIN DEPENDENCIES: Final[tuple[str]] = ("http",)
DEPENDENCIES = ("http",)
@bind_hass @bind_hass
@ -53,8 +52,8 @@ def async_register_command(
# pylint: disable=protected-access # pylint: disable=protected-access
if handler is None: if handler is None:
handler = cast(const.WebSocketCommandHandler, command_or_handler) handler = cast(const.WebSocketCommandHandler, command_or_handler)
command = handler._ws_command # type: ignore command = handler._ws_command # type: ignore[attr-defined]
schema = handler._ws_schema # type: ignore schema = handler._ws_schema # type: ignore[attr-defined]
else: else:
command = command_or_handler command = command_or_handler
handlers = hass.data.get(DOMAIN) handlers = hass.data.get(DOMAIN)
@ -63,8 +62,8 @@ def async_register_command(
handlers[command] = (handler, schema) handlers[command] = (handler, schema)
async def async_setup(hass, config): async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Initialize the websocket API.""" """Initialize the websocket API."""
hass.http.register_view(http.WebsocketAPIView) hass.http.register_view(http.WebsocketAPIView())
commands.async_register_commands(hass, async_register_command) commands.async_register_commands(hass, async_register_command)
return True return True

View file

@ -1,22 +1,31 @@
"""Handle the auth of a connection.""" """Handle the auth of a connection."""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Final
from aiohttp.web import Request
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
from homeassistant.auth.models import RefreshToken, User from homeassistant.auth.models import RefreshToken, User
from homeassistant.components.http.ban import process_success_login, process_wrong_login from homeassistant.components.http.ban import process_success_login, process_wrong_login
from homeassistant.const import __version__ from homeassistant.const import __version__
from homeassistant.core import HomeAssistant
from .connection import ActiveConnection from .connection import ActiveConnection
from .error import Disconnect from .error import Disconnect
# mypy: allow-untyped-calls, allow-untyped-defs if TYPE_CHECKING:
from .http import WebSocketAdapter
TYPE_AUTH = "auth"
TYPE_AUTH_INVALID = "auth_invalid"
TYPE_AUTH_OK = "auth_ok"
TYPE_AUTH_REQUIRED = "auth_required"
AUTH_MESSAGE_SCHEMA = vol.Schema( TYPE_AUTH: Final = "auth"
TYPE_AUTH_INVALID: Final = "auth_invalid"
TYPE_AUTH_OK: Final = "auth_ok"
TYPE_AUTH_REQUIRED: Final = "auth_required"
AUTH_MESSAGE_SCHEMA: Final = vol.Schema(
{ {
vol.Required("type"): TYPE_AUTH, vol.Required("type"): TYPE_AUTH,
vol.Exclusive("api_password", "auth"): str, vol.Exclusive("api_password", "auth"): str,
@ -25,17 +34,17 @@ AUTH_MESSAGE_SCHEMA = vol.Schema(
) )
def auth_ok_message(): def auth_ok_message() -> dict[str, str]:
"""Return an auth_ok message.""" """Return an auth_ok message."""
return {"type": TYPE_AUTH_OK, "ha_version": __version__} return {"type": TYPE_AUTH_OK, "ha_version": __version__}
def auth_required_message(): def auth_required_message() -> dict[str, str]:
"""Return an auth_required message.""" """Return an auth_required message."""
return {"type": TYPE_AUTH_REQUIRED, "ha_version": __version__} return {"type": TYPE_AUTH_REQUIRED, "ha_version": __version__}
def auth_invalid_message(message): def auth_invalid_message(message: str) -> dict[str, str]:
"""Return an auth_invalid message.""" """Return an auth_invalid message."""
return {"type": TYPE_AUTH_INVALID, "message": message} return {"type": TYPE_AUTH_INVALID, "message": message}
@ -43,16 +52,20 @@ def auth_invalid_message(message):
class AuthPhase: class AuthPhase:
"""Connection that requires client to authenticate first.""" """Connection that requires client to authenticate first."""
def __init__(self, logger, hass, send_message, request): def __init__(
self,
logger: WebSocketAdapter,
hass: HomeAssistant,
send_message: Callable[[str | dict[str, Any]], None],
request: Request,
) -> None:
"""Initialize the authentiated connection.""" """Initialize the authentiated connection."""
self._hass = hass self._hass = hass
self._send_message = send_message self._send_message = send_message
self._logger = logger self._logger = logger
self._request = request self._request = request
self._authenticated = False
self._connection = None
async def async_handle(self, msg): async def async_handle(self, msg: dict[str, str]) -> ActiveConnection:
"""Handle authentication.""" """Handle authentication."""
try: try:
msg = AUTH_MESSAGE_SCHEMA(msg) msg = AUTH_MESSAGE_SCHEMA(msg)

View file

@ -1,6 +1,10 @@
"""Commands part of Websocket API.""" """Commands part of Websocket API."""
from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable
import json import json
from typing import Any
import voluptuous as vol import voluptuous as vol
@ -8,7 +12,7 @@ from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ
from homeassistant.bootstrap import SIGNAL_BOOTSTRAP_INTEGRATONS from homeassistant.bootstrap import SIGNAL_BOOTSTRAP_INTEGRATONS
from homeassistant.components.websocket_api.const import ERR_NOT_FOUND from homeassistant.components.websocket_api.const import ERR_NOT_FOUND
from homeassistant.const import EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL from homeassistant.const import EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL
from homeassistant.core import callback from homeassistant.core import Context, Event, HomeAssistant, callback
from homeassistant.exceptions import ( from homeassistant.exceptions import (
HomeAssistantError, HomeAssistantError,
ServiceNotFound, ServiceNotFound,
@ -17,19 +21,25 @@ from homeassistant.exceptions import (
) )
from homeassistant.helpers import config_validation as cv, entity, template from homeassistant.helpers import config_validation as cv, entity, template
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.event import TrackTemplate, async_track_template_result from homeassistant.helpers.event import (
TrackTemplate,
TrackTemplateResult,
async_track_template_result,
)
from homeassistant.helpers.json import ExtendedJSONEncoder from homeassistant.helpers.json import ExtendedJSONEncoder
from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations
from . import const, decorators, messages from . import const, decorators, messages
from .connection import ActiveConnection
# mypy: allow-untyped-calls, allow-untyped-defs
@callback @callback
def async_register_commands(hass, async_reg): def async_register_commands(
hass: HomeAssistant,
async_reg: Callable[[HomeAssistant, const.WebSocketCommandHandler], None],
) -> None:
"""Register commands.""" """Register commands."""
async_reg(hass, handle_call_service) async_reg(hass, handle_call_service)
async_reg(hass, handle_entity_source) async_reg(hass, handle_entity_source)
@ -49,7 +59,7 @@ def async_register_commands(hass, async_reg):
async_reg(hass, handle_unsubscribe_events) async_reg(hass, handle_unsubscribe_events)
def pong_message(iden): def pong_message(iden: int) -> dict[str, Any]:
"""Return a pong message.""" """Return a pong message."""
return {"id": iden, "type": "pong"} return {"id": iden, "type": "pong"}
@ -61,7 +71,9 @@ def pong_message(iden):
vol.Optional("event_type", default=MATCH_ALL): str, vol.Optional("event_type", default=MATCH_ALL): str,
} }
) )
def handle_subscribe_events(hass, connection, msg): def handle_subscribe_events(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle subscribe events command.""" """Handle subscribe events command."""
# Circular dep # Circular dep
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
@ -75,7 +87,7 @@ def handle_subscribe_events(hass, connection, msg):
if event_type == EVENT_STATE_CHANGED: if event_type == EVENT_STATE_CHANGED:
@callback @callback
def forward_events(event): def forward_events(event: Event) -> None:
"""Forward state changed events to websocket.""" """Forward state changed events to websocket."""
if not connection.user.permissions.check_entity( if not connection.user.permissions.check_entity(
event.data["entity_id"], POLICY_READ event.data["entity_id"], POLICY_READ
@ -87,7 +99,7 @@ def handle_subscribe_events(hass, connection, msg):
else: else:
@callback @callback
def forward_events(event): def forward_events(event: Event) -> None:
"""Forward events to websocket.""" """Forward events to websocket."""
if event.event_type == EVENT_TIME_CHANGED: if event.event_type == EVENT_TIME_CHANGED:
return return
@ -107,11 +119,13 @@ def handle_subscribe_events(hass, connection, msg):
vol.Required("type"): "subscribe_bootstrap_integrations", vol.Required("type"): "subscribe_bootstrap_integrations",
} }
) )
def handle_subscribe_bootstrap_integrations(hass, connection, msg): def handle_subscribe_bootstrap_integrations(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle subscribe bootstrap integrations command.""" """Handle subscribe bootstrap integrations command."""
@callback @callback
def forward_bootstrap_integrations(message): def forward_bootstrap_integrations(message: dict[str, Any]) -> None:
"""Forward bootstrap integrations to websocket.""" """Forward bootstrap integrations to websocket."""
connection.send_message(messages.event_message(msg["id"], message)) connection.send_message(messages.event_message(msg["id"], message))
@ -129,7 +143,9 @@ def handle_subscribe_bootstrap_integrations(hass, connection, msg):
vol.Required("subscription"): cv.positive_int, vol.Required("subscription"): cv.positive_int,
} }
) )
def handle_unsubscribe_events(hass, connection, msg): def handle_unsubscribe_events(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle unsubscribe events command.""" """Handle unsubscribe events command."""
subscription = msg["subscription"] subscription = msg["subscription"]
@ -154,7 +170,9 @@ def handle_unsubscribe_events(hass, connection, msg):
} }
) )
@decorators.async_response @decorators.async_response
async def handle_call_service(hass, connection, msg): async def handle_call_service(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle call service command.""" """Handle call service command."""
blocking = True blocking = True
# We do not support templates. # We do not support templates.
@ -206,7 +224,9 @@ async def handle_call_service(hass, connection, msg):
@callback @callback
@decorators.websocket_command({vol.Required("type"): "get_states"}) @decorators.websocket_command({vol.Required("type"): "get_states"})
def handle_get_states(hass, connection, msg): def handle_get_states(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle get states command.""" """Handle get states command."""
if connection.user.permissions.access_all_entities("read"): if connection.user.permissions.access_all_entities("read"):
states = hass.states.async_all() states = hass.states.async_all()
@ -223,7 +243,9 @@ def handle_get_states(hass, connection, msg):
@decorators.websocket_command({vol.Required("type"): "get_services"}) @decorators.websocket_command({vol.Required("type"): "get_services"})
@decorators.async_response @decorators.async_response
async def handle_get_services(hass, connection, msg): async def handle_get_services(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle get services command.""" """Handle get services command."""
descriptions = await async_get_all_descriptions(hass) descriptions = await async_get_all_descriptions(hass)
connection.send_message(messages.result_message(msg["id"], descriptions)) connection.send_message(messages.result_message(msg["id"], descriptions))
@ -231,14 +253,18 @@ async def handle_get_services(hass, connection, msg):
@callback @callback
@decorators.websocket_command({vol.Required("type"): "get_config"}) @decorators.websocket_command({vol.Required("type"): "get_config"})
def handle_get_config(hass, connection, msg): def handle_get_config(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle get config command.""" """Handle get config command."""
connection.send_message(messages.result_message(msg["id"], hass.config.as_dict())) connection.send_message(messages.result_message(msg["id"], hass.config.as_dict()))
@decorators.websocket_command({vol.Required("type"): "manifest/list"}) @decorators.websocket_command({vol.Required("type"): "manifest/list"})
@decorators.async_response @decorators.async_response
async def handle_manifest_list(hass, connection, msg): async def handle_manifest_list(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle integrations command.""" """Handle integrations command."""
loaded_integrations = async_get_loaded_integrations(hass) loaded_integrations = async_get_loaded_integrations(hass)
integrations = await asyncio.gather( integrations = await asyncio.gather(
@ -253,7 +279,9 @@ async def handle_manifest_list(hass, connection, msg):
{vol.Required("type"): "manifest/get", vol.Required("integration"): str} {vol.Required("type"): "manifest/get", vol.Required("integration"): str}
) )
@decorators.async_response @decorators.async_response
async def handle_manifest_get(hass, connection, msg): async def handle_manifest_get(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle integrations command.""" """Handle integrations command."""
try: try:
integration = await async_get_integration(hass, msg["integration"]) integration = await async_get_integration(hass, msg["integration"])
@ -264,7 +292,9 @@ async def handle_manifest_get(hass, connection, msg):
@decorators.websocket_command({vol.Required("type"): "integration/setup_info"}) @decorators.websocket_command({vol.Required("type"): "integration/setup_info"})
@decorators.async_response @decorators.async_response
async def handle_integration_setup_info(hass, connection, msg): async def handle_integration_setup_info(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle integrations command.""" """Handle integrations command."""
connection.send_result( connection.send_result(
msg["id"], msg["id"],
@ -277,7 +307,9 @@ async def handle_integration_setup_info(hass, connection, msg):
@callback @callback
@decorators.websocket_command({vol.Required("type"): "ping"}) @decorators.websocket_command({vol.Required("type"): "ping"})
def handle_ping(hass, connection, msg): def handle_ping(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle ping command.""" """Handle ping command."""
connection.send_message(pong_message(msg["id"])) connection.send_message(pong_message(msg["id"]))
@ -293,10 +325,12 @@ def handle_ping(hass, connection, msg):
} }
) )
@decorators.async_response @decorators.async_response
async def handle_render_template(hass, connection, msg): async def handle_render_template(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle render_template command.""" """Handle render_template command."""
template_str = msg["template"] template_str = msg["template"]
template_obj = template.Template(template_str, hass) template_obj = template.Template(template_str, hass) # type: ignore[no-untyped-call]
variables = msg.get("variables") variables = msg.get("variables")
timeout = msg.get("timeout") timeout = msg.get("timeout")
info = None info = None
@ -319,7 +353,7 @@ async def handle_render_template(hass, connection, msg):
return return
@callback @callback
def _template_listener(event, updates): def _template_listener(event: Event, updates: list[TrackTemplateResult]) -> None:
nonlocal info nonlocal info
track_template_result = updates.pop() track_template_result = updates.pop()
result = track_template_result.result result = track_template_result.result
@ -329,7 +363,7 @@ async def handle_render_template(hass, connection, msg):
connection.send_message( connection.send_message(
messages.event_message( messages.event_message(
msg["id"], {"result": result, "listeners": info.listeners} # type: ignore msg["id"], {"result": result, "listeners": info.listeners} # type: ignore[attr-defined]
) )
) )
@ -356,7 +390,9 @@ async def handle_render_template(hass, connection, msg):
@decorators.websocket_command( @decorators.websocket_command(
{vol.Required("type"): "entity/source", vol.Optional("entity_id"): [cv.entity_id]} {vol.Required("type"): "entity/source", vol.Optional("entity_id"): [cv.entity_id]}
) )
def handle_entity_source(hass, connection, msg): def handle_entity_source(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle entity source command.""" """Handle entity source command."""
raw_sources = entity.entity_sources(hass) raw_sources = entity.entity_sources(hass)
entity_perm = connection.user.permissions.check_entity entity_perm = connection.user.permissions.check_entity
@ -404,7 +440,9 @@ def handle_entity_source(hass, connection, msg):
) )
@decorators.require_admin @decorators.require_admin
@decorators.async_response @decorators.async_response
async def handle_subscribe_trigger(hass, connection, msg): async def handle_subscribe_trigger(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle subscribe trigger command.""" """Handle subscribe trigger command."""
# Circular dep # Circular dep
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
@ -413,7 +451,9 @@ async def handle_subscribe_trigger(hass, connection, msg):
trigger_config = await trigger.async_validate_trigger_config(hass, msg["trigger"]) trigger_config = await trigger.async_validate_trigger_config(hass, msg["trigger"])
@callback @callback
def forward_triggers(variables, context=None): def forward_triggers(
variables: dict[str, Any], context: Context | None = None
) -> None:
"""Forward events to websocket.""" """Forward events to websocket."""
message = messages.event_message( message = messages.event_message(
msg["id"], {"variables": variables, "context": context} msg["id"], {"variables": variables, "context": context}
@ -449,7 +489,9 @@ async def handle_subscribe_trigger(hass, connection, msg):
) )
@decorators.require_admin @decorators.require_admin
@decorators.async_response @decorators.async_response
async def handle_test_condition(hass, connection, msg): async def handle_test_condition(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle test condition command.""" """Handle test condition command."""
# Circular dep # Circular dep
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
@ -470,7 +512,9 @@ async def handle_test_condition(hass, connection, msg):
) )
@decorators.require_admin @decorators.require_admin
@decorators.async_response @decorators.async_response
async def handle_execute_script(hass, connection, msg): async def handle_execute_script(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle execute script command.""" """Handle execute script command."""
# Circular dep # Circular dep
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel

View file

@ -3,48 +3,50 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Hashable from collections.abc import Hashable
from typing import Any, Callable from typing import TYPE_CHECKING, Any, Callable
import voluptuous as vol import voluptuous as vol
from homeassistant.core import Context, callback from homeassistant.auth.models import RefreshToken, User
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, Unauthorized from homeassistant.exceptions import HomeAssistantError, Unauthorized
from . import const, messages from . import const, messages
# mypy: allow-untyped-calls, allow-untyped-defs if TYPE_CHECKING:
from .http import WebSocketAdapter
class ActiveConnection: class ActiveConnection:
"""Handle an active websocket client connection.""" """Handle an active websocket client connection."""
def __init__(self, logger, hass, send_message, user, refresh_token): def __init__(
self,
logger: WebSocketAdapter,
hass: HomeAssistant,
send_message: Callable[[str | dict[str, Any]], None],
user: User,
refresh_token: RefreshToken,
) -> None:
"""Initialize an active connection.""" """Initialize an active connection."""
self.logger = logger self.logger = logger
self.hass = hass self.hass = hass
self.send_message = send_message self.send_message = send_message
self.user = user self.user = user
if refresh_token: self.refresh_token_id = refresh_token.id
self.refresh_token_id = refresh_token.id
else:
self.refresh_token_id = None
self.subscriptions: dict[Hashable, Callable[[], Any]] = {} self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
self.last_id = 0 self.last_id = 0
def context(self, msg): def context(self, msg: dict[str, Any]) -> Context:
"""Return a context.""" """Return a context."""
user = self.user return Context(user_id=self.user.id)
if user is None:
return Context()
return Context(user_id=user.id)
@callback @callback
def send_result(self, msg_id: int, result: Any | None = None) -> None: def send_result(self, msg_id: int, result: Any | None = None) -> None:
"""Send a result message.""" """Send a result message."""
self.send_message(messages.result_message(msg_id, result)) self.send_message(messages.result_message(msg_id, result))
async def send_big_result(self, msg_id, result): async def send_big_result(self, msg_id: int, result: Any) -> None:
"""Send a result message that would be expensive to JSON serialize.""" """Send a result message that would be expensive to JSON serialize."""
content = await self.hass.async_add_executor_job( content = await self.hass.async_add_executor_job(
const.JSON_DUMP, messages.result_message(msg_id, result) const.JSON_DUMP, messages.result_message(msg_id, result)
@ -57,7 +59,7 @@ class ActiveConnection:
self.send_message(messages.error_message(msg_id, code, message)) self.send_message(messages.error_message(msg_id, code, message))
@callback @callback
def async_handle(self, msg): def async_handle(self, msg: dict[str, Any]) -> None:
"""Handle a single incoming message.""" """Handle a single incoming message."""
handlers = self.hass.data[const.DOMAIN] handlers = self.hass.data[const.DOMAIN]
@ -102,13 +104,13 @@ class ActiveConnection:
self.last_id = cur_id self.last_id = cur_id
@callback @callback
def async_close(self): def async_close(self) -> None:
"""Close down connection.""" """Close down connection."""
for unsub in self.subscriptions.values(): for unsub in self.subscriptions.values():
unsub() unsub()
@callback @callback
def async_handle_exception(self, msg, err): def async_handle_exception(self, msg: dict[str, Any], err: Exception) -> None:
"""Handle an exception while processing a handler.""" """Handle an exception while processing a handler."""
log_handler = self.logger.error log_handler = self.logger.error

View file

@ -1,9 +1,11 @@
"""Websocket constants.""" """Websocket constants."""
from __future__ import annotations
import asyncio import asyncio
from concurrent import futures from concurrent import futures
from functools import partial from functools import partial
import json import json
from typing import TYPE_CHECKING, Callable from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Final
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
@ -12,37 +14,42 @@ if TYPE_CHECKING:
from .connection import ActiveConnection from .connection import ActiveConnection
WebSocketCommandHandler = Callable[[HomeAssistant, "ActiveConnection", dict], None] WebSocketCommandHandler = Callable[
[HomeAssistant, "ActiveConnection", Dict[str, Any]], None
]
AsyncWebSocketCommandHandler = Callable[
[HomeAssistant, "ActiveConnection", Dict[str, Any]], Awaitable[None]
]
DOMAIN = "websocket_api" DOMAIN: Final = "websocket_api"
URL = "/api/websocket" URL: Final = "/api/websocket"
PENDING_MSG_PEAK = 512 PENDING_MSG_PEAK: Final = 512
PENDING_MSG_PEAK_TIME = 5 PENDING_MSG_PEAK_TIME: Final = 5
MAX_PENDING_MSG = 2048 MAX_PENDING_MSG: Final = 2048
ERR_ID_REUSE = "id_reuse" ERR_ID_REUSE: Final = "id_reuse"
ERR_INVALID_FORMAT = "invalid_format" ERR_INVALID_FORMAT: Final = "invalid_format"
ERR_NOT_FOUND = "not_found" ERR_NOT_FOUND: Final = "not_found"
ERR_NOT_SUPPORTED = "not_supported" ERR_NOT_SUPPORTED: Final = "not_supported"
ERR_HOME_ASSISTANT_ERROR = "home_assistant_error" ERR_HOME_ASSISTANT_ERROR: Final = "home_assistant_error"
ERR_UNKNOWN_COMMAND = "unknown_command" ERR_UNKNOWN_COMMAND: Final = "unknown_command"
ERR_UNKNOWN_ERROR = "unknown_error" ERR_UNKNOWN_ERROR: Final = "unknown_error"
ERR_UNAUTHORIZED = "unauthorized" ERR_UNAUTHORIZED: Final = "unauthorized"
ERR_TIMEOUT = "timeout" ERR_TIMEOUT: Final = "timeout"
ERR_TEMPLATE_ERROR = "template_error" ERR_TEMPLATE_ERROR: Final = "template_error"
TYPE_RESULT = "result" TYPE_RESULT: Final = "result"
# Define the possible errors that occur when connections are cancelled. # Define the possible errors that occur when connections are cancelled.
# Originally, this was just asyncio.CancelledError, but issue #9546 showed # Originally, this was just asyncio.CancelledError, but issue #9546 showed
# that futures.CancelledErrors can also occur in some situations. # that futures.CancelledErrors can also occur in some situations.
CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError) CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError)
# Event types # Event types
SIGNAL_WEBSOCKET_CONNECTED = "websocket_connected" SIGNAL_WEBSOCKET_CONNECTED: Final = "websocket_connected"
SIGNAL_WEBSOCKET_DISCONNECTED = "websocket_disconnected" SIGNAL_WEBSOCKET_DISCONNECTED: Final = "websocket_disconnected"
# Data used to store the current connection list # Data used to store the current connection list
DATA_CONNECTIONS = f"{DOMAIN}.connections" DATA_CONNECTIONS: Final = f"{DOMAIN}.connections"
JSON_DUMP = partial(json.dumps, cls=JSONEncoder, allow_nan=False) JSON_DUMP: Final = partial(json.dumps, cls=JSONEncoder, allow_nan=False)

View file

@ -2,9 +2,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable
from functools import wraps from functools import wraps
from typing import Callable from typing import Any, Callable
import voluptuous as vol
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import Unauthorized from homeassistant.exceptions import Unauthorized
@ -12,10 +13,13 @@ from homeassistant.exceptions import Unauthorized
from . import const, messages from . import const, messages
from .connection import ActiveConnection from .connection import ActiveConnection
# mypy: allow-untyped-calls, allow-untyped-defs
async def _handle_async_response(
async def _handle_async_response(func, hass, connection, msg): func: const.AsyncWebSocketCommandHandler,
hass: HomeAssistant,
connection: ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Create a response and handle exception.""" """Create a response and handle exception."""
try: try:
await func(hass, connection, msg) await func(hass, connection, msg)
@ -24,13 +28,15 @@ async def _handle_async_response(func, hass, connection, msg):
def async_response( def async_response(
func: Callable[[HomeAssistant, ActiveConnection, dict], Awaitable[None]] func: const.AsyncWebSocketCommandHandler,
) -> const.WebSocketCommandHandler: ) -> const.WebSocketCommandHandler:
"""Decorate an async function to handle WebSocket API messages.""" """Decorate an async function to handle WebSocket API messages."""
@callback @callback
@wraps(func) @wraps(func)
def schedule_handler(hass, connection, msg): def schedule_handler(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Schedule the handler.""" """Schedule the handler."""
# As the webserver is now started before the start # As the webserver is now started before the start
# event we do not want to block for websocket responders # event we do not want to block for websocket responders
@ -43,7 +49,9 @@ def require_admin(func: const.WebSocketCommandHandler) -> const.WebSocketCommand
"""Websocket decorator to require user to be an admin.""" """Websocket decorator to require user to be an admin."""
@wraps(func) @wraps(func)
def with_admin(hass, connection, msg): def with_admin(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Check admin and call function.""" """Check admin and call function."""
user = connection.user user = connection.user
@ -56,34 +64,32 @@ def require_admin(func: const.WebSocketCommandHandler) -> const.WebSocketCommand
def ws_require_user( def ws_require_user(
only_owner=False, only_owner: bool = False,
only_system_user=False, only_system_user: bool = False,
allow_system_user=True, allow_system_user: bool = True,
only_active_user=True, only_active_user: bool = True,
only_inactive_user=False, only_inactive_user: bool = False,
): ) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
"""Decorate function validating login user exist in current WS connection. """Decorate function validating login user exist in current WS connection.
Will write out error message if not authenticated. Will write out error message if not authenticated.
""" """
def validator(func): def validator(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
"""Decorate func.""" """Decorate func."""
@wraps(func) @wraps(func)
def check_current_user(hass, connection, msg): def check_current_user(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Check current user.""" """Check current user."""
def output_error(message_id, message): def output_error(message_id: str, message: str) -> None:
"""Output error message.""" """Output error message."""
connection.send_message( connection.send_message(
messages.error_message(msg["id"], message_id, message) messages.error_message(msg["id"], message_id, message)
) )
if connection.user is None:
output_error("no_user", "Not authenticated as a user")
return
if only_owner and not connection.user.is_owner: if only_owner and not connection.user.is_owner:
output_error("only_owner", "Only allowed as owner") output_error("only_owner", "Only allowed as owner")
return return
@ -112,16 +118,16 @@ def ws_require_user(
def websocket_command( def websocket_command(
schema: dict, schema: dict[vol.Marker, Any],
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]: ) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
"""Tag a function as a websocket command.""" """Tag a function as a websocket command."""
command = schema["type"] command = schema["type"]
def decorate(func): def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
"""Decorate ws command function.""" """Decorate ws command function."""
# pylint: disable=protected-access # pylint: disable=protected-access
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined]
func._ws_command = command func._ws_command = command # type: ignore[attr-defined]
return func return func
return decorate return decorate

View file

@ -2,15 +2,18 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
import datetime as dt
import logging import logging
from typing import Any, Final
from aiohttp import WSMsgType, web from aiohttp import WSMsgType, web
import async_timeout import async_timeout
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers.event import async_call_later from homeassistant.helpers.event import async_call_later
from .auth import AuthPhase, auth_required_message from .auth import AuthPhase, auth_required_message
@ -27,16 +30,15 @@ from .const import (
from .error import Disconnect from .error import Disconnect
from .messages import message_to_json from .messages import message_to_json
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs _WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
_WS_LOGGER = logging.getLogger(f"{__name__}.connection")
class WebsocketAPIView(HomeAssistantView): class WebsocketAPIView(HomeAssistantView):
"""View to serve a websockets endpoint.""" """View to serve a websockets endpoint."""
name = "websocketapi" name: str = "websocketapi"
url = URL url: str = URL
requires_auth = False requires_auth: bool = False
async def get(self, request: web.Request) -> web.WebSocketResponse: async def get(self, request: web.Request) -> web.WebSocketResponse:
"""Handle an incoming websocket connection.""" """Handle an incoming websocket connection."""
@ -46,7 +48,7 @@ class WebsocketAPIView(HomeAssistantView):
class WebSocketAdapter(logging.LoggerAdapter): class WebSocketAdapter(logging.LoggerAdapter):
"""Add connection id to websocket messages.""" """Add connection id to websocket messages."""
def process(self, msg, kwargs): def process(self, msg: str, kwargs: Any) -> tuple[str, Any]:
"""Add connid to websocket log messages.""" """Add connid to websocket log messages."""
return f'[{self.extra["connid"]}] {msg}', kwargs return f'[{self.extra["connid"]}] {msg}', kwargs
@ -54,20 +56,21 @@ class WebSocketAdapter(logging.LoggerAdapter):
class WebSocketHandler: class WebSocketHandler:
"""Handle an active websocket client connection.""" """Handle an active websocket client connection."""
def __init__(self, hass, request): def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
"""Initialize an active connection.""" """Initialize an active connection."""
self.hass = hass self.hass = hass
self.request = request self.request = request
self.wsock: web.WebSocketResponse | None = None self.wsock: web.WebSocketResponse | None = None
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG) self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
self._handle_task = None self._handle_task: asyncio.Task | None = None
self._writer_task = None self._writer_task: asyncio.Task | None = None
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)}) self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
self._peak_checker_unsub = None self._peak_checker_unsub: Callable[[], None] | None = None
async def _writer(self): async def _writer(self) -> None:
"""Write outgoing messages.""" """Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler # Exceptions if Socket disconnected or cancelled by connection handler
assert self.wsock is not None
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS): with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
while not self.wsock.closed: while not self.wsock.closed:
message = await self._to_write.get() message = await self._to_write.get()
@ -78,12 +81,12 @@ class WebSocketHandler:
await self.wsock.send_str(message) await self.wsock.send_str(message)
# Clean up the peaker checker when we shut down the writer # Clean up the peaker checker when we shut down the writer
if self._peak_checker_unsub: if self._peak_checker_unsub is not None:
self._peak_checker_unsub() self._peak_checker_unsub()
self._peak_checker_unsub = None self._peak_checker_unsub = None
@callback @callback
def _send_message(self, message): def _send_message(self, message: str | dict[str, Any]) -> None:
"""Send a message to the client. """Send a message to the client.
Closes connection if the client is not reading the messages. Closes connection if the client is not reading the messages.
@ -114,7 +117,7 @@ class WebSocketHandler:
) )
@callback @callback
def _check_write_peak(self, _): def _check_write_peak(self, _utc_time: dt.datetime) -> None:
"""Check that we are no longer above the write peak.""" """Check that we are no longer above the write peak."""
self._peak_checker_unsub = None self._peak_checker_unsub = None
@ -129,10 +132,12 @@ class WebSocketHandler:
self._cancel() self._cancel()
@callback @callback
def _cancel(self): def _cancel(self) -> None:
"""Cancel the connection.""" """Cancel the connection."""
self._handle_task.cancel() if self._handle_task is not None:
self._writer_task.cancel() self._handle_task.cancel()
if self._writer_task is not None:
self._writer_task.cancel()
async def async_handle(self) -> web.WebSocketResponse: async def async_handle(self) -> web.WebSocketResponse:
"""Handle a websocket response.""" """Handle a websocket response."""
@ -143,7 +148,7 @@ class WebSocketHandler:
self._handle_task = asyncio.current_task() self._handle_task = asyncio.current_task()
@callback @callback
def handle_hass_stop(event): def handle_hass_stop(event: Event) -> None:
"""Cancel this connection.""" """Cancel this connection."""
self._cancel() self._cancel()

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from functools import lru_cache from functools import lru_cache
import logging import logging
from typing import Any from typing import Any, Final
import voluptuous as vol import voluptuous as vol
@ -17,28 +17,27 @@ from homeassistant.util.yaml.loader import JSON_TYPE
from . import const from . import const
_LOGGER = logging.getLogger(__name__) _LOGGER: Final = logging.getLogger(__name__)
# mypy: allow-untyped-defs
# Minimal requirements of a message # Minimal requirements of a message
MINIMAL_MESSAGE_SCHEMA = vol.Schema( MINIMAL_MESSAGE_SCHEMA: Final = vol.Schema(
{vol.Required("id"): cv.positive_int, vol.Required("type"): cv.string}, {vol.Required("id"): cv.positive_int, vol.Required("type"): cv.string},
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
) )
# Base schema to extend by message handlers # Base schema to extend by message handlers
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({vol.Required("id"): cv.positive_int}) BASE_COMMAND_MESSAGE_SCHEMA: Final = vol.Schema({vol.Required("id"): cv.positive_int})
IDEN_TEMPLATE = "__IDEN__" IDEN_TEMPLATE: Final = "__IDEN__"
IDEN_JSON_TEMPLATE = '"__IDEN__"' IDEN_JSON_TEMPLATE: Final = '"__IDEN__"'
def result_message(iden: int, result: Any = None) -> dict: def result_message(iden: int, result: Any = None) -> dict[str, Any]:
"""Return a success result message.""" """Return a success result message."""
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result} return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}
def error_message(iden: int, code: str, message: str) -> dict: def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]:
"""Return an error result message.""" """Return an error result message."""
return { return {
"id": iden, "id": iden,
@ -48,7 +47,7 @@ def error_message(iden: int, code: str, message: str) -> dict:
} }
def event_message(iden: JSON_TYPE, event: Any) -> dict: def event_message(iden: JSON_TYPE, event: Any) -> dict[str, Any]:
"""Return an event message.""" """Return an event message."""
return {"id": iden, "type": "event", "event": event} return {"id": iden, "type": "event", "event": event}
@ -75,7 +74,7 @@ def _cached_event_message(event: Event) -> str:
return message_to_json(event_message(IDEN_TEMPLATE, event)) return message_to_json(event_message(IDEN_TEMPLATE, event))
def message_to_json(message: Any) -> str: def message_to_json(message: dict[str, Any]) -> str:
"""Serialize a websocket message to json.""" """Serialize a websocket message to json."""
try: try:
return const.JSON_DUMP(message) return const.JSON_DUMP(message)

View file

@ -2,6 +2,10 @@
Separate file to avoid circular imports. Separate file to avoid circular imports.
""" """
from __future__ import annotations
from typing import Final
from homeassistant.components.frontend import EVENT_PANELS_UPDATED from homeassistant.components.frontend import EVENT_PANELS_UPDATED
from homeassistant.components.lovelace.const import EVENT_LOVELACE_UPDATED from homeassistant.components.lovelace.const import EVENT_LOVELACE_UPDATED
from homeassistant.components.persistent_notification import ( from homeassistant.components.persistent_notification import (
@ -22,7 +26,7 @@ from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
# These are events that do not contain any sensitive data # These are events that do not contain any sensitive data
# Except for state_changed, which is handled accordingly. # Except for state_changed, which is handled accordingly.
SUBSCRIBE_ALLOWLIST = { SUBSCRIBE_ALLOWLIST: Final[set[str]] = {
EVENT_AREA_REGISTRY_UPDATED, EVENT_AREA_REGISTRY_UPDATED,
EVENT_COMPONENT_LOADED, EVENT_COMPONENT_LOADED,
EVENT_CORE_CONFIG_UPDATE, EVENT_CORE_CONFIG_UPDATE,

View file

@ -1,7 +1,12 @@
"""Entity to track connections to websocket API.""" """Entity to track connections to websocket API."""
from __future__ import annotations
from typing import Any
from homeassistant.components.sensor import SensorEntity from homeassistant.components.sensor import SensorEntity
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import (
DATA_CONNECTIONS, DATA_CONNECTIONS,
@ -9,10 +14,13 @@ from .const import (
SIGNAL_WEBSOCKET_DISCONNECTED, SIGNAL_WEBSOCKET_DISCONNECTED,
) )
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
async def async_setup_platform(
async def async_setup_platform(hass, config, async_add_entities, discovery_info=None): hass: HomeAssistant,
config: ConfigType,
async_add_entities: AddEntitiesCallback,
discovery_info: dict[str, Any] | None = None,
) -> None:
"""Set up the API streams platform.""" """Set up the API streams platform."""
entity = APICount() entity = APICount()
@ -22,11 +30,11 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
class APICount(SensorEntity): class APICount(SensorEntity):
"""Entity to represent how many people are connected to the stream API.""" """Entity to represent how many people are connected to the stream API."""
def __init__(self): def __init__(self) -> None:
"""Initialize the API count.""" """Initialize the API count."""
self.count = 0 self.count = 0
async def async_added_to_hass(self): async def async_added_to_hass(self) -> None:
"""Added to hass.""" """Added to hass."""
self.async_on_remove( self.async_on_remove(
self.hass.helpers.dispatcher.async_dispatcher_connect( self.hass.helpers.dispatcher.async_dispatcher_connect(
@ -40,21 +48,21 @@ class APICount(SensorEntity):
) )
@property @property
def name(self): def name(self) -> str:
"""Return name of entity.""" """Return name of entity."""
return "Connected clients" return "Connected clients"
@property @property
def state(self): def state(self) -> int:
"""Return current API count.""" """Return current API count."""
return self.count return self.count
@property @property
def unit_of_measurement(self): def unit_of_measurement(self) -> str:
"""Return the unit of measurement.""" """Return the unit of measurement."""
return "clients" return "clients"
@callback @callback
def _update_count(self): def _update_count(self) -> None:
self.count = self.hass.data.get(DATA_CONNECTIONS, 0) self.count = self.hass.data.get(DATA_CONNECTIONS, 0)
self.async_write_ha_state() self.async_write_ha_state()

View file

@ -1,6 +1,7 @@
"""Test WebSocket Connection class.""" """Test WebSocket Connection class."""
import asyncio import asyncio
import logging import logging
from unittest.mock import Mock
import voluptuous as vol import voluptuous as vol
@ -8,6 +9,8 @@ from homeassistant import exceptions
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.websocket_api import const from homeassistant.components.websocket_api import const
from tests.common import MockUser
async def test_send_big_result(hass, websocket_client): async def test_send_big_result(hass, websocket_client):
"""Test sending big results over the WS.""" """Test sending big results over the WS."""
@ -31,8 +34,10 @@ async def test_send_big_result(hass, websocket_client):
async def test_exception_handling(): async def test_exception_handling():
"""Test handling of exceptions.""" """Test handling of exceptions."""
send_messages = [] send_messages = []
user = MockUser()
refresh_token = Mock()
conn = websocket_api.ActiveConnection( conn = websocket_api.ActiveConnection(
logging.getLogger(__name__), None, send_messages.append, None, None logging.getLogger(__name__), None, send_messages.append, user, refresh_token
) )
for (exc, code, err) in ( for (exc, code, err) in (