Avoid re-encoding the message id as bytes for every event/state change (#116460)

This commit is contained in:
J. Nick Koston 2024-04-30 12:02:28 -05:00 committed by GitHub
parent fbe1781ebc
commit 9995207817
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 26 additions and 21 deletions

View file

@ -105,7 +105,7 @@ def pong_message(iden: int) -> dict[str, Any]:
def _forward_events_check_permissions( def _forward_events_check_permissions(
send_message: Callable[[bytes | str | dict[str, Any] | Callable[[], str]], None], send_message: Callable[[bytes | str | dict[str, Any] | Callable[[], str]], None],
user: User, user: User,
msg_id: int, message_id_as_bytes: bytes,
event: Event, event: Event,
) -> None: ) -> None:
"""Forward state changed events to websocket.""" """Forward state changed events to websocket."""
@ -118,17 +118,17 @@ def _forward_events_check_permissions(
and not permissions.check_entity(event.data["entity_id"], POLICY_READ) and not permissions.check_entity(event.data["entity_id"], POLICY_READ)
): ):
return return
send_message(messages.cached_event_message(msg_id, event)) send_message(messages.cached_event_message(message_id_as_bytes, event))
@callback @callback
def _forward_events_unconditional( def _forward_events_unconditional(
send_message: Callable[[bytes | str | dict[str, Any] | Callable[[], str]], None], send_message: Callable[[bytes | str | dict[str, Any] | Callable[[], str]], None],
msg_id: int, message_id_as_bytes: bytes,
event: Event, event: Event,
) -> None: ) -> None:
"""Forward events to websocket.""" """Forward events to websocket."""
send_message(messages.cached_event_message(msg_id, event)) send_message(messages.cached_event_message(message_id_as_bytes, event))
@callback @callback
@ -152,16 +152,18 @@ def handle_subscribe_events(
) )
raise Unauthorized(user_id=connection.user.id) raise Unauthorized(user_id=connection.user.id)
message_id_as_bytes = str(msg["id"]).encode()
if event_type == EVENT_STATE_CHANGED: if event_type == EVENT_STATE_CHANGED:
forward_events = partial( forward_events = partial(
_forward_events_check_permissions, _forward_events_check_permissions,
connection.send_message, connection.send_message,
connection.user, connection.user,
msg["id"], message_id_as_bytes,
) )
else: else:
forward_events = partial( forward_events = partial(
_forward_events_unconditional, connection.send_message, msg["id"] _forward_events_unconditional, connection.send_message, message_id_as_bytes
) )
connection.subscriptions[msg["id"]] = hass.bus.async_listen( connection.subscriptions[msg["id"]] = hass.bus.async_listen(
@ -366,7 +368,7 @@ def _forward_entity_changes(
send_message: Callable[[str | bytes | dict[str, Any] | Callable[[], str]], None], send_message: Callable[[str | bytes | dict[str, Any] | Callable[[], str]], None],
entity_ids: set[str], entity_ids: set[str],
user: User, user: User,
msg_id: int, message_id_as_bytes: bytes,
event: Event[EventStateChangedData], event: Event[EventStateChangedData],
) -> None: ) -> None:
"""Forward entity state changed events to websocket.""" """Forward entity state changed events to websocket."""
@ -382,7 +384,7 @@ def _forward_entity_changes(
and not permissions.check_entity(event.data["entity_id"], POLICY_READ) and not permissions.check_entity(event.data["entity_id"], POLICY_READ)
): ):
return return
send_message(messages.cached_state_diff_message(msg_id, event)) send_message(messages.cached_state_diff_message(message_id_as_bytes, event))
@callback @callback
@ -401,6 +403,7 @@ def handle_subscribe_entities(
# state changed events or we will introduce a race condition # state changed events or we will introduce a race condition
# where some states are missed # where some states are missed
states = _async_get_allowed_states(hass, connection) states = _async_get_allowed_states(hass, connection)
message_id_as_bytes = str(msg["id"]).encode()
connection.subscriptions[msg["id"]] = hass.bus.async_listen( connection.subscriptions[msg["id"]] = hass.bus.async_listen(
EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
partial( partial(
@ -408,7 +411,7 @@ def handle_subscribe_entities(
connection.send_message, connection.send_message,
entity_ids, entity_ids,
connection.user, connection.user,
msg["id"], message_id_as_bytes,
), ),
) )
connection.send_result(msg["id"]) connection.send_result(msg["id"])

View file

@ -109,7 +109,7 @@ def event_message(iden: int, event: Any) -> dict[str, Any]:
return {"id": iden, "type": "event", "event": event} return {"id": iden, "type": "event", "event": event}
def cached_event_message(iden: int, event: Event) -> bytes: def cached_event_message(message_id_as_bytes: bytes, event: Event) -> bytes:
"""Return an event message. """Return an event message.
Serialize to json once per message. Serialize to json once per message.
@ -122,7 +122,7 @@ def cached_event_message(iden: int, event: Event) -> bytes:
( (
_partial_cached_event_message(event)[:-1], _partial_cached_event_message(event)[:-1],
b',"id":', b',"id":',
str(iden).encode(), message_id_as_bytes,
b"}", b"}",
) )
) )
@ -141,7 +141,9 @@ def _partial_cached_event_message(event: Event) -> bytes:
) )
def cached_state_diff_message(iden: int, event: Event[EventStateChangedData]) -> bytes: def cached_state_diff_message(
message_id_as_bytes: bytes, event: Event[EventStateChangedData]
) -> bytes:
"""Return an event message. """Return an event message.
Serialize to json once per message. Serialize to json once per message.
@ -154,7 +156,7 @@ def cached_state_diff_message(iden: int, event: Event[EventStateChangedData]) ->
( (
_partial_cached_state_diff_message(event)[:-1], _partial_cached_state_diff_message(event)[:-1],
b',"id":', b',"id":',
str(iden).encode(), message_id_as_bytes,
b"}", b"}",
) )
) )

View file

@ -32,11 +32,11 @@ async def test_cached_event_message(hass: HomeAssistant) -> None:
assert len(events) == 2 assert len(events) == 2
lru_event_cache.cache_clear() lru_event_cache.cache_clear()
msg0 = cached_event_message(2, events[0]) msg0 = cached_event_message(b"2", events[0])
assert msg0 == cached_event_message(2, events[0]) assert msg0 == cached_event_message(b"2", events[0])
msg1 = cached_event_message(2, events[1]) msg1 = cached_event_message(b"2", events[1])
assert msg1 == cached_event_message(2, events[1]) assert msg1 == cached_event_message(b"2", events[1])
assert msg0 != msg1 assert msg0 != msg1
@ -45,7 +45,7 @@ async def test_cached_event_message(hass: HomeAssistant) -> None:
assert cache_info.misses == 2 assert cache_info.misses == 2
assert cache_info.currsize == 2 assert cache_info.currsize == 2
cached_event_message(2, events[1]) cached_event_message(b"2", events[1])
cache_info = lru_event_cache.cache_info() cache_info = lru_event_cache.cache_info()
assert cache_info.hits == 3 assert cache_info.hits == 3
assert cache_info.misses == 2 assert cache_info.misses == 2
@ -70,9 +70,9 @@ async def test_cached_event_message_with_different_idens(hass: HomeAssistant) ->
lru_event_cache.cache_clear() lru_event_cache.cache_clear()
msg0 = cached_event_message(2, events[0]) msg0 = cached_event_message(b"2", events[0])
msg1 = cached_event_message(3, events[0]) msg1 = cached_event_message(b"3", events[0])
msg2 = cached_event_message(4, events[0]) msg2 = cached_event_message(b"4", events[0])
assert msg0 != msg1 assert msg0 != msg1
assert msg0 != msg2 assert msg0 != msg2