Allow confirming local push notifications (#54947)

* Allow confirming local push notifications

* Fix from Zac

* Add tests
This commit is contained in:
Paulus Schoutsen 2021-09-22 14:17:04 -07:00 committed by GitHub
parent f77e93ceeb
commit 677abcd484
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 397 additions and 104 deletions

View file

@ -1,25 +1,23 @@
"""Integrates Native Apps to Home Assistant.""" """Integrates Native Apps to Home Assistant."""
from contextlib import suppress from contextlib import suppress
import voluptuous as vol from homeassistant.components import cloud, notify as hass_notify
from homeassistant.components import cloud, notify as hass_notify, websocket_api
from homeassistant.components.webhook import ( from homeassistant.components.webhook import (
async_register as webhook_register, async_register as webhook_register,
async_unregister as webhook_unregister, async_unregister as webhook_unregister,
) )
from homeassistant.const import ATTR_DEVICE_ID, CONF_WEBHOOK_ID from homeassistant.const import ATTR_DEVICE_ID, CONF_WEBHOOK_ID
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, discovery from homeassistant.helpers import device_registry as dr, discovery
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from . import websocket_api
from .const import ( from .const import (
ATTR_DEVICE_NAME, ATTR_DEVICE_NAME,
ATTR_MANUFACTURER, ATTR_MANUFACTURER,
ATTR_MODEL, ATTR_MODEL,
ATTR_OS_VERSION, ATTR_OS_VERSION,
CONF_CLOUDHOOK_URL, CONF_CLOUDHOOK_URL,
CONF_USER_ID,
DATA_CONFIG_ENTRIES, DATA_CONFIG_ENTRIES,
DATA_DELETED_IDS, DATA_DELETED_IDS,
DATA_DEVICES, DATA_DEVICES,
@ -66,7 +64,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
discovery.async_load_platform(hass, "notify", DOMAIN, {}, config) discovery.async_load_platform(hass, "notify", DOMAIN, {}, config)
) )
websocket_api.async_register_command(hass, handle_push_notification_channel) websocket_api.async_setup_commands(hass)
return True return True
@ -127,52 +125,3 @@ async def async_remove_entry(hass, entry):
if CONF_CLOUDHOOK_URL in entry.data: if CONF_CLOUDHOOK_URL in entry.data:
with suppress(cloud.CloudNotAvailable): with suppress(cloud.CloudNotAvailable):
await cloud.async_delete_cloudhook(hass, entry.data[CONF_WEBHOOK_ID]) await cloud.async_delete_cloudhook(hass, entry.data[CONF_WEBHOOK_ID])
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "mobile_app/push_notification_channel",
vol.Required("webhook_id"): str,
}
)
def handle_push_notification_channel(hass, connection, msg):
"""Set up a direct push notification channel."""
webhook_id = msg["webhook_id"]
# Validate that the webhook ID is registered to the user of the websocket connection
config_entry = hass.data[DOMAIN][DATA_CONFIG_ENTRIES].get(webhook_id)
if config_entry is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Webhook ID not found"
)
return
if config_entry.data[CONF_USER_ID] != connection.user.id:
connection.send_error(
msg["id"],
websocket_api.ERR_UNAUTHORIZED,
"User not linked to this webhook ID",
)
return
registered_channels = hass.data[DOMAIN][DATA_PUSH_CHANNEL]
if webhook_id in registered_channels:
registered_channels.pop(webhook_id)
@callback
def forward_push_notification(data):
"""Forward events to websocket."""
connection.send_message(websocket_api.messages.event_message(msg["id"], data))
@callback
def unsub():
# pylint: disable=comparison-with-callable
if registered_channels.get(webhook_id) == forward_push_notification:
registered_channels.pop(webhook_id)
registered_channels[webhook_id] = forward_push_notification
connection.subscriptions[msg["id"]] = unsub
connection.send_result(msg["id"])

View file

@ -1,5 +1,6 @@
"""Support for mobile_app push notifications.""" """Support for mobile_app push notifications."""
import asyncio import asyncio
from functools import partial
import logging import logging
import aiohttp import aiohttp
@ -124,9 +125,15 @@ class MobileAppNotificationService(BaseNotificationService):
for target in targets: for target in targets:
if target in local_push_channels: if target in local_push_channels:
local_push_channels[target](data) local_push_channels[target].async_send_notification(
data, partial(self._async_send_remote_message_target, target)
)
continue continue
await self._async_send_remote_message_target(target, data)
async def _async_send_remote_message_target(self, target, data):
"""Send a message to a target."""
entry = self.hass.data[DOMAIN][DATA_CONFIG_ENTRIES][target] entry = self.hass.data[DOMAIN][DATA_CONFIG_ENTRIES][target]
entry_data = entry.data entry_data = entry.data
@ -155,7 +162,7 @@ class MobileAppNotificationService(BaseNotificationService):
if response.status in (HTTP_OK, HTTP_CREATED, HTTP_ACCEPTED): if response.status in (HTTP_OK, HTTP_CREATED, HTTP_ACCEPTED):
log_rate_limits(self.hass, entry_data[ATTR_DEVICE_NAME], result) log_rate_limits(self.hass, entry_data[ATTR_DEVICE_NAME], result)
continue return
fallback_error = result.get("errorMessage", "Unknown error") fallback_error = result.get("errorMessage", "Unknown error")
fallback_message = ( fallback_message = (
@ -166,9 +173,7 @@ class MobileAppNotificationService(BaseNotificationService):
if "message" in result: if "message" in result:
if message[-1] not in [".", "?", "!"]: if message[-1] not in [".", "?", "!"]:
message += "." message += "."
message += ( message += " This message is generated externally to Home Assistant."
" This message is generated externally to Home Assistant."
)
if response.status == HTTP_TOO_MANY_REQUESTS: if response.status == HTTP_TOO_MANY_REQUESTS:
_LOGGER.warning(message) _LOGGER.warning(message)

View file

@ -0,0 +1,90 @@
"""Push notification handling."""
import asyncio
from typing import Callable
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.event import async_call_later
from homeassistant.util.uuid import random_uuid_hex
PUSH_CONFIRM_TIMEOUT = 10 # seconds
class PushChannel:
"""Class that represents a push channel."""
def __init__(
self,
hass: HomeAssistant,
webhook_id: str,
support_confirm: bool,
send_message: Callable[[dict], None],
on_teardown: Callable[[], None],
) -> None:
"""Initialize a local push channel."""
self.hass = hass
self.webhook_id = webhook_id
self.support_confirm = support_confirm
self._send_message = send_message
self.on_teardown = on_teardown
self.pending_confirms = {}
@callback
def async_send_notification(self, data, fallback_send):
"""Send a push notification."""
if not self.support_confirm:
self._send_message(data)
return
confirm_id = random_uuid_hex()
data["hass_confirm_id"] = confirm_id
async def handle_push_failed(_=None):
"""Handle a failed local push notification."""
# Remove this handler from the pending dict
# If it didn't exist we hit a race condition between call_later and another
# push failing and tearing down the connection.
if self.pending_confirms.pop(confirm_id, None) is None:
return
# Drop local channel if it's still open
if self.on_teardown is not None:
await self.async_teardown()
await fallback_send(data)
self.pending_confirms[confirm_id] = {
"unsub_scheduled_push_failed": async_call_later(
self.hass, PUSH_CONFIRM_TIMEOUT, handle_push_failed
),
"handle_push_failed": handle_push_failed,
}
self._send_message(data)
@callback
def async_confirm_notification(self, confirm_id) -> bool:
"""Confirm a push notification.
Returns if confirmation successful.
"""
if confirm_id not in self.pending_confirms:
return False
self.pending_confirms.pop(confirm_id)["unsub_scheduled_push_failed"]()
return True
async def async_teardown(self):
"""Tear down this channel."""
# Tear down is in progress
if self.on_teardown is None:
return
self.on_teardown()
self.on_teardown = None
cancel_pending_local_tasks = [
actions["handle_push_failed"]()
for actions in self.pending_confirms.values()
]
if cancel_pending_local_tasks:
await asyncio.gather(*cancel_pending_local_tasks)

View file

@ -0,0 +1,121 @@
"""Mobile app websocket API."""
from __future__ import annotations
from functools import wraps
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import callback
from .const import CONF_USER_ID, DATA_CONFIG_ENTRIES, DATA_PUSH_CHANNEL, DOMAIN
from .push_notification import PushChannel
@callback
def async_setup_commands(hass):
"""Set up the mobile app websocket API."""
websocket_api.async_register_command(hass, handle_push_notification_channel)
websocket_api.async_register_command(hass, handle_push_notification_confirm)
def _ensure_webhook_access(func):
"""Decorate WS function to ensure user owns the webhook ID."""
@callback
@wraps(func)
def with_webhook_access(hass, connection, msg):
# Validate that the webhook ID is registered to the user of the websocket connection
config_entry = hass.data[DOMAIN][DATA_CONFIG_ENTRIES].get(msg["webhook_id"])
if config_entry is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Webhook ID not found"
)
return
if config_entry.data[CONF_USER_ID] != connection.user.id:
connection.send_error(
msg["id"],
websocket_api.ERR_UNAUTHORIZED,
"User not linked to this webhook ID",
)
return
func(hass, connection, msg)
return with_webhook_access
@callback
@_ensure_webhook_access
@websocket_api.websocket_command(
{
vol.Required("type"): "mobile_app/push_notification_confirm",
vol.Required("webhook_id"): str,
vol.Required("confirm_id"): str,
}
)
def handle_push_notification_confirm(hass, connection, msg):
"""Confirm receipt of a push notification."""
channel: PushChannel | None = hass.data[DOMAIN][DATA_PUSH_CHANNEL].get(
msg["webhook_id"]
)
if channel is None:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_FOUND,
"Push notification channel not found",
)
return
if channel.async_confirm_notification(msg["confirm_id"]):
connection.send_result(msg["id"])
else:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_FOUND,
"Push notification channel not found",
)
@websocket_api.websocket_command(
{
vol.Required("type"): "mobile_app/push_notification_channel",
vol.Required("webhook_id"): str,
vol.Optional("support_confirm", default=False): bool,
}
)
@_ensure_webhook_access
@websocket_api.async_response
async def handle_push_notification_channel(hass, connection, msg):
"""Set up a direct push notification channel."""
webhook_id = msg["webhook_id"]
registered_channels: dict[str, PushChannel] = hass.data[DOMAIN][DATA_PUSH_CHANNEL]
if webhook_id in registered_channels:
await registered_channels[webhook_id].async_teardown()
@callback
def on_channel_teardown():
"""Handle teardown."""
if registered_channels.get(webhook_id) == channel:
registered_channels.pop(webhook_id)
# Remove subscription from connection if still exists
connection.subscriptions.pop(msg["id"], None)
channel = registered_channels[webhook_id] = PushChannel(
hass,
webhook_id,
msg["support_confirm"],
lambda data: connection.send_message(
websocket_api.messages.event_message(msg["id"], data)
),
on_channel_teardown,
)
connection.subscriptions[msg["id"]] = lambda: hass.async_create_task(
channel.async_teardown()
)
connection.send_result(msg["id"])

View file

@ -104,8 +104,8 @@ class ActiveConnection:
self.last_id = cur_id self.last_id = cur_id
@callback @callback
def async_close(self) -> None: def async_handle_close(self) -> None:
"""Close down connection.""" """Handle closing down connection."""
for unsub in self.subscriptions.values(): for unsub in self.subscriptions.values():
unsub() unsub()

View file

@ -231,7 +231,7 @@ class WebSocketHandler:
unsub_stop() unsub_stop()
if connection is not None: if connection is not None:
connection.async_close() connection.async_handle_close()
try: try:
self._to_write.put_nowait(None) self._to_write.put_nowait(None)

View file

@ -1,5 +1,6 @@
"""Notify platform tests for mobile_app.""" """Notify platform tests for mobile_app."""
from datetime import datetime, timedelta from datetime import datetime, timedelta
from unittest.mock import patch
import pytest import pytest
@ -204,3 +205,130 @@ async def test_notify_ws_works(
"code": "unauthorized", "code": "unauthorized",
"message": "User not linked to this webhook ID", "message": "User not linked to this webhook ID",
} }
async def test_notify_ws_confirming_works(
hass, aioclient_mock, setup_push_receiver, hass_ws_client
):
"""Test notify confirming works."""
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "mobile_app/push_notification_channel",
"webhook_id": "mock-webhook_id",
"support_confirm": True,
}
)
sub_result = await client.receive_json()
assert sub_result["success"]
# Sent a message that will be delivered locally
assert await hass.services.async_call(
"notify", "mobile_app_test", {"message": "Hello world"}, blocking=True
)
msg_result = await client.receive_json()
confirm_id = msg_result["event"].pop("hass_confirm_id")
assert confirm_id is not None
assert msg_result["event"] == {"message": "Hello world"}
# Try to confirm with incorrect confirm ID
await client.send_json(
{
"id": 6,
"type": "mobile_app/push_notification_confirm",
"webhook_id": "mock-webhook_id",
"confirm_id": "incorrect-confirm-id",
}
)
result = await client.receive_json()
assert not result["success"]
assert result["error"] == {
"code": "not_found",
"message": "Push notification channel not found",
}
# Confirm with correct confirm ID
await client.send_json(
{
"id": 7,
"type": "mobile_app/push_notification_confirm",
"webhook_id": "mock-webhook_id",
"confirm_id": confirm_id,
}
)
result = await client.receive_json()
assert result["success"]
# Drop local push channel and try to confirm another message
await client.send_json(
{
"id": 8,
"type": "unsubscribe_events",
"subscription": 5,
}
)
sub_result = await client.receive_json()
assert sub_result["success"]
await client.send_json(
{
"id": 9,
"type": "mobile_app/push_notification_confirm",
"webhook_id": "mock-webhook_id",
"confirm_id": confirm_id,
}
)
result = await client.receive_json()
assert not result["success"]
assert result["error"] == {
"code": "not_found",
"message": "Push notification channel not found",
}
async def test_notify_ws_not_confirming(
hass, aioclient_mock, setup_push_receiver, hass_ws_client
):
"""Test we go via cloud when failed to confirm."""
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "mobile_app/push_notification_channel",
"webhook_id": "mock-webhook_id",
"support_confirm": True,
}
)
sub_result = await client.receive_json()
assert sub_result["success"]
assert await hass.services.async_call(
"notify", "mobile_app_test", {"message": "Hello world 1"}, blocking=True
)
with patch(
"homeassistant.components.mobile_app.push_notification.PUSH_CONFIRM_TIMEOUT", 0
):
assert await hass.services.async_call(
"notify", "mobile_app_test", {"message": "Hello world 2"}, blocking=True
)
await hass.async_block_till_done()
# When we fail, all unconfirmed ones and failed one are sent via cloud
assert len(aioclient_mock.mock_calls) == 2
# All future ones also go via cloud
assert await hass.services.async_call(
"notify", "mobile_app_test", {"message": "Hello world 3"}, blocking=True
)
assert len(aioclient_mock.mock_calls) == 3