diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 54539158148..0f52685ca2d 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -105,7 +105,7 @@ def pong_message(iden: int) -> dict[str, Any]: def _forward_events_check_permissions( send_message: Callable[[bytes | str | dict[str, Any] | Callable[[], str]], None], user: User, - msg_id: int, + message_id_as_bytes: bytes, event: Event, ) -> None: """Forward state changed events to websocket.""" @@ -118,17 +118,17 @@ def _forward_events_check_permissions( and not permissions.check_entity(event.data["entity_id"], POLICY_READ) ): return - send_message(messages.cached_event_message(msg_id, event)) + send_message(messages.cached_event_message(message_id_as_bytes, event)) @callback def _forward_events_unconditional( send_message: Callable[[bytes | str | dict[str, Any] | Callable[[], str]], None], - msg_id: int, + message_id_as_bytes: bytes, event: Event, ) -> None: """Forward events to websocket.""" - send_message(messages.cached_event_message(msg_id, event)) + send_message(messages.cached_event_message(message_id_as_bytes, event)) @callback @@ -152,16 +152,18 @@ def handle_subscribe_events( ) raise Unauthorized(user_id=connection.user.id) + message_id_as_bytes = str(msg["id"]).encode() + if event_type == EVENT_STATE_CHANGED: forward_events = partial( _forward_events_check_permissions, connection.send_message, connection.user, - msg["id"], + message_id_as_bytes, ) else: forward_events = partial( - _forward_events_unconditional, connection.send_message, msg["id"] + _forward_events_unconditional, connection.send_message, message_id_as_bytes ) connection.subscriptions[msg["id"]] = hass.bus.async_listen( @@ -366,7 +368,7 @@ def _forward_entity_changes( send_message: Callable[[str | bytes | dict[str, Any] | Callable[[], str]], None], entity_ids: set[str], user: User, - msg_id: int, + message_id_as_bytes: bytes, event: Event[EventStateChangedData], ) -> None: """Forward entity state changed events to websocket.""" @@ -382,7 +384,7 @@ def _forward_entity_changes( and not permissions.check_entity(event.data["entity_id"], POLICY_READ) ): return - send_message(messages.cached_state_diff_message(msg_id, event)) + send_message(messages.cached_state_diff_message(message_id_as_bytes, event)) @callback @@ -401,6 +403,7 @@ def handle_subscribe_entities( # state changed events or we will introduce a race condition # where some states are missed states = _async_get_allowed_states(hass, connection) + message_id_as_bytes = str(msg["id"]).encode() connection.subscriptions[msg["id"]] = hass.bus.async_listen( EVENT_STATE_CHANGED, partial( @@ -408,7 +411,7 @@ def handle_subscribe_entities( connection.send_message, entity_ids, connection.user, - msg["id"], + message_id_as_bytes, ), ) connection.send_result(msg["id"]) diff --git a/homeassistant/components/websocket_api/messages.py b/homeassistant/components/websocket_api/messages.py index 75a9c9999d4..98db92dfef7 100644 --- a/homeassistant/components/websocket_api/messages.py +++ b/homeassistant/components/websocket_api/messages.py @@ -109,7 +109,7 @@ def event_message(iden: int, event: Any) -> dict[str, Any]: return {"id": iden, "type": "event", "event": event} -def cached_event_message(iden: int, event: Event) -> bytes: +def cached_event_message(message_id_as_bytes: bytes, event: Event) -> bytes: """Return an event message. Serialize to json once per message. @@ -122,7 +122,7 @@ def cached_event_message(iden: int, event: Event) -> bytes: ( _partial_cached_event_message(event)[:-1], b',"id":', - str(iden).encode(), + message_id_as_bytes, b"}", ) ) @@ -141,7 +141,9 @@ def _partial_cached_event_message(event: Event) -> bytes: ) -def cached_state_diff_message(iden: int, event: Event[EventStateChangedData]) -> bytes: +def cached_state_diff_message( + message_id_as_bytes: bytes, event: Event[EventStateChangedData] +) -> bytes: """Return an event message. Serialize to json once per message. @@ -154,7 +156,7 @@ def cached_state_diff_message(iden: int, event: Event[EventStateChangedData]) -> ( _partial_cached_state_diff_message(event)[:-1], b',"id":', - str(iden).encode(), + message_id_as_bytes, b"}", ) ) diff --git a/tests/components/websocket_api/test_messages.py b/tests/components/websocket_api/test_messages.py index 350aed8b5f7..6294b6a2628 100644 --- a/tests/components/websocket_api/test_messages.py +++ b/tests/components/websocket_api/test_messages.py @@ -32,11 +32,11 @@ async def test_cached_event_message(hass: HomeAssistant) -> None: assert len(events) == 2 lru_event_cache.cache_clear() - msg0 = cached_event_message(2, events[0]) - assert msg0 == cached_event_message(2, events[0]) + msg0 = cached_event_message(b"2", events[0]) + assert msg0 == cached_event_message(b"2", events[0]) - msg1 = cached_event_message(2, events[1]) - assert msg1 == cached_event_message(2, events[1]) + msg1 = cached_event_message(b"2", events[1]) + assert msg1 == cached_event_message(b"2", events[1]) assert msg0 != msg1 @@ -45,7 +45,7 @@ async def test_cached_event_message(hass: HomeAssistant) -> None: assert cache_info.misses == 2 assert cache_info.currsize == 2 - cached_event_message(2, events[1]) + cached_event_message(b"2", events[1]) cache_info = lru_event_cache.cache_info() assert cache_info.hits == 3 assert cache_info.misses == 2 @@ -70,9 +70,9 @@ async def test_cached_event_message_with_different_idens(hass: HomeAssistant) -> lru_event_cache.cache_clear() - msg0 = cached_event_message(2, events[0]) - msg1 = cached_event_message(3, events[0]) - msg2 = cached_event_message(4, events[0]) + msg0 = cached_event_message(b"2", events[0]) + msg1 = cached_event_message(b"3", events[0]) + msg2 = cached_event_message(b"4", events[0]) assert msg0 != msg1 assert msg0 != msg2