Allow confirming local push notifications (#54947)
* Allow confirming local push notifications * Fix from Zac * Add tests
This commit is contained in:
parent
f77e93ceeb
commit
677abcd484
7 changed files with 397 additions and 104 deletions
121
homeassistant/components/mobile_app/websocket_api.py
Normal file
121
homeassistant/components/mobile_app/websocket_api.py
Normal file
|
@ -0,0 +1,121 @@
|
|||
"""Mobile app websocket API."""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import callback
|
||||
|
||||
from .const import CONF_USER_ID, DATA_CONFIG_ENTRIES, DATA_PUSH_CHANNEL, DOMAIN
|
||||
from .push_notification import PushChannel
|
||||
|
||||
|
||||
@callback
|
||||
def async_setup_commands(hass):
|
||||
"""Set up the mobile app websocket API."""
|
||||
websocket_api.async_register_command(hass, handle_push_notification_channel)
|
||||
websocket_api.async_register_command(hass, handle_push_notification_confirm)
|
||||
|
||||
|
||||
def _ensure_webhook_access(func):
|
||||
"""Decorate WS function to ensure user owns the webhook ID."""
|
||||
|
||||
@callback
|
||||
@wraps(func)
|
||||
def with_webhook_access(hass, connection, msg):
|
||||
# Validate that the webhook ID is registered to the user of the websocket connection
|
||||
config_entry = hass.data[DOMAIN][DATA_CONFIG_ENTRIES].get(msg["webhook_id"])
|
||||
|
||||
if config_entry is None:
|
||||
connection.send_error(
|
||||
msg["id"], websocket_api.ERR_NOT_FOUND, "Webhook ID not found"
|
||||
)
|
||||
return
|
||||
|
||||
if config_entry.data[CONF_USER_ID] != connection.user.id:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.ERR_UNAUTHORIZED,
|
||||
"User not linked to this webhook ID",
|
||||
)
|
||||
return
|
||||
|
||||
func(hass, connection, msg)
|
||||
|
||||
return with_webhook_access
|
||||
|
||||
|
||||
@callback
|
||||
@_ensure_webhook_access
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "mobile_app/push_notification_confirm",
|
||||
vol.Required("webhook_id"): str,
|
||||
vol.Required("confirm_id"): str,
|
||||
}
|
||||
)
|
||||
def handle_push_notification_confirm(hass, connection, msg):
|
||||
"""Confirm receipt of a push notification."""
|
||||
channel: PushChannel | None = hass.data[DOMAIN][DATA_PUSH_CHANNEL].get(
|
||||
msg["webhook_id"]
|
||||
)
|
||||
if channel is None:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.ERR_NOT_FOUND,
|
||||
"Push notification channel not found",
|
||||
)
|
||||
return
|
||||
|
||||
if channel.async_confirm_notification(msg["confirm_id"]):
|
||||
connection.send_result(msg["id"])
|
||||
else:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.ERR_NOT_FOUND,
|
||||
"Push notification channel not found",
|
||||
)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "mobile_app/push_notification_channel",
|
||||
vol.Required("webhook_id"): str,
|
||||
vol.Optional("support_confirm", default=False): bool,
|
||||
}
|
||||
)
|
||||
@_ensure_webhook_access
|
||||
@websocket_api.async_response
|
||||
async def handle_push_notification_channel(hass, connection, msg):
|
||||
"""Set up a direct push notification channel."""
|
||||
webhook_id = msg["webhook_id"]
|
||||
registered_channels: dict[str, PushChannel] = hass.data[DOMAIN][DATA_PUSH_CHANNEL]
|
||||
|
||||
if webhook_id in registered_channels:
|
||||
await registered_channels[webhook_id].async_teardown()
|
||||
|
||||
@callback
|
||||
def on_channel_teardown():
|
||||
"""Handle teardown."""
|
||||
if registered_channels.get(webhook_id) == channel:
|
||||
registered_channels.pop(webhook_id)
|
||||
|
||||
# Remove subscription from connection if still exists
|
||||
connection.subscriptions.pop(msg["id"], None)
|
||||
|
||||
channel = registered_channels[webhook_id] = PushChannel(
|
||||
hass,
|
||||
webhook_id,
|
||||
msg["support_confirm"],
|
||||
lambda data: connection.send_message(
|
||||
websocket_api.messages.event_message(msg["id"], data)
|
||||
),
|
||||
on_channel_teardown,
|
||||
)
|
||||
|
||||
connection.subscriptions[msg["id"]] = lambda: hass.async_create_task(
|
||||
channel.async_teardown()
|
||||
)
|
||||
connection.send_result(msg["id"])
|
Loading…
Add table
Add a link
Reference in a new issue