Speed up reconnects by caching state serialize (#93050)
This commit is contained in:
parent
9c039a17ea
commit
99265a983a
6 changed files with 152 additions and 53 deletions
|
@ -2,7 +2,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from contextlib import suppress
|
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import json
|
import json
|
||||||
|
@ -50,6 +49,17 @@ from . import const, decorators, messages
|
||||||
from .connection import ActiveConnection
|
from .connection import ActiveConnection
|
||||||
from .const import ERR_NOT_FOUND
|
from .const import ERR_NOT_FOUND
|
||||||
|
|
||||||
|
_STATES_TEMPLATE = "__STATES__"
|
||||||
|
_STATES_JSON_TEMPLATE = '"__STATES__"'
|
||||||
|
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE = JSON_DUMP(
|
||||||
|
messages.event_message(
|
||||||
|
messages.IDEN_TEMPLATE, {messages.ENTITY_EVENT_ADD: _STATES_TEMPLATE}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_HANDLE_GET_STATES_TEMPLATE = JSON_DUMP(
|
||||||
|
messages.result_message(messages.IDEN_TEMPLATE, _STATES_TEMPLATE)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_register_commands(
|
def async_register_commands(
|
||||||
|
@ -242,33 +252,43 @@ def handle_get_states(
|
||||||
"""Handle get states command."""
|
"""Handle get states command."""
|
||||||
states = _async_get_allowed_states(hass, connection)
|
states = _async_get_allowed_states(hass, connection)
|
||||||
|
|
||||||
# JSON serialize here so we can recover if it blows up due to the
|
|
||||||
# state machine containing unserializable data. This command is required
|
|
||||||
# to succeed for the UI to show.
|
|
||||||
response = messages.result_message(msg["id"], states)
|
|
||||||
try:
|
try:
|
||||||
connection.send_message(JSON_DUMP(response))
|
serialized_states = [state.as_dict_json() for state in states]
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_send_handle_get_states_response(connection, msg["id"], serialized_states)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If we can't serialize, we'll filter out unserializable states
|
||||||
|
serialized_states = []
|
||||||
|
for state in states:
|
||||||
|
try:
|
||||||
|
serialized_states.append(state.as_dict_json())
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
connection.logger.error(
|
connection.logger.error(
|
||||||
"Unable to serialize to JSON. Bad data found at %s",
|
"Unable to serialize to JSON. Bad data found at %s",
|
||||||
format_unserializable_data(
|
format_unserializable_data(
|
||||||
find_paths_unserializable_data(response, dump=JSON_DUMP)
|
find_paths_unserializable_data(state, dump=JSON_DUMP)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
del response
|
|
||||||
|
|
||||||
# If we can't serialize, we'll filter out unserializable states
|
_send_handle_get_states_response(connection, msg["id"], serialized_states)
|
||||||
serialized = []
|
|
||||||
for state in states:
|
|
||||||
# Error is already logged above
|
|
||||||
with suppress(ValueError, TypeError):
|
|
||||||
serialized.append(JSON_DUMP(state))
|
|
||||||
|
|
||||||
# We now have partially serialized states. Craft some JSON.
|
|
||||||
response2 = JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"]))
|
def _send_handle_get_states_response(
|
||||||
response2 = response2.replace('"TO_REPLACE"', ", ".join(serialized))
|
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
|
||||||
connection.send_message(response2)
|
) -> None:
|
||||||
|
"""Send handle get states response."""
|
||||||
|
connection.send_message(
|
||||||
|
_HANDLE_GET_STATES_TEMPLATE.replace(
|
||||||
|
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
|
||||||
|
).replace(
|
||||||
|
_STATES_JSON_TEMPLATE,
|
||||||
|
"[" + ",".join(serialized_states) + "]",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -304,42 +324,50 @@ def handle_subscribe_entities(
|
||||||
EVENT_STATE_CHANGED, forward_entity_changes, run_immediately=True
|
EVENT_STATE_CHANGED, forward_entity_changes, run_immediately=True
|
||||||
)
|
)
|
||||||
connection.send_result(msg["id"])
|
connection.send_result(msg["id"])
|
||||||
data: dict[str, dict[str, dict]] = {
|
|
||||||
messages.ENTITY_EVENT_ADD: {
|
|
||||||
state.entity_id: state.as_compressed_state()
|
|
||||||
for state in states
|
|
||||||
if not entity_ids or state.entity_id in entity_ids
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# JSON serialize here so we can recover if it blows up due to the
|
# JSON serialize here so we can recover if it blows up due to the
|
||||||
# state machine containing unserializable data. This command is required
|
# state machine containing unserializable data. This command is required
|
||||||
# to succeed for the UI to show.
|
# to succeed for the UI to show.
|
||||||
response = messages.event_message(msg["id"], data)
|
|
||||||
try:
|
try:
|
||||||
connection.send_message(JSON_DUMP(response))
|
serialized_states = [
|
||||||
|
state.as_compressed_state_json()
|
||||||
|
for state in states
|
||||||
|
if not entity_ids or state.entity_id in entity_ids
|
||||||
|
]
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_send_handle_entities_init_response(connection, msg["id"], serialized_states)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
serialized_states = []
|
||||||
|
for state in states:
|
||||||
|
try:
|
||||||
|
serialized_states.append(state.as_compressed_state_json())
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
connection.logger.error(
|
connection.logger.error(
|
||||||
"Unable to serialize to JSON. Bad data found at %s",
|
"Unable to serialize to JSON. Bad data found at %s",
|
||||||
format_unserializable_data(
|
format_unserializable_data(
|
||||||
find_paths_unserializable_data(response, dump=JSON_DUMP)
|
find_paths_unserializable_data(state, dump=JSON_DUMP)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
del response
|
|
||||||
|
|
||||||
add_entities = data[messages.ENTITY_EVENT_ADD]
|
_send_handle_entities_init_response(connection, msg["id"], serialized_states)
|
||||||
cannot_serialize: list[str] = []
|
|
||||||
for entity_id, state_dict in add_entities.items():
|
|
||||||
try:
|
|
||||||
JSON_DUMP(state_dict)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
cannot_serialize.append(entity_id)
|
|
||||||
|
|
||||||
for entity_id in cannot_serialize:
|
|
||||||
del add_entities[entity_id]
|
|
||||||
|
|
||||||
connection.send_message(JSON_DUMP(messages.event_message(msg["id"], data)))
|
def _send_handle_entities_init_response(
|
||||||
|
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
|
||||||
|
) -> None:
|
||||||
|
"""Send handle entities init response."""
|
||||||
|
connection.send_message(
|
||||||
|
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE.replace(
|
||||||
|
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
|
||||||
|
).replace(
|
||||||
|
_STATES_JSON_TEMPLATE,
|
||||||
|
"{" + ",".join(serialized_states) + "}",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@decorators.websocket_command({vol.Required("type"): "get_services"})
|
@decorators.websocket_command({vol.Required("type"): "get_services"})
|
||||||
|
|
|
@ -44,7 +44,7 @@ ENTITY_EVENT_REMOVE = "r"
|
||||||
ENTITY_EVENT_CHANGE = "c"
|
ENTITY_EVENT_CHANGE = "c"
|
||||||
|
|
||||||
|
|
||||||
def result_message(iden: int, result: Any = None) -> dict[str, Any]:
|
def result_message(iden: JSON_TYPE | int, result: Any = None) -> dict[str, Any]:
|
||||||
"""Return a success result message."""
|
"""Return a success result message."""
|
||||||
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}
|
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}
|
||||||
|
|
||||||
|
|
|
@ -80,6 +80,7 @@ from .exceptions import (
|
||||||
Unauthorized,
|
Unauthorized,
|
||||||
)
|
)
|
||||||
from .helpers.aiohttp_compat import restore_original_aiohttp_cancel_behavior
|
from .helpers.aiohttp_compat import restore_original_aiohttp_cancel_behavior
|
||||||
|
from .helpers.json import json_dumps
|
||||||
from .util import dt as dt_util, location, ulid as ulid_util
|
from .util import dt as dt_util, location, ulid as ulid_util
|
||||||
from .util.async_ import run_callback_threadsafe, shutdown_run_callback_threadsafe
|
from .util.async_ import run_callback_threadsafe, shutdown_run_callback_threadsafe
|
||||||
from .util.read_only_dict import ReadOnlyDict
|
from .util.read_only_dict import ReadOnlyDict
|
||||||
|
@ -1224,6 +1225,8 @@ class State:
|
||||||
"object_id",
|
"object_id",
|
||||||
"_as_dict",
|
"_as_dict",
|
||||||
"_as_compressed_state",
|
"_as_compressed_state",
|
||||||
|
"_as_dict_json",
|
||||||
|
"_as_compressed_state_json",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -1260,6 +1263,8 @@ class State:
|
||||||
self.domain, self.object_id = split_entity_id(self.entity_id)
|
self.domain, self.object_id = split_entity_id(self.entity_id)
|
||||||
self._as_dict: ReadOnlyDict[str, Collection[Any]] | None = None
|
self._as_dict: ReadOnlyDict[str, Collection[Any]] | None = None
|
||||||
self._as_compressed_state: dict[str, Any] | None = None
|
self._as_compressed_state: dict[str, Any] | None = None
|
||||||
|
self._as_dict_json: str | None = None
|
||||||
|
self._as_compressed_state_json: str | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
|
@ -1294,6 +1299,12 @@ class State:
|
||||||
)
|
)
|
||||||
return self._as_dict
|
return self._as_dict
|
||||||
|
|
||||||
|
def as_dict_json(self) -> str:
|
||||||
|
"""Return a JSON string of the State."""
|
||||||
|
if not self._as_dict_json:
|
||||||
|
self._as_dict_json = json_dumps(self.as_dict())
|
||||||
|
return self._as_dict_json
|
||||||
|
|
||||||
def as_compressed_state(self) -> dict[str, Any]:
|
def as_compressed_state(self) -> dict[str, Any]:
|
||||||
"""Build a compressed dict of a state for adds.
|
"""Build a compressed dict of a state for adds.
|
||||||
|
|
||||||
|
@ -1321,6 +1332,19 @@ class State:
|
||||||
self._as_compressed_state = compressed_state
|
self._as_compressed_state = compressed_state
|
||||||
return compressed_state
|
return compressed_state
|
||||||
|
|
||||||
|
def as_compressed_state_json(self) -> str:
|
||||||
|
"""Build a compressed JSON key value pair of a state for adds.
|
||||||
|
|
||||||
|
The JSON string is a key value pair of the entity_id and the compressed state.
|
||||||
|
|
||||||
|
It is used for sending multiple states in a single message.
|
||||||
|
"""
|
||||||
|
if not self._as_compressed_state_json:
|
||||||
|
self._as_compressed_state_json = json_dumps(
|
||||||
|
{self.entity_id: self.as_compressed_state()}
|
||||||
|
)[1:-1]
|
||||||
|
return self._as_compressed_state_json
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, json_dict: dict[str, Any]) -> Self | None:
|
def from_dict(cls, json_dict: dict[str, Any]) -> Self | None:
|
||||||
"""Initialize a state from a dict.
|
"""Initialize a state from a dict.
|
||||||
|
|
|
@ -9,7 +9,6 @@ from typing import Any, Final
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
from homeassistant.core import Event, State
|
|
||||||
from homeassistant.util.file import write_utf8_file, write_utf8_file_atomic
|
from homeassistant.util.file import write_utf8_file, write_utf8_file_atomic
|
||||||
from homeassistant.util.json import ( # pylint: disable=unused-import # noqa: F401
|
from homeassistant.util.json import ( # pylint: disable=unused-import # noqa: F401
|
||||||
JSON_DECODE_EXCEPTIONS,
|
JSON_DECODE_EXCEPTIONS,
|
||||||
|
@ -189,6 +188,11 @@ def find_paths_unserializable_data(
|
||||||
|
|
||||||
This method is slow! Only use for error handling.
|
This method is slow! Only use for error handling.
|
||||||
"""
|
"""
|
||||||
|
from homeassistant.core import ( # pylint: disable=import-outside-toplevel
|
||||||
|
Event,
|
||||||
|
State,
|
||||||
|
)
|
||||||
|
|
||||||
to_process = deque([(bad_data, "$")])
|
to_process = deque([(bad_data, "$")])
|
||||||
invalid = {}
|
invalid = {}
|
||||||
|
|
||||||
|
|
|
@ -188,10 +188,9 @@ async def test_non_json_message(
|
||||||
assert msg["type"] == const.TYPE_RESULT
|
assert msg["type"] == const.TYPE_RESULT
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == []
|
assert msg["result"] == []
|
||||||
assert (
|
assert "Unable to serialize to JSON. Bad data found" in caplog.text
|
||||||
f"Unable to serialize to JSON. Bad data found at $.result[0](State: test_domain.entity).attributes.bad={bad_data}(<class 'object'>"
|
assert "State: test_domain.entity" in caplog.text
|
||||||
in caplog.text
|
assert "bad=<object" in caplog.text
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_prepare_fail(
|
async def test_prepare_fail(
|
||||||
|
|
|
@ -466,6 +466,29 @@ def test_state_as_dict() -> None:
|
||||||
assert state.as_dict() is as_dict_1
|
assert state.as_dict() is as_dict_1
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_as_dict_json() -> None:
|
||||||
|
"""Test a State as JSON."""
|
||||||
|
last_time = datetime(1984, 12, 8, 12, 0, 0)
|
||||||
|
state = ha.State(
|
||||||
|
"happy.happy",
|
||||||
|
"on",
|
||||||
|
{"pig": "dog"},
|
||||||
|
last_updated=last_time,
|
||||||
|
last_changed=last_time,
|
||||||
|
context=ha.Context(id="01H0D6K3RFJAYAV2093ZW30PCW"),
|
||||||
|
)
|
||||||
|
expected = (
|
||||||
|
'{"entity_id":"happy.happy","state":"on","attributes":{"pig":"dog"},'
|
||||||
|
'"last_changed":"1984-12-08T12:00:00","last_updated":"1984-12-08T12:00:00",'
|
||||||
|
'"context":{"id":"01H0D6K3RFJAYAV2093ZW30PCW","parent_id":null,"user_id":null}}'
|
||||||
|
)
|
||||||
|
as_dict_json_1 = state.as_dict_json()
|
||||||
|
assert as_dict_json_1 == expected
|
||||||
|
# 2nd time to verify cache
|
||||||
|
assert state.as_dict_json() == expected
|
||||||
|
assert state.as_dict_json() is as_dict_json_1
|
||||||
|
|
||||||
|
|
||||||
def test_state_as_compressed_state() -> None:
|
def test_state_as_compressed_state() -> None:
|
||||||
"""Test a State as compressed state."""
|
"""Test a State as compressed state."""
|
||||||
last_time = datetime(1984, 12, 8, 12, 0, 0, tzinfo=dt_util.UTC)
|
last_time = datetime(1984, 12, 8, 12, 0, 0, tzinfo=dt_util.UTC)
|
||||||
|
@ -518,6 +541,27 @@ def test_state_as_compressed_state_unique_last_updated() -> None:
|
||||||
assert state.as_compressed_state() is as_compressed_state
|
assert state.as_compressed_state() is as_compressed_state
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_as_compressed_state_json() -> None:
|
||||||
|
"""Test a State as a JSON compressed state."""
|
||||||
|
last_time = datetime(1984, 12, 8, 12, 0, 0, tzinfo=dt_util.UTC)
|
||||||
|
state = ha.State(
|
||||||
|
"happy.happy",
|
||||||
|
"on",
|
||||||
|
{"pig": "dog"},
|
||||||
|
last_updated=last_time,
|
||||||
|
last_changed=last_time,
|
||||||
|
context=ha.Context(id="01H0D6H5K3SZJ3XGDHED1TJ79N"),
|
||||||
|
)
|
||||||
|
expected = '"happy.happy":{"s":"on","a":{"pig":"dog"},"c":"01H0D6H5K3SZJ3XGDHED1TJ79N","lc":471355200.0}'
|
||||||
|
as_compressed_state = state.as_compressed_state_json()
|
||||||
|
# We are not too concerned about these being ReadOnlyDict
|
||||||
|
# since we don't expect them to be called by external callers
|
||||||
|
assert as_compressed_state == expected
|
||||||
|
# 2nd time to verify cache
|
||||||
|
assert state.as_compressed_state_json() == expected
|
||||||
|
assert state.as_compressed_state_json() is as_compressed_state
|
||||||
|
|
||||||
|
|
||||||
async def test_eventbus_add_remove_listener(hass: HomeAssistant) -> None:
|
async def test_eventbus_add_remove_listener(hass: HomeAssistant) -> None:
|
||||||
"""Test remove_listener method."""
|
"""Test remove_listener method."""
|
||||||
old_count = len(hass.bus.async_listeners())
|
old_count = len(hass.bus.async_listeners())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue