From 9add251b0a7e7b8362e8584017b25e002b0536db Mon Sep 17 00:00:00 2001 From: Denis Shulyaka Date: Tue, 14 May 2024 16:48:59 +0300 Subject: [PATCH] Add context to `telegram_bot` events (#109920) * Add context for received messages events * Add context for sent messages events * ruff * ruff * ruff * Removed user_id mapping * Add tests --- .../components/telegram_bot/__init__.py | 76 +++++++++++++------ tests/components/telegram_bot/conftest.py | 13 +++- .../telegram_bot/test_telegram_bot.py | 30 +++++++- 3 files changed, 94 insertions(+), 25 deletions(-) diff --git a/homeassistant/components/telegram_bot/__init__.py b/homeassistant/components/telegram_bot/__init__.py index 4c1eb8ff795..7a056665ed4 100644 --- a/homeassistant/components/telegram_bot/__init__.py +++ b/homeassistant/components/telegram_bot/__init__.py @@ -36,7 +36,7 @@ from homeassistant.const import ( HTTP_BEARER_AUTHENTICATION, HTTP_DIGEST_AUTHENTICATION, ) -from homeassistant.core import HomeAssistant, ServiceCall +from homeassistant.core import Context, HomeAssistant, ServiceCall from homeassistant.exceptions import TemplateError from homeassistant.helpers import config_validation as cv, issue_registry as ir from homeassistant.helpers.typing import ConfigType @@ -426,7 +426,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: _LOGGER.debug("New telegram message %s: %s", msgtype, kwargs) if msgtype == SERVICE_SEND_MESSAGE: - await notify_service.send_message(**kwargs) + await notify_service.send_message(context=service.context, **kwargs) elif msgtype in [ SERVICE_SEND_PHOTO, SERVICE_SEND_ANIMATION, @@ -434,19 +434,23 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: SERVICE_SEND_VOICE, SERVICE_SEND_DOCUMENT, ]: - await notify_service.send_file(msgtype, **kwargs) + await notify_service.send_file(msgtype, context=service.context, **kwargs) elif msgtype == SERVICE_SEND_STICKER: - await notify_service.send_sticker(**kwargs) + await notify_service.send_sticker(context=service.context, **kwargs) elif msgtype == SERVICE_SEND_LOCATION: - await notify_service.send_location(**kwargs) + await notify_service.send_location(context=service.context, **kwargs) elif msgtype == SERVICE_SEND_POLL: - await notify_service.send_poll(**kwargs) + await notify_service.send_poll(context=service.context, **kwargs) elif msgtype == SERVICE_ANSWER_CALLBACK_QUERY: - await notify_service.answer_callback_query(**kwargs) + await notify_service.answer_callback_query( + context=service.context, **kwargs + ) elif msgtype == SERVICE_DELETE_MESSAGE: - await notify_service.delete_message(**kwargs) + await notify_service.delete_message(context=service.context, **kwargs) else: - await notify_service.edit_message(msgtype, **kwargs) + await notify_service.edit_message( + msgtype, context=service.context, **kwargs + ) # Register notification services for service_notif, schema in SERVICE_MAP.items(): @@ -663,7 +667,7 @@ class TelegramNotificationService: return params async def _send_msg( - self, func_send, msg_error, message_tag, *args_msg, **kwargs_msg + self, func_send, msg_error, message_tag, *args_msg, context=None, **kwargs_msg ): """Send one message.""" try: @@ -684,7 +688,9 @@ class TelegramNotificationService: } if message_tag is not None: event_data[ATTR_MESSAGE_TAG] = message_tag - self.hass.bus.async_fire(EVENT_TELEGRAM_SENT, event_data) + self.hass.bus.async_fire( + EVENT_TELEGRAM_SENT, event_data, context=context + ) elif not isinstance(out, bool): _LOGGER.warning( "Update last message: out_type:%s, out=%s", type(out), out @@ -696,7 +702,7 @@ class TelegramNotificationService: return None return out - async def send_message(self, message="", target=None, **kwargs): + async def send_message(self, message="", target=None, context=None, **kwargs): """Send a message to one or multiple pre-allowed chat IDs.""" title = kwargs.get(ATTR_TITLE) text = f"{title}\n{message}" if title else message @@ -715,15 +721,21 @@ class TelegramNotificationService: reply_to_message_id=params[ATTR_REPLY_TO_MSGID], reply_markup=params[ATTR_REPLYMARKUP], read_timeout=params[ATTR_TIMEOUT], + context=context, ) - async def delete_message(self, chat_id=None, **kwargs): + async def delete_message(self, chat_id=None, context=None, **kwargs): """Delete a previously sent message.""" chat_id = self._get_target_chat_ids(chat_id)[0] message_id, _ = self._get_msg_ids(kwargs, chat_id) _LOGGER.debug("Delete message %s in chat ID %s", message_id, chat_id) deleted = await self._send_msg( - self.bot.delete_message, "Error deleting message", None, chat_id, message_id + self.bot.delete_message, + "Error deleting message", + None, + chat_id, + message_id, + context=context, ) # reduce message_id anyway: if self._last_message_id[chat_id] is not None: @@ -731,7 +743,7 @@ class TelegramNotificationService: self._last_message_id[chat_id] -= 1 return deleted - async def edit_message(self, type_edit, chat_id=None, **kwargs): + async def edit_message(self, type_edit, chat_id=None, context=None, **kwargs): """Edit a previously sent message.""" chat_id = self._get_target_chat_ids(chat_id)[0] message_id, inline_message_id = self._get_msg_ids(kwargs, chat_id) @@ -759,6 +771,7 @@ class TelegramNotificationService: disable_web_page_preview=params[ATTR_DISABLE_WEB_PREV], reply_markup=params[ATTR_REPLYMARKUP], read_timeout=params[ATTR_TIMEOUT], + context=context, ) if type_edit == SERVICE_EDIT_CAPTION: return await self._send_msg( @@ -772,6 +785,7 @@ class TelegramNotificationService: reply_markup=params[ATTR_REPLYMARKUP], read_timeout=params[ATTR_TIMEOUT], parse_mode=params[ATTR_PARSER], + context=context, ) return await self._send_msg( @@ -783,10 +797,11 @@ class TelegramNotificationService: inline_message_id=inline_message_id, reply_markup=params[ATTR_REPLYMARKUP], read_timeout=params[ATTR_TIMEOUT], + context=context, ) async def answer_callback_query( - self, message, callback_query_id, show_alert=False, **kwargs + self, message, callback_query_id, show_alert=False, context=None, **kwargs ): """Answer a callback originated with a press in an inline keyboard.""" params = self._get_msg_kwargs(kwargs) @@ -804,9 +819,12 @@ class TelegramNotificationService: text=message, show_alert=show_alert, read_timeout=params[ATTR_TIMEOUT], + context=context, ) - async def send_file(self, file_type=SERVICE_SEND_PHOTO, target=None, **kwargs): + async def send_file( + self, file_type=SERVICE_SEND_PHOTO, target=None, context=None, **kwargs + ): """Send a photo, sticker, video, or document.""" params = self._get_msg_kwargs(kwargs) file_content = await load_data( @@ -836,6 +854,7 @@ class TelegramNotificationService: reply_markup=params[ATTR_REPLYMARKUP], read_timeout=params[ATTR_TIMEOUT], parse_mode=params[ATTR_PARSER], + context=context, ) elif file_type == SERVICE_SEND_STICKER: @@ -849,6 +868,7 @@ class TelegramNotificationService: reply_to_message_id=params[ATTR_REPLY_TO_MSGID], reply_markup=params[ATTR_REPLYMARKUP], read_timeout=params[ATTR_TIMEOUT], + context=context, ) elif file_type == SERVICE_SEND_VIDEO: @@ -864,6 +884,7 @@ class TelegramNotificationService: reply_markup=params[ATTR_REPLYMARKUP], read_timeout=params[ATTR_TIMEOUT], parse_mode=params[ATTR_PARSER], + context=context, ) elif file_type == SERVICE_SEND_DOCUMENT: await self._send_msg( @@ -878,6 +899,7 @@ class TelegramNotificationService: reply_markup=params[ATTR_REPLYMARKUP], read_timeout=params[ATTR_TIMEOUT], parse_mode=params[ATTR_PARSER], + context=context, ) elif file_type == SERVICE_SEND_VOICE: await self._send_msg( @@ -891,6 +913,7 @@ class TelegramNotificationService: reply_to_message_id=params[ATTR_REPLY_TO_MSGID], reply_markup=params[ATTR_REPLYMARKUP], read_timeout=params[ATTR_TIMEOUT], + context=context, ) elif file_type == SERVICE_SEND_ANIMATION: await self._send_msg( @@ -905,13 +928,14 @@ class TelegramNotificationService: reply_markup=params[ATTR_REPLYMARKUP], read_timeout=params[ATTR_TIMEOUT], parse_mode=params[ATTR_PARSER], + context=context, ) file_content.seek(0) else: _LOGGER.error("Can't send file with kwargs: %s", kwargs) - async def send_sticker(self, target=None, **kwargs): + async def send_sticker(self, target=None, context=None, **kwargs): """Send a sticker from a telegram sticker pack.""" params = self._get_msg_kwargs(kwargs) stickerid = kwargs.get(ATTR_STICKER_ID) @@ -927,11 +951,14 @@ class TelegramNotificationService: reply_to_message_id=params[ATTR_REPLY_TO_MSGID], reply_markup=params[ATTR_REPLYMARKUP], read_timeout=params[ATTR_TIMEOUT], + context=context, ) else: await self.send_file(SERVICE_SEND_STICKER, target, **kwargs) - async def send_location(self, latitude, longitude, target=None, **kwargs): + async def send_location( + self, latitude, longitude, target=None, context=None, **kwargs + ): """Send a location.""" latitude = float(latitude) longitude = float(longitude) @@ -950,6 +977,7 @@ class TelegramNotificationService: disable_notification=params[ATTR_DISABLE_NOTIF], reply_to_message_id=params[ATTR_REPLY_TO_MSGID], read_timeout=params[ATTR_TIMEOUT], + context=context, ) async def send_poll( @@ -959,6 +987,7 @@ class TelegramNotificationService: is_anonymous, allows_multiple_answers, target=None, + context=None, **kwargs, ): """Send a poll.""" @@ -979,14 +1008,15 @@ class TelegramNotificationService: disable_notification=params[ATTR_DISABLE_NOTIF], reply_to_message_id=params[ATTR_REPLY_TO_MSGID], read_timeout=params[ATTR_TIMEOUT], + context=context, ) - async def leave_chat(self, chat_id=None): + async def leave_chat(self, chat_id=None, context=None): """Remove bot from chat.""" chat_id = self._get_target_chat_ids(chat_id)[0] _LOGGER.debug("Leave from chat ID %s", chat_id) return await self._send_msg( - self.bot.leave_chat, "Error leaving chat", None, chat_id + self.bot.leave_chat, "Error leaving chat", None, chat_id, context=context ) @@ -1019,8 +1049,10 @@ class BaseTelegramBotEntity: _LOGGER.warning("Unhandled update: %s", update) return True + event_context = Context() + _LOGGER.debug("Firing event %s: %s", event_type, event_data) - self.hass.bus.async_fire(event_type, event_data) + self.hass.bus.async_fire(event_type, event_data, context=event_context) return True @staticmethod diff --git a/tests/components/telegram_bot/conftest.py b/tests/components/telegram_bot/conftest.py index 0906b6afcbd..6ea5d1446dd 100644 --- a/tests/components/telegram_bot/conftest.py +++ b/tests/components/telegram_bot/conftest.py @@ -1,9 +1,11 @@ """Tests for the telegram_bot integration.""" +from datetime import datetime from unittest.mock import patch import pytest -from telegram import User +from telegram import Chat, Message, User +from telegram.constants import ChatType from homeassistant.components.telegram_bot import ( CONF_ALLOWED_CHAT_IDS, @@ -79,6 +81,11 @@ def mock_register_webhook(): def mock_external_calls(): """Mock calls that make calls to the live Telegram API.""" test_user = User(123456, "Testbot", True) + message = Message( + message_id=12345, + date=datetime.now(), + chat=Chat(id=123456, type=ChatType.PRIVATE), + ) with ( patch( "telegram.Bot.get_me", @@ -92,6 +99,10 @@ def mock_external_calls(): "telegram.Bot.bot", test_user, ), + patch( + "telegram.Bot.send_message", + return_value=message, + ), patch("telegram.ext.Updater._bootstrap"), ): yield diff --git a/tests/components/telegram_bot/test_telegram_bot.py b/tests/components/telegram_bot/test_telegram_bot.py index d6588535b4f..b748b58ad1a 100644 --- a/tests/components/telegram_bot/test_telegram_bot.py +++ b/tests/components/telegram_bot/test_telegram_bot.py @@ -4,9 +4,13 @@ from unittest.mock import AsyncMock, patch from telegram import Update -from homeassistant.components.telegram_bot import DOMAIN, SERVICE_SEND_MESSAGE +from homeassistant.components.telegram_bot import ( + ATTR_MESSAGE, + DOMAIN, + SERVICE_SEND_MESSAGE, +) from homeassistant.components.telegram_bot.webhooks import TELEGRAM_WEBHOOK_URL -from homeassistant.core import HomeAssistant +from homeassistant.core import Context, HomeAssistant from homeassistant.setup import async_setup_component from tests.common import async_capture_events @@ -23,6 +27,24 @@ async def test_polling_platform_init(hass: HomeAssistant, polling_platform) -> N assert hass.services.has_service(DOMAIN, SERVICE_SEND_MESSAGE) is True +async def test_send_message(hass: HomeAssistant, webhook_platform) -> None: + """Test the send_message service.""" + context = Context() + events = async_capture_events(hass, "telegram_sent") + + await hass.services.async_call( + DOMAIN, + SERVICE_SEND_MESSAGE, + {ATTR_MESSAGE: "test_message"}, + blocking=True, + context=context, + ) + await hass.async_block_till_done() + + assert len(events) == 1 + assert events[0].context == context + + async def test_webhook_endpoint_generates_telegram_text_event( hass: HomeAssistant, webhook_platform, @@ -47,6 +69,7 @@ async def test_webhook_endpoint_generates_telegram_text_event( assert len(events) == 1 assert events[0].data["text"] == update_message_text["message"]["text"] + assert isinstance(events[0].context, Context) async def test_webhook_endpoint_generates_telegram_command_event( @@ -73,6 +96,7 @@ async def test_webhook_endpoint_generates_telegram_command_event( assert len(events) == 1 assert events[0].data["command"] == update_message_command["message"]["text"] + assert isinstance(events[0].context, Context) async def test_webhook_endpoint_generates_telegram_callback_event( @@ -99,6 +123,7 @@ async def test_webhook_endpoint_generates_telegram_callback_event( assert len(events) == 1 assert events[0].data["data"] == update_callback_query["callback_query"]["data"] + assert isinstance(events[0].context, Context) async def test_polling_platform_message_text_update( @@ -140,6 +165,7 @@ async def test_polling_platform_message_text_update( assert len(events) == 1 assert events[0].data["text"] == update_message_text["message"]["text"] + assert isinstance(events[0].context, Context) async def test_webhook_endpoint_unauthorized_update_doesnt_generate_telegram_text_event(