From d7375f1a9c4a69858a65a56bd524f5a78ecab23c Mon Sep 17 00:00:00 2001 From: Wictor Date: Sun, 3 Apr 2022 05:39:14 +0200 Subject: [PATCH] Refactor telegram_bot polling/webhooks platforms and add tests (#66433) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Pär Berge --- .../components/telegram_bot/__init__.py | 227 ++++++++---------- .../components/telegram_bot/polling.py | 79 ++---- .../components/telegram_bot/webhooks.py | 151 +++++++----- requirements_test_all.txt | 6 + tests/components/telegram_bot/__init__.py | 1 + tests/components/telegram_bot/conftest.py | 187 +++++++++++++++ .../telegram_bot/test_telegram_bot.py | 112 +++++++++ 7 files changed, 514 insertions(+), 249 deletions(-) create mode 100644 tests/components/telegram_bot/__init__.py create mode 100644 tests/components/telegram_bot/conftest.py create mode 100644 tests/components/telegram_bot/test_telegram_bot.py diff --git a/homeassistant/components/telegram_bot/__init__.py b/homeassistant/components/telegram_bot/__init__.py index cebdd4f4573..29dbabcbbfe 100644 --- a/homeassistant/components/telegram_bot/__init__.py +++ b/homeassistant/components/telegram_bot/__init__.py @@ -1,20 +1,27 @@ """Support to send and receive Telegram messages.""" +from __future__ import annotations + from functools import partial import importlib import io from ipaddress import ip_network import logging +from typing import Any import requests from requests.auth import HTTPBasicAuth, HTTPDigestAuth from telegram import ( Bot, + CallbackQuery, InlineKeyboardButton, InlineKeyboardMarkup, + Message, ReplyKeyboardMarkup, ReplyKeyboardRemove, + Update, ) from telegram.error import TelegramError +from telegram.ext import CallbackContext, Filters from telegram.parsemode import ParseMode from telegram.utils.request import Request import voluptuous as vol @@ -311,14 +318,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return False for p_config in config[DOMAIN]: - + # Each platform config gets its own bot + bot = initialize_bot(p_config) p_type = p_config.get(CONF_PLATFORM) platform = importlib.import_module(f".{p_config[CONF_PLATFORM]}", __name__) _LOGGER.info("Setting up %s.%s", DOMAIN, p_type) try: - receiver_service = await platform.async_setup_platform(hass, p_config) + receiver_service = await platform.async_setup_platform(hass, bot, p_config) if receiver_service is False: _LOGGER.error("Failed to initialize Telegram bot %s", p_type) return False @@ -327,7 +335,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: _LOGGER.exception("Error setting up platform %s", p_type) return False - bot = initialize_bot(p_config) notify_service = TelegramNotificationService( hass, bot, p_config.get(CONF_ALLOWED_CHAT_IDS), p_config.get(ATTR_PARSER) ) @@ -416,7 +423,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: def initialize_bot(p_config): """Initialize telegram bot with proxy support.""" - api_key = p_config.get(CONF_API_KEY) proxy_url = p_config.get(CONF_PROXY_URL) proxy_params = p_config.get(CONF_PROXY_PARAMS) @@ -435,7 +441,6 @@ class TelegramNotificationService: def __init__(self, hass, bot, allowed_chat_ids, parser): """Initialize the service.""" - self.allowed_chat_ids = allowed_chat_ids self._default_user = self.allowed_chat_ids[0] self._last_message_id = {user: None for user in self.allowed_chat_ids} @@ -495,7 +500,6 @@ class TelegramNotificationService: - a string like: `/cmd1, /cmd2, /cmd3` - or a string like: `text_b1:/cmd1, text_b2:/cmd2` """ - buttons = [] if isinstance(row_keyboard, str): for key in row_keyboard.split(","): @@ -566,7 +570,6 @@ class TelegramNotificationService: def _send_msg(self, func_send, msg_error, message_tag, *args_msg, **kwargs_msg): """Send one message.""" - try: out = func_send(*args_msg, **kwargs_msg) if not isinstance(out, bool) and hasattr(out, ATTR_MESSAGEID): @@ -857,131 +860,99 @@ class TelegramNotificationService: class BaseTelegramBotEntity: """The base class for the telegram bot.""" - def __init__(self, hass, allowed_chat_ids): + def __init__(self, hass, config): """Initialize the bot base class.""" - self.allowed_chat_ids = allowed_chat_ids + self.allowed_chat_ids = config[CONF_ALLOWED_CHAT_IDS] self.hass = hass - def _get_message_data(self, msg_data): - """Return boolean msg_data_is_ok and dict msg_data.""" - if not msg_data: - return False, None - bad_fields = ( - "text" not in msg_data and "data" not in msg_data and "chat" not in msg_data - ) - if bad_fields or "from" not in msg_data: - # Message is not correct. - _LOGGER.error("Incoming message does not have required data (%s)", msg_data) - return False, None + def handle_update(self, update: Update, context: CallbackContext) -> bool: + """Handle updates from bot dispatcher set up by the respective platform.""" + _LOGGER.debug("Handling update %s", update) + if not self.authorize_update(update): + return False - if ( - msg_data["from"].get("id") not in self.allowed_chat_ids - and msg_data["message"]["chat"].get("id") not in self.allowed_chat_ids - ): - # Neither from id nor chat id was in allowed_chat_ids, - # origin is not allowed. - _LOGGER.error("Incoming message is not allowed (%s)", msg_data) - return True, None - - data = { - ATTR_USER_ID: msg_data["from"]["id"], - ATTR_FROM_FIRST: msg_data["from"]["first_name"], - } - if "message_id" in msg_data: - data[ATTR_MSGID] = msg_data["message_id"] - if "last_name" in msg_data["from"]: - data[ATTR_FROM_LAST] = msg_data["from"]["last_name"] - if "chat" in msg_data: - data[ATTR_CHAT_ID] = msg_data["chat"]["id"] - elif ATTR_MESSAGE in msg_data and "chat" in msg_data[ATTR_MESSAGE]: - data[ATTR_CHAT_ID] = msg_data[ATTR_MESSAGE]["chat"]["id"] - - return True, data - - def _get_channel_post_data(self, msg_data): - """Return boolean msg_data_is_ok and dict msg_data.""" - if not msg_data: - return False, None - - if "sender_chat" in msg_data and "chat" in msg_data and "text" in msg_data: - if ( - msg_data["sender_chat"].get("id") not in self.allowed_chat_ids - and msg_data["chat"].get("id") not in self.allowed_chat_ids - ): - # Neither sender_chat id nor chat id was in allowed_chat_ids, - # origin is not allowed. - _LOGGER.error("Incoming message is not allowed (%s)", msg_data) - return True, None - - data = { - ATTR_MSGID: msg_data["message_id"], - ATTR_CHAT_ID: msg_data["chat"]["id"], - ATTR_TEXT: msg_data["text"], - } - return True, data - - _LOGGER.error("Incoming message does not have required data (%s)", msg_data) - return False, None - - def process_message(self, data): - """Check for basic message rules and fire an event if message is ok.""" - if ATTR_MSG in data or ATTR_EDITED_MSG in data: - event = EVENT_TELEGRAM_COMMAND - if ATTR_MSG in data: - data = data.get(ATTR_MSG) - else: - data = data.get(ATTR_EDITED_MSG) - message_ok, event_data = self._get_message_data(data) - if event_data is None: - return message_ok - - if ATTR_MSGID in data: - event_data[ATTR_MSGID] = data[ATTR_MSGID] - - if "text" in data: - if data["text"][0] == "/": - pieces = data["text"].split(" ") - event_data[ATTR_COMMAND] = pieces[0] - event_data[ATTR_ARGS] = pieces[1:] - else: - event_data[ATTR_TEXT] = data["text"] - event = EVENT_TELEGRAM_TEXT - else: - _LOGGER.warning("Message without text data received: %s", data) - event_data[ATTR_TEXT] = str(data) - event = EVENT_TELEGRAM_TEXT - - self.hass.bus.async_fire(event, event_data) - return True - if ATTR_CALLBACK_QUERY in data: - event = EVENT_TELEGRAM_CALLBACK - data = data.get(ATTR_CALLBACK_QUERY) - message_ok, event_data = self._get_message_data(data) - if event_data is None: - return message_ok - - query_data = event_data[ATTR_DATA] = data[ATTR_DATA] - - if query_data[0] == "/": - pieces = query_data.split(" ") - event_data[ATTR_COMMAND] = pieces[0] - event_data[ATTR_ARGS] = pieces[1:] - - event_data[ATTR_MSG] = data[ATTR_MSG] - event_data[ATTR_CHAT_INSTANCE] = data[ATTR_CHAT_INSTANCE] - event_data[ATTR_MSGID] = data[ATTR_MSGID] - - self.hass.bus.async_fire(event, event_data) - return True - if ATTR_CHANNEL_POST in data: - event = EVENT_TELEGRAM_TEXT - data = data.get(ATTR_CHANNEL_POST) - message_ok, event_data = self._get_channel_post_data(data) - if event_data is None: - return message_ok - - self.hass.bus.async_fire(event, event_data) + # establish event type: text, command or callback_query + if update.callback_query: + # NOTE: Check for callback query first since effective message will be populated with the message + # in .callback_query (python-telegram-bot docs are wrong) + event_type, event_data = self._get_callback_query_event_data( + update.callback_query + ) + elif update.effective_message: + event_type, event_data = self._get_message_event_data( + update.effective_message + ) + else: + _LOGGER.warning("Unhandled update: %s", update) return True - _LOGGER.warning("Message with unknown data received: %s", data) + _LOGGER.debug("Firing event %s: %s", event_type, event_data) + self.hass.bus.fire(event_type, event_data) return True + + @staticmethod + def _get_command_event_data(command_text: str) -> dict[str, str | list]: + if not command_text.startswith("/"): + return {} + command_parts = command_text.split() + command = command_parts[0] + args = command_parts[1:] + return {ATTR_COMMAND: command, ATTR_ARGS: args} + + def _get_message_event_data(self, message: Message) -> tuple[str, dict[str, Any]]: + event_data: dict[str, Any] = { + ATTR_MSGID: message.message_id, + ATTR_CHAT_ID: message.chat.id, + } + if Filters.command.filter(message): + # This is a command message - set event type to command and split data into command and args + event_type = EVENT_TELEGRAM_COMMAND + event_data.update(self._get_command_event_data(message.text)) + else: + event_type = EVENT_TELEGRAM_TEXT + event_data[ATTR_TEXT] = message.text + + if message.from_user: + event_data.update( + { + ATTR_USER_ID: message.from_user.id, + ATTR_FROM_FIRST: message.from_user.first_name, + ATTR_FROM_LAST: message.from_user.last_name, + } + ) + + return event_type, event_data + + def _get_callback_query_event_data( + self, callback_query: CallbackQuery + ) -> tuple[str, dict[str, Any]]: + event_type = EVENT_TELEGRAM_CALLBACK + event_data: dict[str, Any] = { + ATTR_MSGID: callback_query.id, + ATTR_CHAT_INSTANCE: callback_query.chat_instance, + ATTR_DATA: callback_query.data, + ATTR_MSG: None, + ATTR_CHAT_ID: None, + } + if callback_query.message: + event_data[ATTR_MSG] = callback_query.message.to_dict() + event_data[ATTR_CHAT_ID] = callback_query.message.chat.id + + # Split data into command and args if possible + event_data.update(self._get_command_event_data(callback_query.data)) + + return event_type, event_data + + def authorize_update(self, update: Update) -> bool: + """Make sure either user or chat is in allowed_chat_ids.""" + from_user = update.effective_user.id if update.effective_user else None + from_chat = update.effective_chat.id if update.effective_chat else None + if from_user in self.allowed_chat_ids or from_chat in self.allowed_chat_ids: + return True + _LOGGER.error( + "Unauthorized update - neither user id %s nor chat id %s is in allowed chats: %s", + from_user, + from_chat, + self.allowed_chat_ids, + ) + return False diff --git a/homeassistant/components/telegram_bot/polling.py b/homeassistant/components/telegram_bot/polling.py index b617826411d..4c0de6ade1e 100644 --- a/homeassistant/components/telegram_bot/polling.py +++ b/homeassistant/components/telegram_bot/polling.py @@ -3,31 +3,21 @@ import logging from telegram import Update from telegram.error import NetworkError, RetryAfter, TelegramError, TimedOut -from telegram.ext import CallbackContext, Dispatcher, Handler, Updater -from telegram.utils.types import HandlerArg +from telegram.ext import CallbackContext, TypeHandler, Updater from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP -from . import CONF_ALLOWED_CHAT_IDS, BaseTelegramBotEntity, initialize_bot +from . import BaseTelegramBotEntity _LOGGER = logging.getLogger(__name__) -async def async_setup_platform(hass, config): +async def async_setup_platform(hass, bot, config): """Set up the Telegram polling platform.""" - bot = initialize_bot(config) - pol = TelegramPoll(bot, hass, config[CONF_ALLOWED_CHAT_IDS]) + pollbot = PollBot(hass, bot, config) - def _start_bot(_event): - """Start the bot.""" - pol.start_polling() - - def _stop_bot(_event): - """Stop the bot.""" - pol.stop_polling() - - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, _start_bot) - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _stop_bot) + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, pollbot.start_polling) + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, pollbot.stop_polling) return True @@ -43,57 +33,28 @@ def process_error(update: Update, context: CallbackContext): _LOGGER.error('Update "%s" caused error: "%s"', update, context.error) -def message_handler(handler): - """Create messages handler.""" +class PollBot(BaseTelegramBotEntity): + """ + Controls the Updater object that holds the bot and a dispatcher. - class MessageHandler(Handler): - """Telegram bot message handler.""" - - def __init__(self): - """Initialize the messages handler instance.""" - super().__init__(handler) - - def check_update(self, update): - """Check is update valid.""" - return isinstance(update, Update) - - def handle_update( - self, - update: HandlerArg, - dispatcher: Dispatcher, - check_result: object, - context: CallbackContext = None, - ): - """Handle update.""" - optional_args = self.collect_optional_args(dispatcher, update) - context.args = optional_args - return self.callback(update, context) - - return MessageHandler() - - -class TelegramPoll(BaseTelegramBotEntity): - """Asyncio telegram incoming message handler.""" - - def __init__(self, bot, hass, allowed_chat_ids): - """Initialize the polling instance.""" - - BaseTelegramBotEntity.__init__(self, hass, allowed_chat_ids) + The dispatcher is set up by the super class to pass telegram updates to `self.handle_update` + """ + def __init__(self, hass, bot, config): + """Create Updater and Dispatcher before calling super().""" + self.bot = bot self.updater = Updater(bot=bot, workers=4) self.dispatcher = self.updater.dispatcher - - self.dispatcher.add_handler(message_handler(self.process_update)) + self.dispatcher.add_handler(TypeHandler(Update, self.handle_update)) self.dispatcher.add_error_handler(process_error) + super().__init__(hass, config) - def start_polling(self): + def start_polling(self, event=None): """Start the polling task.""" + _LOGGER.debug("Starting polling") self.updater.start_polling() - def stop_polling(self): + def stop_polling(self, event=None): """Stop the polling task.""" + _LOGGER.debug("Stopping polling") self.updater.stop() - - def process_update(self, update: HandlerArg, context: CallbackContext): - """Process incoming message.""" - self.process_message(update.to_dict()) diff --git a/homeassistant/components/telegram_bot/webhooks.py b/homeassistant/components/telegram_bot/webhooks.py index c1e86129ebb..8b94cb66496 100644 --- a/homeassistant/components/telegram_bot/webhooks.py +++ b/homeassistant/components/telegram_bot/webhooks.py @@ -4,90 +4,115 @@ from http import HTTPStatus from ipaddress import ip_address import logging +from telegram import Update from telegram.error import TimedOut +from telegram.ext import Dispatcher, TypeHandler from homeassistant.components.http import HomeAssistantView from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.helpers.network import get_url -from . import ( - CONF_ALLOWED_CHAT_IDS, - CONF_TRUSTED_NETWORKS, - CONF_URL, - BaseTelegramBotEntity, - initialize_bot, -) +from . import CONF_TRUSTED_NETWORKS, CONF_URL, BaseTelegramBotEntity _LOGGER = logging.getLogger(__name__) -TELEGRAM_HANDLER_URL = "/api/telegram_webhooks" -REMOVE_HANDLER_URL = "" +TELEGRAM_WEBHOOK_URL = "/api/telegram_webhooks" +REMOVE_WEBHOOK_URL = "" -async def async_setup_platform(hass, config): +async def async_setup_platform(hass, bot, config): """Set up the Telegram webhooks platform.""" + pushbot = PushBot(hass, bot, config) - bot = initialize_bot(config) - - current_status = await hass.async_add_executor_job(bot.getWebhookInfo) - if not (base_url := config.get(CONF_URL)): - base_url = get_url(hass, require_ssl=True, allow_internal=False) - - # Some logging of Bot current status: - last_error_date = getattr(current_status, "last_error_date", None) - if (last_error_date is not None) and (isinstance(last_error_date, int)): - last_error_date = dt.datetime.fromtimestamp(last_error_date) - _LOGGER.info( - "Telegram webhook last_error_date: %s. Status: %s", - last_error_date, - current_status, - ) - else: - _LOGGER.debug("telegram webhook Status: %s", current_status) - - handler_url = f"{base_url}{TELEGRAM_HANDLER_URL}" - if not handler_url.startswith("https"): - _LOGGER.error("Invalid telegram webhook %s must be https", handler_url) + if not pushbot.webhook_url.startswith("https"): + _LOGGER.error("Invalid telegram webhook %s must be https", pushbot.webhook_url) return False - def _try_to_set_webhook(): - retry_num = 0 - while retry_num < 3: - try: - return bot.setWebhook(handler_url, timeout=5) - except TimedOut: - retry_num += 1 - _LOGGER.warning("Timeout trying to set webhook (retry #%d)", retry_num) + webhook_registered = await pushbot.register_webhook() + if not webhook_registered: + return False - if current_status and current_status["url"] != handler_url: - result = await hass.async_add_executor_job(_try_to_set_webhook) - if result: - _LOGGER.info("Set new telegram webhook %s", handler_url) - else: - _LOGGER.error("Set telegram webhook failed %s", handler_url) - return False - - hass.bus.async_listen_once( - EVENT_HOMEASSISTANT_STOP, lambda event: bot.setWebhook(REMOVE_HANDLER_URL) - ) + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, pushbot.deregister_webhook) hass.http.register_view( - BotPushReceiver( - hass, config[CONF_ALLOWED_CHAT_IDS], config[CONF_TRUSTED_NETWORKS] - ) + PushBotView(hass, bot, pushbot.dispatcher, config[CONF_TRUSTED_NETWORKS]) ) return True -class BotPushReceiver(HomeAssistantView, BaseTelegramBotEntity): - """Handle pushes from Telegram.""" +class PushBot(BaseTelegramBotEntity): + """Handles all the push/webhook logic and passes telegram updates to `self.handle_update`.""" + + def __init__(self, hass, bot, config): + """Create Dispatcher before calling super().""" + self.bot = bot + self.trusted_networks = config[CONF_TRUSTED_NETWORKS] + # Dumb dispatcher that just gets our updates to our handler callback (self.handle_update) + self.dispatcher = Dispatcher(bot, None) + self.dispatcher.add_handler(TypeHandler(Update, self.handle_update)) + super().__init__(hass, config) + + self.base_url = config.get(CONF_URL) or get_url( + hass, require_ssl=True, allow_internal=False + ) + self.webhook_url = f"{self.base_url}{TELEGRAM_WEBHOOK_URL}" + + def _try_to_set_webhook(self): + _LOGGER.debug("Registering webhook URL: %s", self.webhook_url) + retry_num = 0 + while retry_num < 3: + try: + return self.bot.set_webhook(self.webhook_url, timeout=5) + except TimedOut: + retry_num += 1 + _LOGGER.warning("Timeout trying to set webhook (retry #%d)", retry_num) + + return False + + async def register_webhook(self): + """Query telegram and register the URL for our webhook.""" + current_status = await self.hass.async_add_executor_job( + self.bot.get_webhook_info + ) + # Some logging of Bot current status: + last_error_date = getattr(current_status, "last_error_date", None) + if (last_error_date is not None) and (isinstance(last_error_date, int)): + last_error_date = dt.datetime.fromtimestamp(last_error_date) + _LOGGER.debug( + "Telegram webhook last_error_date: %s. Status: %s", + last_error_date, + current_status, + ) + else: + _LOGGER.debug("telegram webhook status: %s", current_status) + + if current_status and current_status["url"] != self.webhook_url: + result = await self.hass.async_add_executor_job(self._try_to_set_webhook) + if result: + _LOGGER.info("Set new telegram webhook %s", self.webhook_url) + else: + _LOGGER.error("Set telegram webhook failed %s", self.webhook_url) + return False + + return True + + def deregister_webhook(self, event=None): + """Query telegram and deregister the URL for our webhook.""" + _LOGGER.debug("Deregistering webhook URL") + return self.bot.delete_webhook() + + +class PushBotView(HomeAssistantView): + """View for handling webhook calls from Telegram.""" requires_auth = False - url = TELEGRAM_HANDLER_URL + url = TELEGRAM_WEBHOOK_URL name = "telegram_webhooks" - def __init__(self, hass, allowed_chat_ids, trusted_networks): - """Initialize the class.""" - BaseTelegramBotEntity.__init__(self, hass, allowed_chat_ids) + def __init__(self, hass, bot, dispatcher, trusted_networks): + """Initialize by storing stuff needed for setting up our webhook endpoint.""" + self.hass = hass + self.bot = bot + self.dispatcher = dispatcher self.trusted_networks = trusted_networks async def post(self, request): @@ -98,10 +123,12 @@ class BotPushReceiver(HomeAssistantView, BaseTelegramBotEntity): return self.json_message("Access denied", HTTPStatus.UNAUTHORIZED) try: - data = await request.json() + update_data = await request.json() except ValueError: return self.json_message("Invalid JSON", HTTPStatus.BAD_REQUEST) - if not self.process_message(data): - return self.json_message("Invalid message", HTTPStatus.BAD_REQUEST) + update = Update.de_json(update_data, self.bot) + _LOGGER.debug("Received Update on %s: %s", self.url, update) + await self.hass.async_add_executor_job(self.dispatcher.process_update, update) + return None diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 6e7bc79fd0e..bd746c58bd7 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -26,6 +26,9 @@ PyQRCode==1.2.1 # homeassistant.components.rmvtransport PyRMVtransport==0.3.3 +# homeassistant.components.telegram_bot +PySocks==1.7.1 + # homeassistant.components.switchbot # PySwitchbot==0.13.3 @@ -1259,6 +1262,9 @@ python-songpal==0.14.1 # homeassistant.components.tado python-tado==0.12.0 +# homeassistant.components.telegram_bot +python-telegram-bot==13.1 + # homeassistant.components.awair python_awair==0.2.3 diff --git a/tests/components/telegram_bot/__init__.py b/tests/components/telegram_bot/__init__.py new file mode 100644 index 00000000000..810b1bc268f --- /dev/null +++ b/tests/components/telegram_bot/__init__.py @@ -0,0 +1 @@ +"""Tests for telegram_bot integration.""" diff --git a/tests/components/telegram_bot/conftest.py b/tests/components/telegram_bot/conftest.py new file mode 100644 index 00000000000..61818e4c377 --- /dev/null +++ b/tests/components/telegram_bot/conftest.py @@ -0,0 +1,187 @@ +"""Tests for the telegram_bot integration.""" +from unittest.mock import patch + +import pytest +from telegram.ext.dispatcher import Dispatcher + +from homeassistant.components.telegram_bot import ( + CONF_ALLOWED_CHAT_IDS, + CONF_TRUSTED_NETWORKS, + DOMAIN, +) +from homeassistant.const import CONF_API_KEY, CONF_PLATFORM, CONF_URL +from homeassistant.setup import async_setup_component + + +@pytest.fixture +def config_webhooks(): + """Fixture for a webhooks platform configuration.""" + return { + DOMAIN: [ + { + CONF_PLATFORM: "webhooks", + CONF_URL: "https://test", + CONF_TRUSTED_NETWORKS: ["127.0.0.1"], + CONF_API_KEY: "1234567890:ABC", + CONF_ALLOWED_CHAT_IDS: [ + # "me" + 12345678, + # Some chat + -123456789, + ], + } + ] + } + + +@pytest.fixture +def config_polling(): + """Fixture for a polling platform configuration.""" + return { + DOMAIN: [ + { + CONF_PLATFORM: "polling", + CONF_API_KEY: "1234567890:ABC", + CONF_ALLOWED_CHAT_IDS: [ + # "me" + 12345678, + # Some chat + -123456789, + ], + } + ] + } + + +@pytest.fixture +def mock_register_webhook(): + """Mock calls made by telegram_bot when (de)registering webhook.""" + with patch( + "homeassistant.components.telegram_bot.webhooks.PushBot.register_webhook", + return_value=True, + ), patch( + "homeassistant.components.telegram_bot.webhooks.PushBot.deregister_webhook", + return_value=True, + ): + yield + + +@pytest.fixture +def update_message_command(): + """Fixture for mocking an incoming update of type message/command.""" + return { + "update_id": 1, + "message": { + "message_id": 1, + "from": { + "id": 12345678, + "is_bot": False, + "first_name": "Firstname", + "username": "some_username", + "language_code": "en", + }, + "chat": { + "id": -123456789, + "title": "SomeChat", + "type": "group", + "all_members_are_administrators": True, + }, + "date": 1644518189, + "text": "/command", + "entities": [ + { + "type": "bot_command", + "offset": 0, + "length": 7, + } + ], + }, + } + + +@pytest.fixture +def update_message_text(): + """Fixture for mocking an incoming update of type message/text.""" + return { + "update_id": 1, + "message": { + "message_id": 1, + "date": 1441645532, + "from": { + "id": 12345678, + "is_bot": False, + "last_name": "Test Lastname", + "first_name": "Test Firstname", + "username": "Testusername", + }, + "chat": { + "last_name": "Test Lastname", + "id": 1111111, + "type": "private", + "first_name": "Test Firstname", + "username": "Testusername", + }, + "text": "HELLO", + }, + } + + +@pytest.fixture +def unauthorized_update_message_text(update_message_text): + """Fixture for mocking an incoming update of type message/text that is not in our `allowed_chat_ids`.""" + update_message_text["message"]["from"]["id"] = 1234 + update_message_text["message"]["chat"]["id"] = 1234 + return update_message_text + + +@pytest.fixture +def update_callback_query(): + """Fixture for mocking an incoming update of type callback_query.""" + return { + "update_id": 1, + "callback_query": { + "id": "4382bfdwdsb323b2d9", + "from": { + "id": 12345678, + "type": "private", + "is_bot": False, + "last_name": "Test Lastname", + "first_name": "Test Firstname", + "username": "Testusername", + }, + "chat_instance": "aaa111", + "data": "Data from button callback", + "inline_message_id": "1234csdbsk4839", + }, + } + + +@pytest.fixture +async def webhook_platform(hass, config_webhooks, mock_register_webhook): + """Fixture for setting up the webhooks platform using appropriate config and mocks.""" + await async_setup_component( + hass, + DOMAIN, + config_webhooks, + ) + await hass.async_block_till_done() + + +@pytest.fixture +async def polling_platform(hass, config_polling): + """Fixture for setting up the polling platform using appropriate config and mocks.""" + await async_setup_component( + hass, + DOMAIN, + config_polling, + ) + await hass.async_block_till_done() + + +@pytest.fixture(autouse=True) +def clear_dispatcher(): + """Clear the singleton that telegram.ext.dispatcher.Dispatcher sets on itself.""" + yield + Dispatcher._set_singleton(None) + # This is how python-telegram-bot resets the dispatcher in their test suite + Dispatcher._Dispatcher__singleton_semaphore.release() diff --git a/tests/components/telegram_bot/test_telegram_bot.py b/tests/components/telegram_bot/test_telegram_bot.py new file mode 100644 index 00000000000..9b099a180f7 --- /dev/null +++ b/tests/components/telegram_bot/test_telegram_bot.py @@ -0,0 +1,112 @@ +"""Tests for the telegram_bot component.""" +from telegram import Update +from telegram.ext.dispatcher import Dispatcher + +from homeassistant.components.telegram_bot import DOMAIN, SERVICE_SEND_MESSAGE +from homeassistant.components.telegram_bot.webhooks import TELEGRAM_WEBHOOK_URL + +from tests.common import async_capture_events + + +async def test_webhook_platform_init(hass, webhook_platform): + """Test initialization of the webhooks platform.""" + assert hass.services.has_service(DOMAIN, SERVICE_SEND_MESSAGE) is True + + +async def test_polling_platform_init(hass, polling_platform): + """Test initialization of the polling platform.""" + assert hass.services.has_service(DOMAIN, SERVICE_SEND_MESSAGE) is True + + +async def test_webhook_endpoint_generates_telegram_text_event( + hass, webhook_platform, hass_client, update_message_text +): + """POST to the configured webhook endpoint and assert fired `telegram_text` event.""" + client = await hass_client() + events = async_capture_events(hass, "telegram_text") + + response = await client.post(TELEGRAM_WEBHOOK_URL, json=update_message_text) + assert response.status == 200 + assert (await response.read()).decode("utf-8") == "" + + # Make sure event has fired + await hass.async_block_till_done() + + assert len(events) == 1 + assert events[0].data["text"] == update_message_text["message"]["text"] + + +async def test_webhook_endpoint_generates_telegram_command_event( + hass, webhook_platform, hass_client, update_message_command +): + """POST to the configured webhook endpoint and assert fired `telegram_command` event.""" + client = await hass_client() + events = async_capture_events(hass, "telegram_command") + + response = await client.post(TELEGRAM_WEBHOOK_URL, json=update_message_command) + assert response.status == 200 + assert (await response.read()).decode("utf-8") == "" + + # Make sure event has fired + await hass.async_block_till_done() + + assert len(events) == 1 + assert events[0].data["command"] == update_message_command["message"]["text"] + + +async def test_webhook_endpoint_generates_telegram_callback_event( + hass, webhook_platform, hass_client, update_callback_query +): + """POST to the configured webhook endpoint and assert fired `telegram_callback` event.""" + client = await hass_client() + events = async_capture_events(hass, "telegram_callback") + + response = await client.post(TELEGRAM_WEBHOOK_URL, json=update_callback_query) + assert response.status == 200 + assert (await response.read()).decode("utf-8") == "" + + # Make sure event has fired + await hass.async_block_till_done() + + assert len(events) == 1 + assert events[0].data["data"] == update_callback_query["callback_query"]["data"] + + +async def test_polling_platform_message_text_update( + hass, polling_platform, update_message_text +): + """Provide the `PollBot`s `Dispatcher` with an `Update` and assert fired `telegram_text` event.""" + events = async_capture_events(hass, "telegram_text") + + def telegram_dispatcher_callback(): + dispatcher = Dispatcher.get_instance() + update = Update.de_json(update_message_text, dispatcher.bot) + dispatcher.process_update(update) + + # python-telegram-bots `Updater` uses threading, so we need to schedule its callback in a sync context. + await hass.async_add_executor_job(telegram_dispatcher_callback) + + # Make sure event has fired + await hass.async_block_till_done() + + assert len(events) == 1 + assert events[0].data["text"] == update_message_text["message"]["text"] + + +async def test_webhook_endpoint_unauthorized_update_doesnt_generate_telegram_text_event( + hass, webhook_platform, hass_client, unauthorized_update_message_text +): + """Update with unauthorized user/chat should not trigger event.""" + client = await hass_client() + events = async_capture_events(hass, "telegram_text") + + response = await client.post( + TELEGRAM_WEBHOOK_URL, json=unauthorized_update_message_text + ) + assert response.status == 200 + assert (await response.read()).decode("utf-8") == "" + + # Make sure any events would have fired + await hass.async_block_till_done() + + assert len(events) == 0