From 1e107724976f34884359e1bf1ff1a3190176c797 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 17 May 2021 11:06:42 -0700 Subject: [PATCH] Add support for local push channels to mobile_app (#50750) --- .../components/mobile_app/__init__.py | 60 ++++++++++++- homeassistant/components/mobile_app/const.py | 1 + .../components/mobile_app/manifest.json | 2 +- homeassistant/components/mobile_app/notify.py | 14 ++- .../components/websocket_api/commands.py | 1 - tests/components/mobile_app/test_notify.py | 85 +++++++++++++++++-- 6 files changed, 150 insertions(+), 13 deletions(-) diff --git a/homeassistant/components/mobile_app/__init__.py b/homeassistant/components/mobile_app/__init__.py index 0fe1386d7ce..951c6f3beaf 100644 --- a/homeassistant/components/mobile_app/__init__.py +++ b/homeassistant/components/mobile_app/__init__.py @@ -1,13 +1,15 @@ """Integrates Native Apps to Home Assistant.""" from contextlib import suppress -from homeassistant.components import cloud, notify as hass_notify +import voluptuous as vol + +from homeassistant.components import cloud, notify as hass_notify, websocket_api from homeassistant.components.webhook import ( async_register as webhook_register, async_unregister as webhook_unregister, ) from homeassistant.const import ATTR_DEVICE_ID, CONF_WEBHOOK_ID -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import device_registry as dr, discovery from homeassistant.helpers.typing import ConfigType @@ -17,9 +19,11 @@ from .const import ( ATTR_MODEL, ATTR_OS_VERSION, CONF_CLOUDHOOK_URL, + CONF_USER_ID, DATA_CONFIG_ENTRIES, DATA_DELETED_IDS, DATA_DEVICES, + DATA_PUSH_CHANNEL, DATA_STORE, DOMAIN, STORAGE_KEY, @@ -46,6 +50,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType): DATA_CONFIG_ENTRIES: {}, DATA_DELETED_IDS: app_config.get(DATA_DELETED_IDS, []), DATA_DEVICES: {}, + DATA_PUSH_CHANNEL: {}, DATA_STORE: store, } @@ -61,6 +66,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType): discovery.async_load_platform(hass, "notify", DOMAIN, {}, config) ) + websocket_api.async_register_command(hass, handle_push_notification_channel) + return True @@ -120,3 +127,52 @@ async def async_remove_entry(hass, entry): if CONF_CLOUDHOOK_URL in entry.data: with suppress(cloud.CloudNotAvailable): await cloud.async_delete_cloudhook(hass, entry.data[CONF_WEBHOOK_ID]) + + +@callback +@websocket_api.websocket_command( + { + vol.Required("type"): "mobile_app/push_notification_channel", + vol.Required("webhook_id"): str, + } +) +def handle_push_notification_channel(hass, connection, msg): + """Set up a direct push notification channel.""" + webhook_id = msg["webhook_id"] + + # Validate that the webhook ID is registered to the user of the websocket connection + config_entry = hass.data[DOMAIN][DATA_CONFIG_ENTRIES].get(webhook_id) + + if config_entry is None: + connection.send_error( + msg["id"], websocket_api.ERR_NOT_FOUND, "Webhook ID not found" + ) + return + + if config_entry.data[CONF_USER_ID] != connection.user.id: + connection.send_error( + msg["id"], + websocket_api.ERR_UNAUTHORIZED, + "User not linked to this webhook ID", + ) + return + + registered_channels = hass.data[DOMAIN][DATA_PUSH_CHANNEL] + + if webhook_id in registered_channels: + registered_channels.pop(webhook_id)() + + @callback + def forward_push_notification(data): + """Forward events to websocket.""" + connection.send_message(websocket_api.messages.event_message(msg["id"], data)) + + @callback + def unsub(): + # pylint: disable=comparison-with-callable + if registered_channels.get(webhook_id) == forward_push_notification: + registered_channels.pop(webhook_id) + + registered_channels[webhook_id] = forward_push_notification + connection.subscriptions[msg["id"]] = unsub + connection.send_result(msg["id"]) diff --git a/homeassistant/components/mobile_app/const.py b/homeassistant/components/mobile_app/const.py index af828ce423e..e375ec55ff2 100644 --- a/homeassistant/components/mobile_app/const.py +++ b/homeassistant/components/mobile_app/const.py @@ -14,6 +14,7 @@ DATA_DELETED_IDS = "deleted_ids" DATA_DEVICES = "devices" DATA_STORE = "store" DATA_NOTIFY = "notify" +DATA_PUSH_CHANNEL = "push_channel" ATTR_APP_DATA = "app_data" ATTR_APP_ID = "app_id" diff --git a/homeassistant/components/mobile_app/manifest.json b/homeassistant/components/mobile_app/manifest.json index 2372ee0c515..d850d9ab469 100644 --- a/homeassistant/components/mobile_app/manifest.json +++ b/homeassistant/components/mobile_app/manifest.json @@ -4,7 +4,7 @@ "config_flow": true, "documentation": "https://www.home-assistant.io/integrations/mobile_app", "requirements": ["PyNaCl==1.3.0", "emoji==1.2.0"], - "dependencies": ["http", "webhook", "person", "tag"], + "dependencies": ["http", "webhook", "person", "tag", "websocket_api"], "after_dependencies": ["cloud", "camera", "notify"], "codeowners": ["@robbiet480"], "quality_scale": "internal", diff --git a/homeassistant/components/mobile_app/notify.py b/homeassistant/components/mobile_app/notify.py index 803f00764e7..1acb9f25c0c 100644 --- a/homeassistant/components/mobile_app/notify.py +++ b/homeassistant/components/mobile_app/notify.py @@ -37,6 +37,7 @@ from .const import ( ATTR_PUSH_URL, DATA_CONFIG_ENTRIES, DATA_NOTIFY, + DATA_PUSH_CHANNEL, DOMAIN, ) from .util import supports_push @@ -119,7 +120,13 @@ class MobileAppNotificationService(BaseNotificationService): if kwargs.get(ATTR_DATA) is not None: data[ATTR_DATA] = kwargs.get(ATTR_DATA) + local_push_channels = self.hass.data[DOMAIN][DATA_PUSH_CHANNEL] + for target in targets: + if target in local_push_channels: + local_push_channels[target](data) + continue + entry = self.hass.data[DOMAIN][DATA_CONFIG_ENTRIES][target] entry_data = entry.data @@ -127,7 +134,8 @@ class MobileAppNotificationService(BaseNotificationService): push_token = app_data[ATTR_PUSH_TOKEN] push_url = app_data[ATTR_PUSH_URL] - data[ATTR_PUSH_TOKEN] = push_token + target_data = dict(data) + target_data[ATTR_PUSH_TOKEN] = push_token reg_info = { ATTR_APP_ID: entry_data[ATTR_APP_ID], @@ -136,12 +144,12 @@ class MobileAppNotificationService(BaseNotificationService): if ATTR_OS_VERSION in entry_data: reg_info[ATTR_OS_VERSION] = entry_data[ATTR_OS_VERSION] - data["registration_info"] = reg_info + target_data["registration_info"] = reg_info try: with async_timeout.timeout(10): response = await async_get_clientsession(self._hass).post( - push_url, json=data + push_url, json=target_data ) result = await response.json() diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index af2c914bfbd..53ff6d1da26 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -395,7 +395,6 @@ def handle_entity_source(hass, connection, msg): connection.send_result(msg["id"], sources) -@callback @decorators.websocket_command( { vol.Required("type"): "subscribe_trigger", diff --git a/tests/components/mobile_app/test_notify.py b/tests/components/mobile_app/test_notify.py index 8823fefd92c..9c4ca146898 100644 --- a/tests/components/mobile_app/test_notify.py +++ b/tests/components/mobile_app/test_notify.py @@ -1,5 +1,6 @@ """Notify platform tests for mobile_app.""" -# pylint: disable=redefined-outer-name +from datetime import datetime, timedelta + import pytest from homeassistant.components.mobile_app.const import DOMAIN @@ -9,12 +10,10 @@ from tests.common import MockConfigEntry @pytest.fixture -async def setup_push_receiver(hass, aioclient_mock): +async def setup_push_receiver(hass, aioclient_mock, hass_admin_user): """Fixture that sets up a mocked push receiver.""" push_url = "https://mobile-push.home-assistant.dev/push" - from datetime import datetime, timedelta - now = datetime.now() + timedelta(hours=24) iso_time = now.strftime("%Y-%m-%dT%H:%M:%SZ") @@ -47,8 +46,8 @@ async def setup_push_receiver(hass, aioclient_mock): "os_version": "5.0.6", "secret": "123abc", "supports_encryption": False, - "user_id": "1a2b3c", - "webhook_id": "webhook_id", + "user_id": hass_admin_user.id, + "webhook_id": "mock-webhook_id", }, domain=DOMAIN, source="registration", @@ -118,3 +117,77 @@ async def test_notify_works(hass, aioclient_mock, setup_push_receiver): assert call_json["message"] == "Hello world" assert call_json["registration_info"]["app_id"] == "io.homeassistant.mobile_app" assert call_json["registration_info"]["app_version"] == "1.0" + + +async def test_notify_ws_works( + hass, aioclient_mock, setup_push_receiver, hass_ws_client +): + """Test notify works.""" + client = await hass_ws_client(hass) + + await client.send_json( + { + "id": 5, + "type": "mobile_app/push_notification_channel", + "webhook_id": "mock-webhook_id", + } + ) + + sub_result = await client.receive_json() + assert sub_result["success"] + + assert await hass.services.async_call( + "notify", "mobile_app_test", {"message": "Hello world"}, blocking=True + ) + + assert len(aioclient_mock.mock_calls) == 0 + + msg_result = await client.receive_json() + assert msg_result["event"] == {"message": "Hello world"} + + # Unsubscribe, now it should go over http + await client.send_json( + { + "id": 6, + "type": "unsubscribe_events", + "subscription": 5, + } + ) + sub_result = await client.receive_json() + assert sub_result["success"] + + assert await hass.services.async_call( + "notify", "mobile_app_test", {"message": "Hello world 2"}, blocking=True + ) + + assert len(aioclient_mock.mock_calls) == 1 + + # Test non-existing webhook ID + await client.send_json( + { + "id": 7, + "type": "mobile_app/push_notification_channel", + "webhook_id": "non-existing", + } + ) + sub_result = await client.receive_json() + assert not sub_result["success"] + assert sub_result["error"] == { + "code": "not_found", + "message": "Webhook ID not found", + } + + # Test webhook ID linked to other user + await client.send_json( + { + "id": 8, + "type": "mobile_app/push_notification_channel", + "webhook_id": "webhook_id_2", + } + ) + sub_result = await client.receive_json() + assert not sub_result["success"] + assert sub_result["error"] == { + "code": "unauthorized", + "message": "User not linked to this webhook ID", + }