Fix resolving Matrix room aliases (#101928)
This commit is contained in:
parent
1176003b51
commit
30ba78cf82
5 changed files with 125 additions and 49 deletions
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
|
@ -17,6 +18,7 @@ from nio.responses import (
|
|||
JoinResponse,
|
||||
LoginError,
|
||||
Response,
|
||||
RoomResolveAliasResponse,
|
||||
UploadError,
|
||||
UploadResponse,
|
||||
WhoamiError,
|
||||
|
@ -53,6 +55,9 @@ CONF_COMMANDS = "commands"
|
|||
CONF_WORD = "word"
|
||||
CONF_EXPRESSION = "expression"
|
||||
|
||||
CONF_USERNAME_REGEX = "^@[^:]*:.*"
|
||||
CONF_ROOMS_REGEX = "^[!|#][^:]*:.*"
|
||||
|
||||
EVENT_MATRIX_COMMAND = "matrix_command"
|
||||
|
||||
DEFAULT_CONTENT_TYPE = "application/octet-stream"
|
||||
|
@ -65,7 +70,9 @@ ATTR_IMAGES = "images" # optional images
|
|||
|
||||
WordCommand = NewType("WordCommand", str)
|
||||
ExpressionCommand = NewType("ExpressionCommand", re.Pattern)
|
||||
RoomID = NewType("RoomID", str)
|
||||
RoomAlias = NewType("RoomAlias", str) # Starts with "#"
|
||||
RoomID = NewType("RoomID", str) # Starts with "!"
|
||||
RoomAnyID = RoomID | RoomAlias
|
||||
|
||||
|
||||
class ConfigCommand(TypedDict, total=False):
|
||||
|
@ -83,7 +90,9 @@ COMMAND_SCHEMA = vol.All(
|
|||
vol.Exclusive(CONF_WORD, "trigger"): cv.string,
|
||||
vol.Exclusive(CONF_EXPRESSION, "trigger"): cv.is_regex,
|
||||
vol.Required(CONF_NAME): cv.string,
|
||||
vol.Optional(CONF_ROOMS): vol.All(cv.ensure_list, [cv.string]),
|
||||
vol.Optional(CONF_ROOMS): vol.All(
|
||||
cv.ensure_list, [cv.matches_regex(CONF_ROOMS_REGEX)]
|
||||
),
|
||||
}
|
||||
),
|
||||
cv.has_at_least_one_key(CONF_WORD, CONF_EXPRESSION),
|
||||
|
@ -95,10 +104,10 @@ CONFIG_SCHEMA = vol.Schema(
|
|||
{
|
||||
vol.Required(CONF_HOMESERVER): cv.url,
|
||||
vol.Optional(CONF_VERIFY_SSL, default=True): cv.boolean,
|
||||
vol.Required(CONF_USERNAME): cv.matches_regex("@[^:]*:.*"),
|
||||
vol.Required(CONF_USERNAME): cv.matches_regex(CONF_USERNAME_REGEX),
|
||||
vol.Required(CONF_PASSWORD): cv.string,
|
||||
vol.Optional(CONF_ROOMS, default=[]): vol.All(
|
||||
cv.ensure_list, [cv.string]
|
||||
cv.ensure_list, [cv.matches_regex(CONF_ROOMS_REGEX)]
|
||||
),
|
||||
vol.Optional(CONF_COMMANDS, default=[]): [COMMAND_SCHEMA],
|
||||
}
|
||||
|
@ -116,7 +125,9 @@ SERVICE_SCHEMA_SEND_MESSAGE = vol.Schema(
|
|||
),
|
||||
vol.Optional(ATTR_IMAGES): vol.All(cv.ensure_list, [cv.string]),
|
||||
},
|
||||
vol.Required(ATTR_TARGET): vol.All(cv.ensure_list, [cv.string]),
|
||||
vol.Required(ATTR_TARGET): vol.All(
|
||||
cv.ensure_list, [cv.matches_regex(CONF_ROOMS_REGEX)]
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -160,7 +171,7 @@ class MatrixBot:
|
|||
verify_ssl: bool,
|
||||
username: str,
|
||||
password: str,
|
||||
listening_rooms: list[RoomID],
|
||||
listening_rooms: list[RoomAnyID],
|
||||
commands: list[ConfigCommand],
|
||||
) -> None:
|
||||
"""Set up the client."""
|
||||
|
@ -178,11 +189,10 @@ class MatrixBot:
|
|||
homeserver=self._homeserver, user=self._mx_id, ssl=self._verify_tls
|
||||
)
|
||||
|
||||
self._listening_rooms = listening_rooms
|
||||
|
||||
self._listening_rooms: dict[RoomAnyID, RoomID] = {}
|
||||
self._word_commands: dict[RoomID, dict[WordCommand, ConfigCommand]] = {}
|
||||
self._expression_commands: dict[RoomID, list[ConfigCommand]] = {}
|
||||
self._load_commands(commands)
|
||||
self._unparsed_commands = commands
|
||||
|
||||
async def stop_client(event: HassEvent) -> None:
|
||||
"""Run once when Home Assistant stops."""
|
||||
|
@ -195,6 +205,8 @@ class MatrixBot:
|
|||
"""Run once when Home Assistant finished startup."""
|
||||
self._access_tokens = await self._get_auth_tokens()
|
||||
await self._login()
|
||||
await self._resolve_room_aliases(listening_rooms)
|
||||
self._load_commands(commands)
|
||||
await self._join_rooms()
|
||||
# Sync once so that we don't respond to past events.
|
||||
await self._client.sync(timeout=30_000)
|
||||
|
@ -211,7 +223,7 @@ class MatrixBot:
|
|||
def _load_commands(self, commands: list[ConfigCommand]) -> None:
|
||||
for command in commands:
|
||||
# Set the command for all listening_rooms, unless otherwise specified.
|
||||
command.setdefault(CONF_ROOMS, self._listening_rooms) # type: ignore[misc]
|
||||
command.setdefault(CONF_ROOMS, list(self._listening_rooms.values())) # type: ignore[misc]
|
||||
|
||||
# COMMAND_SCHEMA guarantees that exactly one of CONF_WORD and CONF_expression are set.
|
||||
if (word_command := command.get(CONF_WORD)) is not None:
|
||||
|
@ -262,24 +274,60 @@ class MatrixBot:
|
|||
}
|
||||
self.hass.bus.async_fire(EVENT_MATRIX_COMMAND, message_data)
|
||||
|
||||
async def _join_room(self, room_id_or_alias: str) -> None:
|
||||
async def _resolve_room_alias(
|
||||
self, room_alias_or_id: RoomAnyID
|
||||
) -> dict[RoomAnyID, RoomID]:
|
||||
"""Resolve a single RoomAlias if needed."""
|
||||
if room_alias_or_id.startswith("!"):
|
||||
room_id = RoomID(room_alias_or_id)
|
||||
_LOGGER.debug("Will listen to room_id '%s'", room_id)
|
||||
elif room_alias_or_id.startswith("#"):
|
||||
room_alias = RoomAlias(room_alias_or_id)
|
||||
resolve_response = await self._client.room_resolve_alias(room_alias)
|
||||
if isinstance(resolve_response, RoomResolveAliasResponse):
|
||||
room_id = RoomID(resolve_response.room_id)
|
||||
_LOGGER.debug(
|
||||
"Will listen to room_alias '%s' as room_id '%s'",
|
||||
room_alias_or_id,
|
||||
room_id,
|
||||
)
|
||||
else:
|
||||
_LOGGER.error(
|
||||
"Could not resolve '%s' to a room_id: '%s'",
|
||||
room_alias_or_id,
|
||||
resolve_response,
|
||||
)
|
||||
return {}
|
||||
# The config schema guarantees it's a valid room alias or id, so room_id is always set.
|
||||
return {room_alias_or_id: room_id}
|
||||
|
||||
async def _resolve_room_aliases(self, listening_rooms: list[RoomAnyID]) -> None:
|
||||
"""Resolve any RoomAliases into RoomIDs for the purpose of client interactions."""
|
||||
resolved_rooms = [
|
||||
self.hass.async_create_task(self._resolve_room_alias(room_alias_or_id))
|
||||
for room_alias_or_id in listening_rooms
|
||||
]
|
||||
for resolved_room in asyncio.as_completed(resolved_rooms):
|
||||
self._listening_rooms |= await resolved_room
|
||||
|
||||
async def _join_room(self, room_id: RoomID, room_alias_or_id: RoomAnyID) -> None:
|
||||
"""Join a room or do nothing if already joined."""
|
||||
join_response = await self._client.join(room_id_or_alias)
|
||||
join_response = await self._client.join(room_id)
|
||||
|
||||
if isinstance(join_response, JoinResponse):
|
||||
_LOGGER.debug("Joined or already in room '%s'", room_id_or_alias)
|
||||
_LOGGER.debug("Joined or already in room '%s'", room_alias_or_id)
|
||||
elif isinstance(join_response, JoinError):
|
||||
_LOGGER.error(
|
||||
"Could not join room '%s': %s",
|
||||
room_id_or_alias,
|
||||
room_alias_or_id,
|
||||
join_response,
|
||||
)
|
||||
|
||||
async def _join_rooms(self) -> None:
|
||||
"""Join the Matrix rooms that we listen for commands in."""
|
||||
rooms = [
|
||||
self.hass.async_create_task(self._join_room(room_id))
|
||||
for room_id in self._listening_rooms
|
||||
self.hass.async_create_task(self._join_room(room_id, room_alias_or_id))
|
||||
for room_alias_or_id, room_id in self._listening_rooms.items()
|
||||
]
|
||||
await asyncio.wait(rooms)
|
||||
|
||||
|
@ -356,11 +404,11 @@ class MatrixBot:
|
|||
await self._store_auth_token(self._client.access_token)
|
||||
|
||||
async def _handle_room_send(
|
||||
self, target_room: RoomID, message_type: str, content: dict
|
||||
self, target_room: RoomAnyID, message_type: str, content: dict
|
||||
) -> None:
|
||||
"""Wrap _client.room_send and handle ErrorResponses."""
|
||||
response: Response = await self._client.room_send(
|
||||
room_id=target_room,
|
||||
room_id=self._listening_rooms.get(target_room, target_room),
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
)
|
||||
|
@ -374,7 +422,7 @@ class MatrixBot:
|
|||
_LOGGER.debug("Message delivered to room '%s'", target_room)
|
||||
|
||||
async def _handle_multi_room_send(
|
||||
self, target_rooms: list[RoomID], message_type: str, content: dict
|
||||
self, target_rooms: Sequence[RoomAnyID], message_type: str, content: dict
|
||||
) -> None:
|
||||
"""Wrap _handle_room_send for multiple target_rooms."""
|
||||
_tasks = []
|
||||
|
@ -390,7 +438,9 @@ class MatrixBot:
|
|||
)
|
||||
await asyncio.wait(_tasks)
|
||||
|
||||
async def _send_image(self, image_path: str, target_rooms: list[RoomID]) -> None:
|
||||
async def _send_image(
|
||||
self, image_path: str, target_rooms: Sequence[RoomAnyID]
|
||||
) -> None:
|
||||
"""Upload an image, then send it to all target_rooms."""
|
||||
_is_allowed_path = await self.hass.async_add_executor_job(
|
||||
self.hass.config.is_allowed_path, image_path
|
||||
|
@ -442,7 +492,7 @@ class MatrixBot:
|
|||
)
|
||||
|
||||
async def _send_message(
|
||||
self, message: str, target_rooms: list[RoomID], data: dict | None
|
||||
self, message: str, target_rooms: list[RoomAnyID], data: dict | None
|
||||
) -> None:
|
||||
"""Send a message to the Matrix server."""
|
||||
content = {"msgtype": "m.text", "body": message}
|
||||
|
|
|
@ -14,6 +14,8 @@ from nio import (
|
|||
LoginError,
|
||||
LoginResponse,
|
||||
Response,
|
||||
RoomResolveAliasError,
|
||||
RoomResolveAliasResponse,
|
||||
UploadResponse,
|
||||
WhoamiError,
|
||||
WhoamiResponse,
|
||||
|
@ -48,8 +50,15 @@ from tests.common import async_capture_events
|
|||
|
||||
TEST_NOTIFIER_NAME = "matrix_notify"
|
||||
|
||||
TEST_HOMESERVER = "example.com"
|
||||
TEST_DEFAULT_ROOM = "!DefaultNotificationRoom:example.com"
|
||||
TEST_JOINABLE_ROOMS = ["!RoomIdString:example.com", "#RoomAliasString:example.com"]
|
||||
TEST_ROOM_A_ID = "!RoomA-ID:example.com"
|
||||
TEST_ROOM_B_ID = "!RoomB-ID:example.com"
|
||||
TEST_ROOM_B_ALIAS = "#RoomB-Alias:example.com"
|
||||
TEST_JOINABLE_ROOMS = {
|
||||
TEST_ROOM_A_ID: TEST_ROOM_A_ID,
|
||||
TEST_ROOM_B_ALIAS: TEST_ROOM_B_ID,
|
||||
}
|
||||
TEST_BAD_ROOM = "!UninvitedRoom:example.com"
|
||||
TEST_MXID = "@user:example.com"
|
||||
TEST_DEVICE_ID = "FAKEID"
|
||||
|
@ -65,8 +74,16 @@ class _MockAsyncClient(AsyncClient):
|
|||
async def close(self):
|
||||
return None
|
||||
|
||||
async def room_resolve_alias(self, room_alias: str):
|
||||
if room_id := TEST_JOINABLE_ROOMS.get(room_alias):
|
||||
return RoomResolveAliasResponse(
|
||||
room_alias=room_alias, room_id=room_id, servers=[TEST_HOMESERVER]
|
||||
)
|
||||
else:
|
||||
return RoomResolveAliasError(message=f"Could not resolve {room_alias}")
|
||||
|
||||
async def join(self, room_id: RoomID):
|
||||
if room_id in TEST_JOINABLE_ROOMS:
|
||||
if room_id in TEST_JOINABLE_ROOMS.values():
|
||||
return JoinResponse(room_id=room_id)
|
||||
else:
|
||||
return JoinError(message="Not allowed to join this room.")
|
||||
|
@ -102,10 +119,10 @@ class _MockAsyncClient(AsyncClient):
|
|||
async def room_send(self, *args, **kwargs):
|
||||
if not self.logged_in:
|
||||
raise LocalProtocolError
|
||||
if kwargs["room_id"] in TEST_JOINABLE_ROOMS:
|
||||
return Response()
|
||||
else:
|
||||
if kwargs["room_id"] not in TEST_JOINABLE_ROOMS.values():
|
||||
return ErrorResponse(message="Cannot send a message in this room.")
|
||||
else:
|
||||
return Response()
|
||||
|
||||
async def sync(self, *args, **kwargs):
|
||||
return None
|
||||
|
@ -123,7 +140,7 @@ MOCK_CONFIG_DATA = {
|
|||
CONF_USERNAME: TEST_MXID,
|
||||
CONF_PASSWORD: TEST_PASSWORD,
|
||||
CONF_VERIFY_SSL: True,
|
||||
CONF_ROOMS: TEST_JOINABLE_ROOMS,
|
||||
CONF_ROOMS: list(TEST_JOINABLE_ROOMS),
|
||||
CONF_COMMANDS: [
|
||||
{
|
||||
CONF_WORD: "WordTrigger",
|
||||
|
@ -143,35 +160,35 @@ MOCK_CONFIG_DATA = {
|
|||
}
|
||||
|
||||
MOCK_WORD_COMMANDS = {
|
||||
"!RoomIdString:example.com": {
|
||||
TEST_ROOM_A_ID: {
|
||||
"WordTrigger": {
|
||||
"word": "WordTrigger",
|
||||
"name": "WordTriggerEventName",
|
||||
"rooms": ["!RoomIdString:example.com", "#RoomAliasString:example.com"],
|
||||
"rooms": [TEST_ROOM_A_ID, TEST_ROOM_B_ID],
|
||||
}
|
||||
},
|
||||
"#RoomAliasString:example.com": {
|
||||
TEST_ROOM_B_ID: {
|
||||
"WordTrigger": {
|
||||
"word": "WordTrigger",
|
||||
"name": "WordTriggerEventName",
|
||||
"rooms": ["!RoomIdString:example.com", "#RoomAliasString:example.com"],
|
||||
"rooms": [TEST_ROOM_A_ID, TEST_ROOM_B_ID],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
MOCK_EXPRESSION_COMMANDS = {
|
||||
"!RoomIdString:example.com": [
|
||||
TEST_ROOM_A_ID: [
|
||||
{
|
||||
"expression": re.compile("My name is (?P<name>.*)"),
|
||||
"name": "ExpressionTriggerEventName",
|
||||
"rooms": ["!RoomIdString:example.com", "#RoomAliasString:example.com"],
|
||||
"rooms": [TEST_ROOM_A_ID, TEST_ROOM_B_ID],
|
||||
}
|
||||
],
|
||||
"#RoomAliasString:example.com": [
|
||||
TEST_ROOM_B_ID: [
|
||||
{
|
||||
"expression": re.compile("My name is (?P<name>.*)"),
|
||||
"name": "ExpressionTriggerEventName",
|
||||
"rooms": ["!RoomIdString:example.com", "#RoomAliasString:example.com"],
|
||||
"rooms": [TEST_ROOM_A_ID, TEST_ROOM_B_ID],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
|
|
@ -12,8 +12,8 @@ from homeassistant.core import HomeAssistant
|
|||
from .conftest import (
|
||||
MOCK_EXPRESSION_COMMANDS,
|
||||
MOCK_WORD_COMMANDS,
|
||||
TEST_JOINABLE_ROOMS,
|
||||
TEST_NOTIFIER_NAME,
|
||||
TEST_ROOM_A_ID,
|
||||
)
|
||||
|
||||
|
||||
|
@ -34,12 +34,13 @@ async def test_services(hass: HomeAssistant, matrix_bot: MatrixBot):
|
|||
async def test_commands(hass, matrix_bot: MatrixBot, command_events):
|
||||
"""Test that the configured commands were parsed correctly."""
|
||||
|
||||
await hass.async_start()
|
||||
assert len(command_events) == 0
|
||||
|
||||
assert matrix_bot._word_commands == MOCK_WORD_COMMANDS
|
||||
assert matrix_bot._expression_commands == MOCK_EXPRESSION_COMMANDS
|
||||
|
||||
room_id = TEST_JOINABLE_ROOMS[0]
|
||||
room_id = TEST_ROOM_A_ID
|
||||
room = MatrixRoom(room_id=room_id, own_user_id=matrix_bot._mx_id)
|
||||
|
||||
# Test single-word command.
|
||||
|
|
|
@ -5,18 +5,24 @@ from homeassistant.components.matrix import MatrixBot
|
|||
from tests.components.matrix.conftest import TEST_BAD_ROOM, TEST_JOINABLE_ROOMS
|
||||
|
||||
|
||||
async def test_join(matrix_bot: MatrixBot, caplog):
|
||||
async def test_join(hass, matrix_bot: MatrixBot, caplog):
|
||||
"""Test joining configured rooms."""
|
||||
|
||||
# Join configured rooms.
|
||||
await matrix_bot._join_rooms()
|
||||
await hass.async_start()
|
||||
for room_id in TEST_JOINABLE_ROOMS:
|
||||
assert f"Joined or already in room '{room_id}'" in caplog.messages
|
||||
|
||||
# Joining a disallowed room should not raise an exception.
|
||||
matrix_bot._listening_rooms = [TEST_BAD_ROOM]
|
||||
matrix_bot._listening_rooms = {TEST_BAD_ROOM: TEST_BAD_ROOM}
|
||||
await matrix_bot._join_rooms()
|
||||
assert (
|
||||
f"Could not join room '{TEST_BAD_ROOM}': JoinError: Not allowed to join this room."
|
||||
in caplog.messages
|
||||
)
|
||||
|
||||
|
||||
async def test_resolve_aliases(hass, matrix_bot: MatrixBot):
|
||||
"""Test resolving configured room aliases into room ids."""
|
||||
|
||||
await hass.async_start()
|
||||
assert matrix_bot._listening_rooms == TEST_JOINABLE_ROOMS
|
|
@ -17,30 +17,32 @@ async def test_send_message(
|
|||
hass: HomeAssistant, matrix_bot: MatrixBot, image_path, matrix_events, caplog
|
||||
):
|
||||
"""Test the send_message service."""
|
||||
|
||||
await hass.async_start()
|
||||
assert len(matrix_events) == 0
|
||||
await matrix_bot._login()
|
||||
|
||||
# Send a message without an attached image.
|
||||
data = {ATTR_MESSAGE: "Test message", ATTR_TARGET: TEST_JOINABLE_ROOMS}
|
||||
data = {ATTR_MESSAGE: "Test message", ATTR_TARGET: list(TEST_JOINABLE_ROOMS)}
|
||||
await hass.services.async_call(
|
||||
MATRIX_DOMAIN, SERVICE_SEND_MESSAGE, data, blocking=True
|
||||
)
|
||||
|
||||
for room_id in TEST_JOINABLE_ROOMS:
|
||||
assert f"Message delivered to room '{room_id}'" in caplog.messages
|
||||
for room_alias_or_id in TEST_JOINABLE_ROOMS:
|
||||
assert f"Message delivered to room '{room_alias_or_id}'" in caplog.messages
|
||||
|
||||
# Send an HTML message without an attached image.
|
||||
data = {
|
||||
ATTR_MESSAGE: "Test message",
|
||||
ATTR_TARGET: TEST_JOINABLE_ROOMS,
|
||||
ATTR_TARGET: list(TEST_JOINABLE_ROOMS),
|
||||
ATTR_DATA: {ATTR_FORMAT: FORMAT_HTML},
|
||||
}
|
||||
await hass.services.async_call(
|
||||
MATRIX_DOMAIN, SERVICE_SEND_MESSAGE, data, blocking=True
|
||||
)
|
||||
|
||||
for room_id in TEST_JOINABLE_ROOMS:
|
||||
assert f"Message delivered to room '{room_id}'" in caplog.messages
|
||||
for room_alias_or_id in TEST_JOINABLE_ROOMS:
|
||||
assert f"Message delivered to room '{room_alias_or_id}'" in caplog.messages
|
||||
|
||||
# Send a message with an attached image.
|
||||
data[ATTR_DATA] = {ATTR_IMAGES: [image_path.name]}
|
||||
|
@ -48,8 +50,8 @@ async def test_send_message(
|
|||
MATRIX_DOMAIN, SERVICE_SEND_MESSAGE, data, blocking=True
|
||||
)
|
||||
|
||||
for room_id in TEST_JOINABLE_ROOMS:
|
||||
assert f"Message delivered to room '{room_id}'" in caplog.messages
|
||||
for room_alias_or_id in TEST_JOINABLE_ROOMS:
|
||||
assert f"Message delivered to room '{room_alias_or_id}'" in caplog.messages
|
||||
|
||||
|
||||
async def test_unsendable_message(
|
||||
|
|
Loading…
Add table
Reference in a new issue