Add context to telegram_bot events (#109920)

* Add context for received messages events

* Add context for sent messages events

* ruff

* ruff

* ruff

* Removed user_id mapping

* Add tests
This commit is contained in:
Denis Shulyaka 2024-05-14 16:48:59 +03:00 committed by GitHub
parent 121966245b
commit 9add251b0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 94 additions and 25 deletions

View file

@ -36,7 +36,7 @@ from homeassistant.const import (
HTTP_BEARER_AUTHENTICATION,
HTTP_DIGEST_AUTHENTICATION,
)
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.core import Context, HomeAssistant, ServiceCall
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import config_validation as cv, issue_registry as ir
from homeassistant.helpers.typing import ConfigType
@ -426,7 +426,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
_LOGGER.debug("New telegram message %s: %s", msgtype, kwargs)
if msgtype == SERVICE_SEND_MESSAGE:
await notify_service.send_message(**kwargs)
await notify_service.send_message(context=service.context, **kwargs)
elif msgtype in [
SERVICE_SEND_PHOTO,
SERVICE_SEND_ANIMATION,
@ -434,19 +434,23 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
SERVICE_SEND_VOICE,
SERVICE_SEND_DOCUMENT,
]:
await notify_service.send_file(msgtype, **kwargs)
await notify_service.send_file(msgtype, context=service.context, **kwargs)
elif msgtype == SERVICE_SEND_STICKER:
await notify_service.send_sticker(**kwargs)
await notify_service.send_sticker(context=service.context, **kwargs)
elif msgtype == SERVICE_SEND_LOCATION:
await notify_service.send_location(**kwargs)
await notify_service.send_location(context=service.context, **kwargs)
elif msgtype == SERVICE_SEND_POLL:
await notify_service.send_poll(**kwargs)
await notify_service.send_poll(context=service.context, **kwargs)
elif msgtype == SERVICE_ANSWER_CALLBACK_QUERY:
await notify_service.answer_callback_query(**kwargs)
await notify_service.answer_callback_query(
context=service.context, **kwargs
)
elif msgtype == SERVICE_DELETE_MESSAGE:
await notify_service.delete_message(**kwargs)
await notify_service.delete_message(context=service.context, **kwargs)
else:
await notify_service.edit_message(msgtype, **kwargs)
await notify_service.edit_message(
msgtype, context=service.context, **kwargs
)
# Register notification services
for service_notif, schema in SERVICE_MAP.items():
@ -663,7 +667,7 @@ class TelegramNotificationService:
return params
async def _send_msg(
self, func_send, msg_error, message_tag, *args_msg, **kwargs_msg
self, func_send, msg_error, message_tag, *args_msg, context=None, **kwargs_msg
):
"""Send one message."""
try:
@ -684,7 +688,9 @@ class TelegramNotificationService:
}
if message_tag is not None:
event_data[ATTR_MESSAGE_TAG] = message_tag
self.hass.bus.async_fire(EVENT_TELEGRAM_SENT, event_data)
self.hass.bus.async_fire(
EVENT_TELEGRAM_SENT, event_data, context=context
)
elif not isinstance(out, bool):
_LOGGER.warning(
"Update last message: out_type:%s, out=%s", type(out), out
@ -696,7 +702,7 @@ class TelegramNotificationService:
return None
return out
async def send_message(self, message="", target=None, **kwargs):
async def send_message(self, message="", target=None, context=None, **kwargs):
"""Send a message to one or multiple pre-allowed chat IDs."""
title = kwargs.get(ATTR_TITLE)
text = f"{title}\n{message}" if title else message
@ -715,15 +721,21 @@ class TelegramNotificationService:
reply_to_message_id=params[ATTR_REPLY_TO_MSGID],
reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT],
context=context,
)
async def delete_message(self, chat_id=None, **kwargs):
async def delete_message(self, chat_id=None, context=None, **kwargs):
"""Delete a previously sent message."""
chat_id = self._get_target_chat_ids(chat_id)[0]
message_id, _ = self._get_msg_ids(kwargs, chat_id)
_LOGGER.debug("Delete message %s in chat ID %s", message_id, chat_id)
deleted = await self._send_msg(
self.bot.delete_message, "Error deleting message", None, chat_id, message_id
self.bot.delete_message,
"Error deleting message",
None,
chat_id,
message_id,
context=context,
)
# reduce message_id anyway:
if self._last_message_id[chat_id] is not None:
@ -731,7 +743,7 @@ class TelegramNotificationService:
self._last_message_id[chat_id] -= 1
return deleted
async def edit_message(self, type_edit, chat_id=None, **kwargs):
async def edit_message(self, type_edit, chat_id=None, context=None, **kwargs):
"""Edit a previously sent message."""
chat_id = self._get_target_chat_ids(chat_id)[0]
message_id, inline_message_id = self._get_msg_ids(kwargs, chat_id)
@ -759,6 +771,7 @@ class TelegramNotificationService:
disable_web_page_preview=params[ATTR_DISABLE_WEB_PREV],
reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT],
context=context,
)
if type_edit == SERVICE_EDIT_CAPTION:
return await self._send_msg(
@ -772,6 +785,7 @@ class TelegramNotificationService:
reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT],
parse_mode=params[ATTR_PARSER],
context=context,
)
return await self._send_msg(
@ -783,10 +797,11 @@ class TelegramNotificationService:
inline_message_id=inline_message_id,
reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT],
context=context,
)
async def answer_callback_query(
self, message, callback_query_id, show_alert=False, **kwargs
self, message, callback_query_id, show_alert=False, context=None, **kwargs
):
"""Answer a callback originated with a press in an inline keyboard."""
params = self._get_msg_kwargs(kwargs)
@ -804,9 +819,12 @@ class TelegramNotificationService:
text=message,
show_alert=show_alert,
read_timeout=params[ATTR_TIMEOUT],
context=context,
)
async def send_file(self, file_type=SERVICE_SEND_PHOTO, target=None, **kwargs):
async def send_file(
self, file_type=SERVICE_SEND_PHOTO, target=None, context=None, **kwargs
):
"""Send a photo, sticker, video, or document."""
params = self._get_msg_kwargs(kwargs)
file_content = await load_data(
@ -836,6 +854,7 @@ class TelegramNotificationService:
reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT],
parse_mode=params[ATTR_PARSER],
context=context,
)
elif file_type == SERVICE_SEND_STICKER:
@ -849,6 +868,7 @@ class TelegramNotificationService:
reply_to_message_id=params[ATTR_REPLY_TO_MSGID],
reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT],
context=context,
)
elif file_type == SERVICE_SEND_VIDEO:
@ -864,6 +884,7 @@ class TelegramNotificationService:
reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT],
parse_mode=params[ATTR_PARSER],
context=context,
)
elif file_type == SERVICE_SEND_DOCUMENT:
await self._send_msg(
@ -878,6 +899,7 @@ class TelegramNotificationService:
reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT],
parse_mode=params[ATTR_PARSER],
context=context,
)
elif file_type == SERVICE_SEND_VOICE:
await self._send_msg(
@ -891,6 +913,7 @@ class TelegramNotificationService:
reply_to_message_id=params[ATTR_REPLY_TO_MSGID],
reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT],
context=context,
)
elif file_type == SERVICE_SEND_ANIMATION:
await self._send_msg(
@ -905,13 +928,14 @@ class TelegramNotificationService:
reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT],
parse_mode=params[ATTR_PARSER],
context=context,
)
file_content.seek(0)
else:
_LOGGER.error("Can't send file with kwargs: %s", kwargs)
async def send_sticker(self, target=None, **kwargs):
async def send_sticker(self, target=None, context=None, **kwargs):
"""Send a sticker from a telegram sticker pack."""
params = self._get_msg_kwargs(kwargs)
stickerid = kwargs.get(ATTR_STICKER_ID)
@ -927,11 +951,14 @@ class TelegramNotificationService:
reply_to_message_id=params[ATTR_REPLY_TO_MSGID],
reply_markup=params[ATTR_REPLYMARKUP],
read_timeout=params[ATTR_TIMEOUT],
context=context,
)
else:
await self.send_file(SERVICE_SEND_STICKER, target, **kwargs)
async def send_location(self, latitude, longitude, target=None, **kwargs):
async def send_location(
self, latitude, longitude, target=None, context=None, **kwargs
):
"""Send a location."""
latitude = float(latitude)
longitude = float(longitude)
@ -950,6 +977,7 @@ class TelegramNotificationService:
disable_notification=params[ATTR_DISABLE_NOTIF],
reply_to_message_id=params[ATTR_REPLY_TO_MSGID],
read_timeout=params[ATTR_TIMEOUT],
context=context,
)
async def send_poll(
@ -959,6 +987,7 @@ class TelegramNotificationService:
is_anonymous,
allows_multiple_answers,
target=None,
context=None,
**kwargs,
):
"""Send a poll."""
@ -979,14 +1008,15 @@ class TelegramNotificationService:
disable_notification=params[ATTR_DISABLE_NOTIF],
reply_to_message_id=params[ATTR_REPLY_TO_MSGID],
read_timeout=params[ATTR_TIMEOUT],
context=context,
)
async def leave_chat(self, chat_id=None):
async def leave_chat(self, chat_id=None, context=None):
"""Remove bot from chat."""
chat_id = self._get_target_chat_ids(chat_id)[0]
_LOGGER.debug("Leave from chat ID %s", chat_id)
return await self._send_msg(
self.bot.leave_chat, "Error leaving chat", None, chat_id
self.bot.leave_chat, "Error leaving chat", None, chat_id, context=context
)
@ -1019,8 +1049,10 @@ class BaseTelegramBotEntity:
_LOGGER.warning("Unhandled update: %s", update)
return True
event_context = Context()
_LOGGER.debug("Firing event %s: %s", event_type, event_data)
self.hass.bus.async_fire(event_type, event_data)
self.hass.bus.async_fire(event_type, event_data, context=event_context)
return True
@staticmethod

View file

@ -1,9 +1,11 @@
"""Tests for the telegram_bot integration."""
from datetime import datetime
from unittest.mock import patch
import pytest
from telegram import User
from telegram import Chat, Message, User
from telegram.constants import ChatType
from homeassistant.components.telegram_bot import (
CONF_ALLOWED_CHAT_IDS,
@ -79,6 +81,11 @@ def mock_register_webhook():
def mock_external_calls():
"""Mock calls that make calls to the live Telegram API."""
test_user = User(123456, "Testbot", True)
message = Message(
message_id=12345,
date=datetime.now(),
chat=Chat(id=123456, type=ChatType.PRIVATE),
)
with (
patch(
"telegram.Bot.get_me",
@ -92,6 +99,10 @@ def mock_external_calls():
"telegram.Bot.bot",
test_user,
),
patch(
"telegram.Bot.send_message",
return_value=message,
),
patch("telegram.ext.Updater._bootstrap"),
):
yield

View file

@ -4,9 +4,13 @@ from unittest.mock import AsyncMock, patch
from telegram import Update
from homeassistant.components.telegram_bot import DOMAIN, SERVICE_SEND_MESSAGE
from homeassistant.components.telegram_bot import (
ATTR_MESSAGE,
DOMAIN,
SERVICE_SEND_MESSAGE,
)
from homeassistant.components.telegram_bot.webhooks import TELEGRAM_WEBHOOK_URL
from homeassistant.core import HomeAssistant
from homeassistant.core import Context, HomeAssistant
from homeassistant.setup import async_setup_component
from tests.common import async_capture_events
@ -23,6 +27,24 @@ async def test_polling_platform_init(hass: HomeAssistant, polling_platform) -> N
assert hass.services.has_service(DOMAIN, SERVICE_SEND_MESSAGE) is True
async def test_send_message(hass: HomeAssistant, webhook_platform) -> None:
"""Test the send_message service."""
context = Context()
events = async_capture_events(hass, "telegram_sent")
await hass.services.async_call(
DOMAIN,
SERVICE_SEND_MESSAGE,
{ATTR_MESSAGE: "test_message"},
blocking=True,
context=context,
)
await hass.async_block_till_done()
assert len(events) == 1
assert events[0].context == context
async def test_webhook_endpoint_generates_telegram_text_event(
hass: HomeAssistant,
webhook_platform,
@ -47,6 +69,7 @@ async def test_webhook_endpoint_generates_telegram_text_event(
assert len(events) == 1
assert events[0].data["text"] == update_message_text["message"]["text"]
assert isinstance(events[0].context, Context)
async def test_webhook_endpoint_generates_telegram_command_event(
@ -73,6 +96,7 @@ async def test_webhook_endpoint_generates_telegram_command_event(
assert len(events) == 1
assert events[0].data["command"] == update_message_command["message"]["text"]
assert isinstance(events[0].context, Context)
async def test_webhook_endpoint_generates_telegram_callback_event(
@ -99,6 +123,7 @@ async def test_webhook_endpoint_generates_telegram_callback_event(
assert len(events) == 1
assert events[0].data["data"] == update_callback_query["callback_query"]["data"]
assert isinstance(events[0].context, Context)
async def test_polling_platform_message_text_update(
@ -140,6 +165,7 @@ async def test_polling_platform_message_text_update(
assert len(events) == 1
assert events[0].data["text"] == update_message_text["message"]["text"]
assert isinstance(events[0].context, Context)
async def test_webhook_endpoint_unauthorized_update_doesnt_generate_telegram_text_event(