From 30ba78cf82d2622cd6b9f0bc06ce20c68ce6dd75 Mon Sep 17 00:00:00 2001 From: Paarth Shah Date: Mon, 23 Oct 2023 01:35:41 -0700 Subject: [PATCH] Fix resolving Matrix room aliases (#101928) --- homeassistant/components/matrix/__init__.py | 92 ++++++++++++++----- tests/components/matrix/conftest.py | 45 ++++++--- tests/components/matrix/test_matrix_bot.py | 5 +- .../{test_join_rooms.py => test_rooms.py} | 14 ++- tests/components/matrix/test_send_message.py | 18 ++-- 5 files changed, 125 insertions(+), 49 deletions(-) rename tests/components/matrix/{test_join_rooms.py => test_rooms.py} (60%) diff --git a/homeassistant/components/matrix/__init__.py b/homeassistant/components/matrix/__init__.py index cf7bcce7b3c..f9ef3593fe6 100644 --- a/homeassistant/components/matrix/__init__.py +++ b/homeassistant/components/matrix/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +from collections.abc import Sequence import logging import mimetypes import os @@ -17,6 +18,7 @@ from nio.responses import ( JoinResponse, LoginError, Response, + RoomResolveAliasResponse, UploadError, UploadResponse, WhoamiError, @@ -53,6 +55,9 @@ CONF_COMMANDS = "commands" CONF_WORD = "word" CONF_EXPRESSION = "expression" +CONF_USERNAME_REGEX = "^@[^:]*:.*" +CONF_ROOMS_REGEX = "^[!|#][^:]*:.*" + EVENT_MATRIX_COMMAND = "matrix_command" DEFAULT_CONTENT_TYPE = "application/octet-stream" @@ -65,7 +70,9 @@ ATTR_IMAGES = "images" # optional images WordCommand = NewType("WordCommand", str) ExpressionCommand = NewType("ExpressionCommand", re.Pattern) -RoomID = NewType("RoomID", str) +RoomAlias = NewType("RoomAlias", str) # Starts with "#" +RoomID = NewType("RoomID", str) # Starts with "!" +RoomAnyID = RoomID | RoomAlias class ConfigCommand(TypedDict, total=False): @@ -83,7 +90,9 @@ COMMAND_SCHEMA = vol.All( vol.Exclusive(CONF_WORD, "trigger"): cv.string, vol.Exclusive(CONF_EXPRESSION, "trigger"): cv.is_regex, vol.Required(CONF_NAME): cv.string, - vol.Optional(CONF_ROOMS): vol.All(cv.ensure_list, [cv.string]), + vol.Optional(CONF_ROOMS): vol.All( + cv.ensure_list, [cv.matches_regex(CONF_ROOMS_REGEX)] + ), } ), cv.has_at_least_one_key(CONF_WORD, CONF_EXPRESSION), @@ -95,10 +104,10 @@ CONFIG_SCHEMA = vol.Schema( { vol.Required(CONF_HOMESERVER): cv.url, vol.Optional(CONF_VERIFY_SSL, default=True): cv.boolean, - vol.Required(CONF_USERNAME): cv.matches_regex("@[^:]*:.*"), + vol.Required(CONF_USERNAME): cv.matches_regex(CONF_USERNAME_REGEX), vol.Required(CONF_PASSWORD): cv.string, vol.Optional(CONF_ROOMS, default=[]): vol.All( - cv.ensure_list, [cv.string] + cv.ensure_list, [cv.matches_regex(CONF_ROOMS_REGEX)] ), vol.Optional(CONF_COMMANDS, default=[]): [COMMAND_SCHEMA], } @@ -116,7 +125,9 @@ SERVICE_SCHEMA_SEND_MESSAGE = vol.Schema( ), vol.Optional(ATTR_IMAGES): vol.All(cv.ensure_list, [cv.string]), }, - vol.Required(ATTR_TARGET): vol.All(cv.ensure_list, [cv.string]), + vol.Required(ATTR_TARGET): vol.All( + cv.ensure_list, [cv.matches_regex(CONF_ROOMS_REGEX)] + ), } ) @@ -160,7 +171,7 @@ class MatrixBot: verify_ssl: bool, username: str, password: str, - listening_rooms: list[RoomID], + listening_rooms: list[RoomAnyID], commands: list[ConfigCommand], ) -> None: """Set up the client.""" @@ -178,11 +189,10 @@ class MatrixBot: homeserver=self._homeserver, user=self._mx_id, ssl=self._verify_tls ) - self._listening_rooms = listening_rooms - + self._listening_rooms: dict[RoomAnyID, RoomID] = {} self._word_commands: dict[RoomID, dict[WordCommand, ConfigCommand]] = {} self._expression_commands: dict[RoomID, list[ConfigCommand]] = {} - self._load_commands(commands) + self._unparsed_commands = commands async def stop_client(event: HassEvent) -> None: """Run once when Home Assistant stops.""" @@ -195,6 +205,8 @@ class MatrixBot: """Run once when Home Assistant finished startup.""" self._access_tokens = await self._get_auth_tokens() await self._login() + await self._resolve_room_aliases(listening_rooms) + self._load_commands(commands) await self._join_rooms() # Sync once so that we don't respond to past events. await self._client.sync(timeout=30_000) @@ -211,7 +223,7 @@ class MatrixBot: def _load_commands(self, commands: list[ConfigCommand]) -> None: for command in commands: # Set the command for all listening_rooms, unless otherwise specified. - command.setdefault(CONF_ROOMS, self._listening_rooms) # type: ignore[misc] + command.setdefault(CONF_ROOMS, list(self._listening_rooms.values())) # type: ignore[misc] # COMMAND_SCHEMA guarantees that exactly one of CONF_WORD and CONF_expression are set. if (word_command := command.get(CONF_WORD)) is not None: @@ -262,24 +274,60 @@ class MatrixBot: } self.hass.bus.async_fire(EVENT_MATRIX_COMMAND, message_data) - async def _join_room(self, room_id_or_alias: str) -> None: + async def _resolve_room_alias( + self, room_alias_or_id: RoomAnyID + ) -> dict[RoomAnyID, RoomID]: + """Resolve a single RoomAlias if needed.""" + if room_alias_or_id.startswith("!"): + room_id = RoomID(room_alias_or_id) + _LOGGER.debug("Will listen to room_id '%s'", room_id) + elif room_alias_or_id.startswith("#"): + room_alias = RoomAlias(room_alias_or_id) + resolve_response = await self._client.room_resolve_alias(room_alias) + if isinstance(resolve_response, RoomResolveAliasResponse): + room_id = RoomID(resolve_response.room_id) + _LOGGER.debug( + "Will listen to room_alias '%s' as room_id '%s'", + room_alias_or_id, + room_id, + ) + else: + _LOGGER.error( + "Could not resolve '%s' to a room_id: '%s'", + room_alias_or_id, + resolve_response, + ) + return {} + # The config schema guarantees it's a valid room alias or id, so room_id is always set. + return {room_alias_or_id: room_id} + + async def _resolve_room_aliases(self, listening_rooms: list[RoomAnyID]) -> None: + """Resolve any RoomAliases into RoomIDs for the purpose of client interactions.""" + resolved_rooms = [ + self.hass.async_create_task(self._resolve_room_alias(room_alias_or_id)) + for room_alias_or_id in listening_rooms + ] + for resolved_room in asyncio.as_completed(resolved_rooms): + self._listening_rooms |= await resolved_room + + async def _join_room(self, room_id: RoomID, room_alias_or_id: RoomAnyID) -> None: """Join a room or do nothing if already joined.""" - join_response = await self._client.join(room_id_or_alias) + join_response = await self._client.join(room_id) if isinstance(join_response, JoinResponse): - _LOGGER.debug("Joined or already in room '%s'", room_id_or_alias) + _LOGGER.debug("Joined or already in room '%s'", room_alias_or_id) elif isinstance(join_response, JoinError): _LOGGER.error( "Could not join room '%s': %s", - room_id_or_alias, + room_alias_or_id, join_response, ) async def _join_rooms(self) -> None: """Join the Matrix rooms that we listen for commands in.""" rooms = [ - self.hass.async_create_task(self._join_room(room_id)) - for room_id in self._listening_rooms + self.hass.async_create_task(self._join_room(room_id, room_alias_or_id)) + for room_alias_or_id, room_id in self._listening_rooms.items() ] await asyncio.wait(rooms) @@ -356,11 +404,11 @@ class MatrixBot: await self._store_auth_token(self._client.access_token) async def _handle_room_send( - self, target_room: RoomID, message_type: str, content: dict + self, target_room: RoomAnyID, message_type: str, content: dict ) -> None: """Wrap _client.room_send and handle ErrorResponses.""" response: Response = await self._client.room_send( - room_id=target_room, + room_id=self._listening_rooms.get(target_room, target_room), message_type=message_type, content=content, ) @@ -374,7 +422,7 @@ class MatrixBot: _LOGGER.debug("Message delivered to room '%s'", target_room) async def _handle_multi_room_send( - self, target_rooms: list[RoomID], message_type: str, content: dict + self, target_rooms: Sequence[RoomAnyID], message_type: str, content: dict ) -> None: """Wrap _handle_room_send for multiple target_rooms.""" _tasks = [] @@ -390,7 +438,9 @@ class MatrixBot: ) await asyncio.wait(_tasks) - async def _send_image(self, image_path: str, target_rooms: list[RoomID]) -> None: + async def _send_image( + self, image_path: str, target_rooms: Sequence[RoomAnyID] + ) -> None: """Upload an image, then send it to all target_rooms.""" _is_allowed_path = await self.hass.async_add_executor_job( self.hass.config.is_allowed_path, image_path @@ -442,7 +492,7 @@ class MatrixBot: ) async def _send_message( - self, message: str, target_rooms: list[RoomID], data: dict | None + self, message: str, target_rooms: list[RoomAnyID], data: dict | None ) -> None: """Send a message to the Matrix server.""" content = {"msgtype": "m.text", "body": message} diff --git a/tests/components/matrix/conftest.py b/tests/components/matrix/conftest.py index d0970b96019..1198d7e6012 100644 --- a/tests/components/matrix/conftest.py +++ b/tests/components/matrix/conftest.py @@ -14,6 +14,8 @@ from nio import ( LoginError, LoginResponse, Response, + RoomResolveAliasError, + RoomResolveAliasResponse, UploadResponse, WhoamiError, WhoamiResponse, @@ -48,8 +50,15 @@ from tests.common import async_capture_events TEST_NOTIFIER_NAME = "matrix_notify" +TEST_HOMESERVER = "example.com" TEST_DEFAULT_ROOM = "!DefaultNotificationRoom:example.com" -TEST_JOINABLE_ROOMS = ["!RoomIdString:example.com", "#RoomAliasString:example.com"] +TEST_ROOM_A_ID = "!RoomA-ID:example.com" +TEST_ROOM_B_ID = "!RoomB-ID:example.com" +TEST_ROOM_B_ALIAS = "#RoomB-Alias:example.com" +TEST_JOINABLE_ROOMS = { + TEST_ROOM_A_ID: TEST_ROOM_A_ID, + TEST_ROOM_B_ALIAS: TEST_ROOM_B_ID, +} TEST_BAD_ROOM = "!UninvitedRoom:example.com" TEST_MXID = "@user:example.com" TEST_DEVICE_ID = "FAKEID" @@ -65,8 +74,16 @@ class _MockAsyncClient(AsyncClient): async def close(self): return None + async def room_resolve_alias(self, room_alias: str): + if room_id := TEST_JOINABLE_ROOMS.get(room_alias): + return RoomResolveAliasResponse( + room_alias=room_alias, room_id=room_id, servers=[TEST_HOMESERVER] + ) + else: + return RoomResolveAliasError(message=f"Could not resolve {room_alias}") + async def join(self, room_id: RoomID): - if room_id in TEST_JOINABLE_ROOMS: + if room_id in TEST_JOINABLE_ROOMS.values(): return JoinResponse(room_id=room_id) else: return JoinError(message="Not allowed to join this room.") @@ -102,10 +119,10 @@ class _MockAsyncClient(AsyncClient): async def room_send(self, *args, **kwargs): if not self.logged_in: raise LocalProtocolError - if kwargs["room_id"] in TEST_JOINABLE_ROOMS: - return Response() - else: + if kwargs["room_id"] not in TEST_JOINABLE_ROOMS.values(): return ErrorResponse(message="Cannot send a message in this room.") + else: + return Response() async def sync(self, *args, **kwargs): return None @@ -123,7 +140,7 @@ MOCK_CONFIG_DATA = { CONF_USERNAME: TEST_MXID, CONF_PASSWORD: TEST_PASSWORD, CONF_VERIFY_SSL: True, - CONF_ROOMS: TEST_JOINABLE_ROOMS, + CONF_ROOMS: list(TEST_JOINABLE_ROOMS), CONF_COMMANDS: [ { CONF_WORD: "WordTrigger", @@ -143,35 +160,35 @@ MOCK_CONFIG_DATA = { } MOCK_WORD_COMMANDS = { - "!RoomIdString:example.com": { + TEST_ROOM_A_ID: { "WordTrigger": { "word": "WordTrigger", "name": "WordTriggerEventName", - "rooms": ["!RoomIdString:example.com", "#RoomAliasString:example.com"], + "rooms": [TEST_ROOM_A_ID, TEST_ROOM_B_ID], } }, - "#RoomAliasString:example.com": { + TEST_ROOM_B_ID: { "WordTrigger": { "word": "WordTrigger", "name": "WordTriggerEventName", - "rooms": ["!RoomIdString:example.com", "#RoomAliasString:example.com"], + "rooms": [TEST_ROOM_A_ID, TEST_ROOM_B_ID], } }, } MOCK_EXPRESSION_COMMANDS = { - "!RoomIdString:example.com": [ + TEST_ROOM_A_ID: [ { "expression": re.compile("My name is (?P.*)"), "name": "ExpressionTriggerEventName", - "rooms": ["!RoomIdString:example.com", "#RoomAliasString:example.com"], + "rooms": [TEST_ROOM_A_ID, TEST_ROOM_B_ID], } ], - "#RoomAliasString:example.com": [ + TEST_ROOM_B_ID: [ { "expression": re.compile("My name is (?P.*)"), "name": "ExpressionTriggerEventName", - "rooms": ["!RoomIdString:example.com", "#RoomAliasString:example.com"], + "rooms": [TEST_ROOM_A_ID, TEST_ROOM_B_ID], } ], } diff --git a/tests/components/matrix/test_matrix_bot.py b/tests/components/matrix/test_matrix_bot.py index 0b150a629fe..0048f6665e8 100644 --- a/tests/components/matrix/test_matrix_bot.py +++ b/tests/components/matrix/test_matrix_bot.py @@ -12,8 +12,8 @@ from homeassistant.core import HomeAssistant from .conftest import ( MOCK_EXPRESSION_COMMANDS, MOCK_WORD_COMMANDS, - TEST_JOINABLE_ROOMS, TEST_NOTIFIER_NAME, + TEST_ROOM_A_ID, ) @@ -34,12 +34,13 @@ async def test_services(hass: HomeAssistant, matrix_bot: MatrixBot): async def test_commands(hass, matrix_bot: MatrixBot, command_events): """Test that the configured commands were parsed correctly.""" + await hass.async_start() assert len(command_events) == 0 assert matrix_bot._word_commands == MOCK_WORD_COMMANDS assert matrix_bot._expression_commands == MOCK_EXPRESSION_COMMANDS - room_id = TEST_JOINABLE_ROOMS[0] + room_id = TEST_ROOM_A_ID room = MatrixRoom(room_id=room_id, own_user_id=matrix_bot._mx_id) # Test single-word command. diff --git a/tests/components/matrix/test_join_rooms.py b/tests/components/matrix/test_rooms.py similarity index 60% rename from tests/components/matrix/test_join_rooms.py rename to tests/components/matrix/test_rooms.py index 54856b91ac3..29081b80fd5 100644 --- a/tests/components/matrix/test_join_rooms.py +++ b/tests/components/matrix/test_rooms.py @@ -5,18 +5,24 @@ from homeassistant.components.matrix import MatrixBot from tests.components.matrix.conftest import TEST_BAD_ROOM, TEST_JOINABLE_ROOMS -async def test_join(matrix_bot: MatrixBot, caplog): +async def test_join(hass, matrix_bot: MatrixBot, caplog): """Test joining configured rooms.""" - # Join configured rooms. - await matrix_bot._join_rooms() + await hass.async_start() for room_id in TEST_JOINABLE_ROOMS: assert f"Joined or already in room '{room_id}'" in caplog.messages # Joining a disallowed room should not raise an exception. - matrix_bot._listening_rooms = [TEST_BAD_ROOM] + matrix_bot._listening_rooms = {TEST_BAD_ROOM: TEST_BAD_ROOM} await matrix_bot._join_rooms() assert ( f"Could not join room '{TEST_BAD_ROOM}': JoinError: Not allowed to join this room." in caplog.messages ) + + +async def test_resolve_aliases(hass, matrix_bot: MatrixBot): + """Test resolving configured room aliases into room ids.""" + + await hass.async_start() + assert matrix_bot._listening_rooms == TEST_JOINABLE_ROOMS diff --git a/tests/components/matrix/test_send_message.py b/tests/components/matrix/test_send_message.py index 34964f2b091..47c3e08aa48 100644 --- a/tests/components/matrix/test_send_message.py +++ b/tests/components/matrix/test_send_message.py @@ -17,30 +17,32 @@ async def test_send_message( hass: HomeAssistant, matrix_bot: MatrixBot, image_path, matrix_events, caplog ): """Test the send_message service.""" + + await hass.async_start() assert len(matrix_events) == 0 await matrix_bot._login() # Send a message without an attached image. - data = {ATTR_MESSAGE: "Test message", ATTR_TARGET: TEST_JOINABLE_ROOMS} + data = {ATTR_MESSAGE: "Test message", ATTR_TARGET: list(TEST_JOINABLE_ROOMS)} await hass.services.async_call( MATRIX_DOMAIN, SERVICE_SEND_MESSAGE, data, blocking=True ) - for room_id in TEST_JOINABLE_ROOMS: - assert f"Message delivered to room '{room_id}'" in caplog.messages + for room_alias_or_id in TEST_JOINABLE_ROOMS: + assert f"Message delivered to room '{room_alias_or_id}'" in caplog.messages # Send an HTML message without an attached image. data = { ATTR_MESSAGE: "Test message", - ATTR_TARGET: TEST_JOINABLE_ROOMS, + ATTR_TARGET: list(TEST_JOINABLE_ROOMS), ATTR_DATA: {ATTR_FORMAT: FORMAT_HTML}, } await hass.services.async_call( MATRIX_DOMAIN, SERVICE_SEND_MESSAGE, data, blocking=True ) - for room_id in TEST_JOINABLE_ROOMS: - assert f"Message delivered to room '{room_id}'" in caplog.messages + for room_alias_or_id in TEST_JOINABLE_ROOMS: + assert f"Message delivered to room '{room_alias_or_id}'" in caplog.messages # Send a message with an attached image. data[ATTR_DATA] = {ATTR_IMAGES: [image_path.name]} @@ -48,8 +50,8 @@ async def test_send_message( MATRIX_DOMAIN, SERVICE_SEND_MESSAGE, data, blocking=True ) - for room_id in TEST_JOINABLE_ROOMS: - assert f"Message delivered to room '{room_id}'" in caplog.messages + for room_alias_or_id in TEST_JOINABLE_ROOMS: + assert f"Message delivered to room '{room_alias_or_id}'" in caplog.messages async def test_unsendable_message(