From 9be2704c11ab0a05b9995bae5d91a9d2b43b25ed Mon Sep 17 00:00:00 2001 From: Ian Byrne Date: Sun, 13 Nov 2022 00:11:06 +0000 Subject: [PATCH] Add option to include attachments from remote URL to Discord notifications (#74811) * Add option to include attachments from remote URL to Discord notifications * Use aiohttp instead of requests for Discord Notify URL retrieval * Refactor discord notifications code * Remove unecessary images check in discord notifier --- homeassistant/components/discord/notify.py | 72 ++++++++++++++-- tests/components/discord/conftest.py | 45 ++++++++++ tests/components/discord/test_notify.py | 96 ++++++++++++++++++++++ 3 files changed, 208 insertions(+), 5 deletions(-) create mode 100644 tests/components/discord/conftest.py create mode 100644 tests/components/discord/test_notify.py diff --git a/homeassistant/components/discord/notify.py b/homeassistant/components/discord/notify.py index d97ce7042bc..8fcab7cefba 100644 --- a/homeassistant/components/discord/notify.py +++ b/homeassistant/components/discord/notify.py @@ -1,6 +1,7 @@ """Discord platform for notify component.""" from __future__ import annotations +from io import BytesIO import logging import os.path from typing import Any, cast @@ -15,6 +16,7 @@ from homeassistant.components.notify import ( ) from homeassistant.const import CONF_API_TOKEN from homeassistant.core import HomeAssistant +from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType _LOGGER = logging.getLogger(__name__) @@ -30,6 +32,10 @@ ATTR_EMBED_THUMBNAIL = "thumbnail" ATTR_EMBED_IMAGE = "image" ATTR_EMBED_URL = "url" ATTR_IMAGES = "images" +ATTR_URLS = "urls" +ATTR_VERIFY_SSL = "verify_ssl" + +MAX_ALLOWED_DOWNLOAD_SIZE_BYTES = 8000000 async def async_get_service( @@ -61,11 +67,54 @@ class DiscordNotificationService(BaseNotificationService): return False return True + async def async_get_file_from_url( + self, url: str, verify_ssl: bool, max_file_size: int + ) -> bytearray | None: + """Retrieve file bytes from URL.""" + if not self.hass.config.is_allowed_external_url(url): + _LOGGER.error("URL not allowed: %s", url) + return None + + session = async_get_clientsession(self.hass) + + async with session.get( + url, + ssl=verify_ssl, + timeout=30, + raise_for_status=True, + ) as resp: + content_length = resp.headers.get("Content-Length") + + if content_length is not None and int(content_length) > max_file_size: + _LOGGER.error( + "Attachment too large (Content-Length reports %s). Max size: %s bytes", + int(content_length), + max_file_size, + ) + return None + + file_size = 0 + byte_chunks = bytearray() + + async for byte_chunk, _ in resp.content.iter_chunks(): + file_size += len(byte_chunk) + if file_size > max_file_size: + _LOGGER.error( + "Attachment too large (Stream reports %s). Max size: %s bytes", + file_size, + max_file_size, + ) + return None + + byte_chunks.extend(byte_chunk) + + return byte_chunks + async def async_send_message(self, message: str, **kwargs: Any) -> None: """Login to Discord, send message to channel(s) and log out.""" nextcord.VoiceClient.warn_nacl = False discord_bot = nextcord.Client() - images = None + images = [] embedding = None if ATTR_TARGET not in kwargs: @@ -100,15 +149,28 @@ class DiscordNotificationService(BaseNotificationService): embeds.append(embed) if ATTR_IMAGES in data: - images = [] - for image in data.get(ATTR_IMAGES, []): image_exists = await self.hass.async_add_executor_job( self.file_exists, image ) + filename = os.path.basename(image) + if image_exists: - images.append(image) + images.append((image, filename)) + + if ATTR_URLS in data: + for url in data.get(ATTR_URLS, []): + file = await self.async_get_file_from_url( + url, + data.get(ATTR_VERIFY_SSL, True), + MAX_ALLOWED_DOWNLOAD_SIZE_BYTES, + ) + + if file is not None: + filename = os.path.basename(url) + + images.append((BytesIO(file), filename)) await discord_bot.login(self.token) @@ -116,7 +178,7 @@ class DiscordNotificationService(BaseNotificationService): for channelid in kwargs[ATTR_TARGET]: channelid = int(channelid) # Must create new instances of File for each channel. - files = [nextcord.File(image) for image in images] if images else [] + files = [nextcord.File(image, filename) for image, filename in images] try: channel = cast( Messageable, await discord_bot.fetch_channel(channelid) diff --git a/tests/components/discord/conftest.py b/tests/components/discord/conftest.py new file mode 100644 index 00000000000..c98944fdc85 --- /dev/null +++ b/tests/components/discord/conftest.py @@ -0,0 +1,45 @@ +"""Discord notification test helpers.""" +from http import HTTPStatus + +import pytest + +from homeassistant.components.discord.notify import DiscordNotificationService +from homeassistant.core import HomeAssistant + +from tests.test_util.aiohttp import AiohttpClientMocker + +MESSAGE = "Testing Discord Messenger platform" +CONTENT = b"TestContent" +URL_ATTACHMENT = "http://127.0.0.1:8080/image.jpg" +TARGET = "1234567890" + + +@pytest.fixture +def discord_notification_service(hass: HomeAssistant) -> DiscordNotificationService: + """Set up discord notification service.""" + hass.config.allowlist_external_urls.add(URL_ATTACHMENT) + return DiscordNotificationService(hass, "token") + + +@pytest.fixture +def discord_aiohttp_mock_factory( + aioclient_mock: AiohttpClientMocker, +) -> AiohttpClientMocker: + """Create Discord service mock from factory.""" + + def _discord_aiohttp_mock_factory( + headers: dict[str, str] = None, + ) -> AiohttpClientMocker: + if headers is not None: + aioclient_mock.get( + URL_ATTACHMENT, status=HTTPStatus.OK, content=CONTENT, headers=headers + ) + else: + aioclient_mock.get( + URL_ATTACHMENT, + status=HTTPStatus.OK, + content=CONTENT, + ) + return aioclient_mock + + return _discord_aiohttp_mock_factory diff --git a/tests/components/discord/test_notify.py b/tests/components/discord/test_notify.py new file mode 100644 index 00000000000..810898cdf73 --- /dev/null +++ b/tests/components/discord/test_notify.py @@ -0,0 +1,96 @@ +"""Test Discord notify.""" +import logging + +import pytest + +from homeassistant.components.discord.notify import DiscordNotificationService + +from .conftest import CONTENT, MESSAGE, URL_ATTACHMENT + +from tests.test_util.aiohttp import AiohttpClientMocker + + +async def test_send_message_without_target_logs_error( + discord_notification_service: DiscordNotificationService, + discord_aiohttp_mock_factory: AiohttpClientMocker, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test send message.""" + discord_aiohttp_mock = discord_aiohttp_mock_factory() + with caplog.at_level( + logging.ERROR, logger="homeassistant.components.discord.notify" + ): + await discord_notification_service.async_send_message(MESSAGE) + assert "No target specified" in caplog.text + assert discord_aiohttp_mock.call_count == 0 + + +async def test_get_file_from_url( + discord_notification_service: DiscordNotificationService, + discord_aiohttp_mock_factory: AiohttpClientMocker, +) -> None: + """Test getting a file from a URL.""" + headers = {"Content-Length": str(len(CONTENT))} + discord_aiohttp_mock = discord_aiohttp_mock_factory(headers) + result = await discord_notification_service.async_get_file_from_url( + URL_ATTACHMENT, True, len(CONTENT) + ) + + assert discord_aiohttp_mock.call_count == 1 + assert result == bytearray(CONTENT) + + +async def test_get_file_from_url_not_on_allowlist( + discord_notification_service: DiscordNotificationService, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test getting file from URL that isn't on the allowlist.""" + url = "http://dodgyurl.com" + with caplog.at_level( + logging.WARNING, logger="homeassistant.components.discord.notify" + ): + result = await discord_notification_service.async_get_file_from_url( + url, True, len(CONTENT) + ) + + assert f"URL not allowed: {url}" in caplog.text + assert result is None + + +async def test_get_file_from_url_with_large_attachment( + discord_notification_service: DiscordNotificationService, + discord_aiohttp_mock_factory: AiohttpClientMocker, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test getting file from URL with large attachment (per Content-Length header) throws error.""" + headers = {"Content-Length": str(len(CONTENT) + 1)} + discord_aiohttp_mock = discord_aiohttp_mock_factory(headers) + with caplog.at_level( + logging.WARNING, logger="homeassistant.components.discord.notify" + ): + result = await discord_notification_service.async_get_file_from_url( + URL_ATTACHMENT, True, len(CONTENT) + ) + + assert discord_aiohttp_mock.call_count == 1 + assert "Attachment too large (Content-Length reports" in caplog.text + assert result is None + + +async def test_get_file_from_url_with_large_attachment_no_header( + discord_notification_service: DiscordNotificationService, + discord_aiohttp_mock_factory: AiohttpClientMocker, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test getting file from URL with large attachment (per content length) throws error.""" + discord_aiohttp_mock = discord_aiohttp_mock_factory() + with caplog.at_level( + logging.WARNING, logger="homeassistant.components.discord.notify" + ): + result = await discord_notification_service.async_get_file_from_url( + URL_ATTACHMENT, True, len(CONTENT) - 1 + ) + + assert discord_aiohttp_mock.call_count == 1 + assert "Attachment too large (Stream reports" in caplog.text + assert result is None