Add context to telegram_bot events (#109920)

* Add context for received messages events

* Add context for sent messages events

* ruff

* ruff

* ruff

* Removed user_id mapping

* Add tests
This commit is contained in:
Denis Shulyaka 2024-05-14 16:48:59 +03:00 committed by GitHub
parent 121966245b
commit 9add251b0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 94 additions and 25 deletions

View file

@ -36,7 +36,7 @@ from homeassistant.const import (
HTTP_BEARER_AUTHENTICATION, HTTP_BEARER_AUTHENTICATION,
HTTP_DIGEST_AUTHENTICATION, HTTP_DIGEST_AUTHENTICATION,
) )
from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.core import Context, HomeAssistant, ServiceCall
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.helpers import config_validation as cv, issue_registry as ir from homeassistant.helpers import config_validation as cv, issue_registry as ir
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@ -426,7 +426,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
_LOGGER.debug("New telegram message %s: %s", msgtype, kwargs) _LOGGER.debug("New telegram message %s: %s", msgtype, kwargs)
if msgtype == SERVICE_SEND_MESSAGE: if msgtype == SERVICE_SEND_MESSAGE:
await notify_service.send_message(**kwargs) await notify_service.send_message(context=service.context, **kwargs)
elif msgtype in [ elif msgtype in [
SERVICE_SEND_PHOTO, SERVICE_SEND_PHOTO,
SERVICE_SEND_ANIMATION, SERVICE_SEND_ANIMATION,
@ -434,19 +434,23 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
SERVICE_SEND_VOICE, SERVICE_SEND_VOICE,
SERVICE_SEND_DOCUMENT, SERVICE_SEND_DOCUMENT,
]: ]:
await notify_service.send_file(msgtype, **kwargs) await notify_service.send_file(msgtype, context=service.context, **kwargs)
elif msgtype == SERVICE_SEND_STICKER: elif msgtype == SERVICE_SEND_STICKER:
await notify_service.send_sticker(**kwargs) await notify_service.send_sticker(context=service.context, **kwargs)
elif msgtype == SERVICE_SEND_LOCATION: elif msgtype == SERVICE_SEND_LOCATION:
await notify_service.send_location(**kwargs) await notify_service.send_location(context=service.context, **kwargs)
elif msgtype == SERVICE_SEND_POLL: elif msgtype == SERVICE_SEND_POLL:
await notify_service.send_poll(**kwargs) await notify_service.send_poll(context=service.context, **kwargs)
elif msgtype == SERVICE_ANSWER_CALLBACK_QUERY: elif msgtype == SERVICE_ANSWER_CALLBACK_QUERY:
await notify_service.answer_callback_query(**kwargs) await notify_service.answer_callback_query(
context=service.context, **kwargs
)
elif msgtype == SERVICE_DELETE_MESSAGE: elif msgtype == SERVICE_DELETE_MESSAGE:
await notify_service.delete_message(**kwargs) await notify_service.delete_message(context=service.context, **kwargs)
else: else:
await notify_service.edit_message(msgtype, **kwargs) await notify_service.edit_message(
msgtype, context=service.context, **kwargs
)
# Register notification services # Register notification services
for service_notif, schema in SERVICE_MAP.items(): for service_notif, schema in SERVICE_MAP.items():
@ -663,7 +667,7 @@ class TelegramNotificationService:
return params return params
async def _send_msg( async def _send_msg(
self, func_send, msg_error, message_tag, *args_msg, **kwargs_msg self, func_send, msg_error, message_tag, *args_msg, context=None, **kwargs_msg
): ):
"""Send one message.""" """Send one message."""
try: try:
@ -684,7 +688,9 @@ class TelegramNotificationService:
} }
if message_tag is not None: if message_tag is not None:
event_data[ATTR_MESSAGE_TAG] = message_tag event_data[ATTR_MESSAGE_TAG] = message_tag
self.hass.bus.async_fire(EVENT_TELEGRAM_SENT, event_data) self.hass.bus.async_fire(
EVENT_TELEGRAM_SENT, event_data, context=context
)
elif not isinstance(out, bool): elif not isinstance(out, bool):
_LOGGER.warning( _LOGGER.warning(
"Update last message: out_type:%s, out=%s", type(out), out "Update last message: out_type:%s, out=%s", type(out), out
@ -696,7 +702,7 @@ class TelegramNotificationService:
return None return None
return out return out
async def send_message(self, message="", target=None, **kwargs): async def send_message(self, message="", target=None, context=None, **kwargs):
"""Send a message to one or multiple pre-allowed chat IDs.""" """Send a message to one or multiple pre-allowed chat IDs."""
title = kwargs.get(ATTR_TITLE) title = kwargs.get(ATTR_TITLE)
text = f"{title}\n{message}" if title else message text = f"{title}\n{message}" if title else message
@ -715,15 +721,21 @@ class TelegramNotificationService:
reply_to_message_id=params[ATTR_REPLY_TO_MSGID], reply_to_message_id=params[ATTR_REPLY_TO_MSGID],
reply_markup=params[ATTR_REPLYMARKUP], reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT], read_timeout=params[ATTR_TIMEOUT],
context=context,
) )
async def delete_message(self, chat_id=None, **kwargs): async def delete_message(self, chat_id=None, context=None, **kwargs):
"""Delete a previously sent message.""" """Delete a previously sent message."""
chat_id = self._get_target_chat_ids(chat_id)[0] chat_id = self._get_target_chat_ids(chat_id)[0]
message_id, _ = self._get_msg_ids(kwargs, chat_id) message_id, _ = self._get_msg_ids(kwargs, chat_id)
_LOGGER.debug("Delete message %s in chat ID %s", message_id, chat_id) _LOGGER.debug("Delete message %s in chat ID %s", message_id, chat_id)
deleted = await self._send_msg( deleted = await self._send_msg(
self.bot.delete_message, "Error deleting message", None, chat_id, message_id self.bot.delete_message,
"Error deleting message",
None,
chat_id,
message_id,
context=context,
) )
# reduce message_id anyway: # reduce message_id anyway:
if self._last_message_id[chat_id] is not None: if self._last_message_id[chat_id] is not None:
@ -731,7 +743,7 @@ class TelegramNotificationService:
self._last_message_id[chat_id] -= 1 self._last_message_id[chat_id] -= 1
return deleted return deleted
async def edit_message(self, type_edit, chat_id=None, **kwargs): async def edit_message(self, type_edit, chat_id=None, context=None, **kwargs):
"""Edit a previously sent message.""" """Edit a previously sent message."""
chat_id = self._get_target_chat_ids(chat_id)[0] chat_id = self._get_target_chat_ids(chat_id)[0]
message_id, inline_message_id = self._get_msg_ids(kwargs, chat_id) message_id, inline_message_id = self._get_msg_ids(kwargs, chat_id)
@ -759,6 +771,7 @@ class TelegramNotificationService:
disable_web_page_preview=params[ATTR_DISABLE_WEB_PREV], disable_web_page_preview=params[ATTR_DISABLE_WEB_PREV],
reply_markup=params[ATTR_REPLYMARKUP], reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT], read_timeout=params[ATTR_TIMEOUT],
context=context,
) )
if type_edit == SERVICE_EDIT_CAPTION: if type_edit == SERVICE_EDIT_CAPTION:
return await self._send_msg( return await self._send_msg(
@ -772,6 +785,7 @@ class TelegramNotificationService:
reply_markup=params[ATTR_REPLYMARKUP], reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT], read_timeout=params[ATTR_TIMEOUT],
parse_mode=params[ATTR_PARSER], parse_mode=params[ATTR_PARSER],
context=context,
) )
return await self._send_msg( return await self._send_msg(
@ -783,10 +797,11 @@ class TelegramNotificationService:
inline_message_id=inline_message_id, inline_message_id=inline_message_id,
reply_markup=params[ATTR_REPLYMARKUP], reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT], read_timeout=params[ATTR_TIMEOUT],
context=context,
) )
async def answer_callback_query( async def answer_callback_query(
self, message, callback_query_id, show_alert=False, **kwargs self, message, callback_query_id, show_alert=False, context=None, **kwargs
): ):
"""Answer a callback originated with a press in an inline keyboard.""" """Answer a callback originated with a press in an inline keyboard."""
params = self._get_msg_kwargs(kwargs) params = self._get_msg_kwargs(kwargs)
@ -804,9 +819,12 @@ class TelegramNotificationService:
text=message, text=message,
show_alert=show_alert, show_alert=show_alert,
read_timeout=params[ATTR_TIMEOUT], read_timeout=params[ATTR_TIMEOUT],
context=context,
) )
async def send_file(self, file_type=SERVICE_SEND_PHOTO, target=None, **kwargs): async def send_file(
self, file_type=SERVICE_SEND_PHOTO, target=None, context=None, **kwargs
):
"""Send a photo, sticker, video, or document.""" """Send a photo, sticker, video, or document."""
params = self._get_msg_kwargs(kwargs) params = self._get_msg_kwargs(kwargs)
file_content = await load_data( file_content = await load_data(
@ -836,6 +854,7 @@ class TelegramNotificationService:
reply_markup=params[ATTR_REPLYMARKUP], reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT], read_timeout=params[ATTR_TIMEOUT],
parse_mode=params[ATTR_PARSER], parse_mode=params[ATTR_PARSER],
context=context,
) )
elif file_type == SERVICE_SEND_STICKER: elif file_type == SERVICE_SEND_STICKER:
@ -849,6 +868,7 @@ class TelegramNotificationService:
reply_to_message_id=params[ATTR_REPLY_TO_MSGID], reply_to_message_id=params[ATTR_REPLY_TO_MSGID],
reply_markup=params[ATTR_REPLYMARKUP], reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT], read_timeout=params[ATTR_TIMEOUT],
context=context,
) )
elif file_type == SERVICE_SEND_VIDEO: elif file_type == SERVICE_SEND_VIDEO:
@ -864,6 +884,7 @@ class TelegramNotificationService:
reply_markup=params[ATTR_REPLYMARKUP], reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT], read_timeout=params[ATTR_TIMEOUT],
parse_mode=params[ATTR_PARSER], parse_mode=params[ATTR_PARSER],
context=context,
) )
elif file_type == SERVICE_SEND_DOCUMENT: elif file_type == SERVICE_SEND_DOCUMENT:
await self._send_msg( await self._send_msg(
@ -878,6 +899,7 @@ class TelegramNotificationService:
reply_markup=params[ATTR_REPLYMARKUP], reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT], read_timeout=params[ATTR_TIMEOUT],
parse_mode=params[ATTR_PARSER], parse_mode=params[ATTR_PARSER],
context=context,
) )
elif file_type == SERVICE_SEND_VOICE: elif file_type == SERVICE_SEND_VOICE:
await self._send_msg( await self._send_msg(
@ -891,6 +913,7 @@ class TelegramNotificationService:
reply_to_message_id=params[ATTR_REPLY_TO_MSGID], reply_to_message_id=params[ATTR_REPLY_TO_MSGID],
reply_markup=params[ATTR_REPLYMARKUP], reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT], read_timeout=params[ATTR_TIMEOUT],
context=context,
) )
elif file_type == SERVICE_SEND_ANIMATION: elif file_type == SERVICE_SEND_ANIMATION:
await self._send_msg( await self._send_msg(
@ -905,13 +928,14 @@ class TelegramNotificationService:
reply_markup=params[ATTR_REPLYMARKUP], reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT], read_timeout=params[ATTR_TIMEOUT],
parse_mode=params[ATTR_PARSER], parse_mode=params[ATTR_PARSER],
context=context,
) )
file_content.seek(0) file_content.seek(0)
else: else:
_LOGGER.error("Can't send file with kwargs: %s", kwargs) _LOGGER.error("Can't send file with kwargs: %s", kwargs)
async def send_sticker(self, target=None, **kwargs): async def send_sticker(self, target=None, context=None, **kwargs):
"""Send a sticker from a telegram sticker pack.""" """Send a sticker from a telegram sticker pack."""
params = self._get_msg_kwargs(kwargs) params = self._get_msg_kwargs(kwargs)
stickerid = kwargs.get(ATTR_STICKER_ID) stickerid = kwargs.get(ATTR_STICKER_ID)
@ -927,11 +951,14 @@ class TelegramNotificationService:
reply_to_message_id=params[ATTR_REPLY_TO_MSGID], reply_to_message_id=params[ATTR_REPLY_TO_MSGID],
reply_markup=params[ATTR_REPLYMARKUP], reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT], read_timeout=params[ATTR_TIMEOUT],
context=context,
) )
else: else:
await self.send_file(SERVICE_SEND_STICKER, target, **kwargs) await self.send_file(SERVICE_SEND_STICKER, target, **kwargs)
async def send_location(self, latitude, longitude, target=None, **kwargs): async def send_location(
self, latitude, longitude, target=None, context=None, **kwargs
):
"""Send a location.""" """Send a location."""
latitude = float(latitude) latitude = float(latitude)
longitude = float(longitude) longitude = float(longitude)
@ -950,6 +977,7 @@ class TelegramNotificationService:
disable_notification=params[ATTR_DISABLE_NOTIF], disable_notification=params[ATTR_DISABLE_NOTIF],
reply_to_message_id=params[ATTR_REPLY_TO_MSGID], reply_to_message_id=params[ATTR_REPLY_TO_MSGID],
read_timeout=params[ATTR_TIMEOUT], read_timeout=params[ATTR_TIMEOUT],
context=context,
) )
async def send_poll( async def send_poll(
@ -959,6 +987,7 @@ class TelegramNotificationService:
is_anonymous, is_anonymous,
allows_multiple_answers, allows_multiple_answers,
target=None, target=None,
context=None,
**kwargs, **kwargs,
): ):
"""Send a poll.""" """Send a poll."""
@ -979,14 +1008,15 @@ class TelegramNotificationService:
disable_notification=params[ATTR_DISABLE_NOTIF], disable_notification=params[ATTR_DISABLE_NOTIF],
reply_to_message_id=params[ATTR_REPLY_TO_MSGID], reply_to_message_id=params[ATTR_REPLY_TO_MSGID],
read_timeout=params[ATTR_TIMEOUT], read_timeout=params[ATTR_TIMEOUT],
context=context,
) )
async def leave_chat(self, chat_id=None): async def leave_chat(self, chat_id=None, context=None):
"""Remove bot from chat.""" """Remove bot from chat."""
chat_id = self._get_target_chat_ids(chat_id)[0] chat_id = self._get_target_chat_ids(chat_id)[0]
_LOGGER.debug("Leave from chat ID %s", chat_id) _LOGGER.debug("Leave from chat ID %s", chat_id)
return await self._send_msg( return await self._send_msg(
self.bot.leave_chat, "Error leaving chat", None, chat_id self.bot.leave_chat, "Error leaving chat", None, chat_id, context=context
) )
@ -1019,8 +1049,10 @@ class BaseTelegramBotEntity:
_LOGGER.warning("Unhandled update: %s", update) _LOGGER.warning("Unhandled update: %s", update)
return True return True
event_context = Context()
_LOGGER.debug("Firing event %s: %s", event_type, event_data) _LOGGER.debug("Firing event %s: %s", event_type, event_data)
self.hass.bus.async_fire(event_type, event_data) self.hass.bus.async_fire(event_type, event_data, context=event_context)
return True return True
@staticmethod @staticmethod

View file

@ -1,9 +1,11 @@
"""Tests for the telegram_bot integration.""" """Tests for the telegram_bot integration."""
from datetime import datetime
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from telegram import User from telegram import Chat, Message, User
from telegram.constants import ChatType
from homeassistant.components.telegram_bot import ( from homeassistant.components.telegram_bot import (
CONF_ALLOWED_CHAT_IDS, CONF_ALLOWED_CHAT_IDS,
@ -79,6 +81,11 @@ def mock_register_webhook():
def mock_external_calls(): def mock_external_calls():
"""Mock calls that make calls to the live Telegram API.""" """Mock calls that make calls to the live Telegram API."""
test_user = User(123456, "Testbot", True) test_user = User(123456, "Testbot", True)
message = Message(
message_id=12345,
date=datetime.now(),
chat=Chat(id=123456, type=ChatType.PRIVATE),
)
with ( with (
patch( patch(
"telegram.Bot.get_me", "telegram.Bot.get_me",
@ -92,6 +99,10 @@ def mock_external_calls():
"telegram.Bot.bot", "telegram.Bot.bot",
test_user, test_user,
), ),
patch(
"telegram.Bot.send_message",
return_value=message,
),
patch("telegram.ext.Updater._bootstrap"), patch("telegram.ext.Updater._bootstrap"),
): ):
yield yield

View file

@ -4,9 +4,13 @@ from unittest.mock import AsyncMock, patch
from telegram import Update from telegram import Update
from homeassistant.components.telegram_bot import DOMAIN, SERVICE_SEND_MESSAGE from homeassistant.components.telegram_bot import (
ATTR_MESSAGE,
DOMAIN,
SERVICE_SEND_MESSAGE,
)
from homeassistant.components.telegram_bot.webhooks import TELEGRAM_WEBHOOK_URL from homeassistant.components.telegram_bot.webhooks import TELEGRAM_WEBHOOK_URL
from homeassistant.core import HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import async_capture_events from tests.common import async_capture_events
@ -23,6 +27,24 @@ async def test_polling_platform_init(hass: HomeAssistant, polling_platform) -> N
assert hass.services.has_service(DOMAIN, SERVICE_SEND_MESSAGE) is True assert hass.services.has_service(DOMAIN, SERVICE_SEND_MESSAGE) is True
async def test_send_message(hass: HomeAssistant, webhook_platform) -> None:
"""Test the send_message service."""
context = Context()
events = async_capture_events(hass, "telegram_sent")
await hass.services.async_call(
DOMAIN,
SERVICE_SEND_MESSAGE,
{ATTR_MESSAGE: "test_message"},
blocking=True,
context=context,
)
await hass.async_block_till_done()
assert len(events) == 1
assert events[0].context == context
async def test_webhook_endpoint_generates_telegram_text_event( async def test_webhook_endpoint_generates_telegram_text_event(
hass: HomeAssistant, hass: HomeAssistant,
webhook_platform, webhook_platform,
@ -47,6 +69,7 @@ async def test_webhook_endpoint_generates_telegram_text_event(
assert len(events) == 1 assert len(events) == 1
assert events[0].data["text"] == update_message_text["message"]["text"] assert events[0].data["text"] == update_message_text["message"]["text"]
assert isinstance(events[0].context, Context)
async def test_webhook_endpoint_generates_telegram_command_event( async def test_webhook_endpoint_generates_telegram_command_event(
@ -73,6 +96,7 @@ async def test_webhook_endpoint_generates_telegram_command_event(
assert len(events) == 1 assert len(events) == 1
assert events[0].data["command"] == update_message_command["message"]["text"] assert events[0].data["command"] == update_message_command["message"]["text"]
assert isinstance(events[0].context, Context)
async def test_webhook_endpoint_generates_telegram_callback_event( async def test_webhook_endpoint_generates_telegram_callback_event(
@ -99,6 +123,7 @@ async def test_webhook_endpoint_generates_telegram_callback_event(
assert len(events) == 1 assert len(events) == 1
assert events[0].data["data"] == update_callback_query["callback_query"]["data"] assert events[0].data["data"] == update_callback_query["callback_query"]["data"]
assert isinstance(events[0].context, Context)
async def test_polling_platform_message_text_update( async def test_polling_platform_message_text_update(
@ -140,6 +165,7 @@ async def test_polling_platform_message_text_update(
assert len(events) == 1 assert len(events) == 1
assert events[0].data["text"] == update_message_text["message"]["text"] assert events[0].data["text"] == update_message_text["message"]["text"]
assert isinstance(events[0].context, Context)
async def test_webhook_endpoint_unauthorized_update_doesnt_generate_telegram_text_event( async def test_webhook_endpoint_unauthorized_update_doesnt_generate_telegram_text_event(