Cache generation of the service descriptions (#93131)

This commit is contained in:
J. Nick Koston 2023-05-16 21:42:37 -05:00 committed by GitHub
parent 6c56ceead0
commit b993fe1c9d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 86 additions and 45 deletions

View file

@ -33,6 +33,7 @@ from homeassistant.helpers.json import (
JSON_DUMP,
ExtendedJSONEncoder,
find_paths_unserializable_data,
json_dumps,
)
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import (
@ -48,17 +49,9 @@ from homeassistant.util.json import format_unserializable_data
from . import const, decorators, messages
from .connection import ActiveConnection
from .const import ERR_NOT_FOUND
from .messages import construct_event_message, construct_result_message
_STATES_TEMPLATE = "__STATES__"
_STATES_JSON_TEMPLATE = '"__STATES__"'
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE = JSON_DUMP(
messages.event_message(
messages.IDEN_TEMPLATE, {messages.ENTITY_EVENT_ADD: _STATES_TEMPLATE}
)
)
_HANDLE_GET_STATES_TEMPLATE = JSON_DUMP(
messages.result_message(messages.IDEN_TEMPLATE, _STATES_TEMPLATE)
)
ALL_SERVICE_DESCRIPTIONS_JSON_CACHE = "websocket_api_all_service_descriptions_json"
@callback
@ -280,15 +273,8 @@ def _send_handle_get_states_response(
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
) -> None:
"""Send handle get states response."""
connection.send_message(
_HANDLE_GET_STATES_TEMPLATE.replace(
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
).replace(
_STATES_JSON_TEMPLATE,
"[" + ",".join(serialized_states) + "]",
1,
)
)
joined_states = ",".join(serialized_states)
connection.send_message(construct_result_message(msg_id, f"[{joined_states}]"))
@callback
@ -359,25 +345,35 @@ def _send_handle_entities_init_response(
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
) -> None:
"""Send handle entities init response."""
joined_states = ",".join(serialized_states)
connection.send_message(
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE.replace(
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
).replace(
_STATES_JSON_TEMPLATE,
"{" + ",".join(serialized_states) + "}",
1,
)
construct_event_message(msg_id, f'{{"a":{{{joined_states}}}}}')
)
async def _async_get_all_descriptions_json(hass: HomeAssistant) -> str:
"""Return JSON of descriptions (i.e. user documentation) for all service calls."""
descriptions = await async_get_all_descriptions(hass)
if ALL_SERVICE_DESCRIPTIONS_JSON_CACHE in hass.data:
cached_descriptions, cached_json_payload = hass.data[
ALL_SERVICE_DESCRIPTIONS_JSON_CACHE
]
# If the descriptions are the same, return the cached JSON payload
if cached_descriptions is descriptions:
return cast(str, cached_json_payload)
json_payload = json_dumps(descriptions)
hass.data[ALL_SERVICE_DESCRIPTIONS_JSON_CACHE] = (descriptions, json_payload)
return json_payload
@decorators.websocket_command({vol.Required("type"): "get_services"})
@decorators.async_response
async def handle_get_services(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle get services command."""
descriptions = await async_get_all_descriptions(hass)
connection.send_result(msg["id"], descriptions)
payload = await _async_get_all_descriptions_json(hass)
connection.send_message(construct_result_message(msg["id"], payload))
@callback

View file

@ -135,7 +135,8 @@ class WebSocketHandler:
)
messages_remaining -= 1
coalesced_messages = "[" + ",".join(messages) + "]"
joined_messages = ",".join(messages)
coalesced_messages = f"[{joined_messages}]"
debug("Sending %s", coalesced_messages)
await send_str(coalesced_messages)
finally:

View file

@ -18,7 +18,6 @@ from homeassistant.core import Event, State
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.json import JSON_DUMP, find_paths_unserializable_data
from homeassistant.util.json import format_unserializable_data
from homeassistant.util.yaml.loader import JSON_TYPE
from . import const
@ -44,11 +43,17 @@ ENTITY_EVENT_REMOVE = "r"
ENTITY_EVENT_CHANGE = "c"
def result_message(iden: JSON_TYPE | int, result: Any = None) -> dict[str, Any]:
def result_message(iden: int, result: Any = None) -> dict[str, Any]:
"""Return a success result message."""
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}
def construct_result_message(iden: int, payload: str) -> str:
"""Construct a success result message JSON."""
iden_str = str(iden)
return f'{{"id":{iden_str},"type":"result","success":true,"result":{payload}}}'
def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]:
"""Return an error result message."""
return {
@ -59,7 +64,13 @@ def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]:
}
def event_message(iden: JSON_TYPE | int, event: Any) -> dict[str, Any]:
def construct_event_message(iden: int, payload: str) -> str:
"""Construct an event message JSON."""
iden_str = str(iden)
return f'{{"id":{iden_str},"type":"event","event":{payload}}}'
def event_message(iden: int, event: Any) -> dict[str, Any]:
"""Return an event message."""
return {"id": iden, "type": "event", "event": event}
@ -83,7 +94,7 @@ def _cached_event_message(event: Event) -> str:
The IDEN_TEMPLATE is used which will be replaced
with the actual iden in cached_event_message
"""
return message_to_json(event_message(IDEN_TEMPLATE, event))
return message_to_json({"id": IDEN_TEMPLATE, "type": "event", "event": event})
def cached_state_diff_message(iden: int, event: Event) -> str:
@ -105,7 +116,9 @@ def _cached_state_diff_message(event: Event) -> str:
The IDEN_TEMPLATE is used which will be replaced
with the actual iden in cached_event_message
"""
return message_to_json(event_message(IDEN_TEMPLATE, _state_diff_event(event)))
return message_to_json(
{"id": IDEN_TEMPLATE, "type": "event", "event": _state_diff_event(event)}
)
def _state_diff_event(event: Event) -> dict:

View file

@ -59,6 +59,7 @@ CONF_SERVICE_ENTITY_ID = "entity_id"
_LOGGER = logging.getLogger(__name__)
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
ALL_SERVICE_DESCRIPTIONS_CACHE = "all_service_descriptions_cache"
@cache
@ -559,17 +560,27 @@ async def async_get_all_descriptions(
) -> dict[str, dict[str, Any]]:
"""Return descriptions (i.e. user documentation) for all service calls."""
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
format_cache_key = "{}.{}".format
services = hass.services.async_services()
# See if there are new services not seen before.
# Any service that we saw before already has an entry in description_cache.
missing = set()
all_services = []
for domain in services:
for service in services[domain]:
if format_cache_key(domain, service) not in descriptions_cache:
cache_key = (domain, service)
all_services.append(cache_key)
if cache_key not in descriptions_cache:
missing.add(domain)
break
# If we have a complete cache, check if it is still valid
if ALL_SERVICE_DESCRIPTIONS_CACHE in hass.data:
previous_all_services, previous_descriptions_cache = hass.data[
ALL_SERVICE_DESCRIPTIONS_CACHE
]
# If the services are the same, we can return the cache
if previous_all_services == all_services:
return cast(dict[str, dict[str, Any]], previous_descriptions_cache)
# Files we loaded for missing descriptions
loaded = {}
@ -595,7 +606,7 @@ async def async_get_all_descriptions(
descriptions[domain] = {}
for service in services[domain]:
cache_key = format_cache_key(domain, service)
cache_key = (domain, service)
description = descriptions_cache.get(cache_key)
# Cache missing descriptions
@ -622,6 +633,7 @@ async def async_get_all_descriptions(
descriptions[domain][service] = description
hass.data[ALL_SERVICE_DESCRIPTIONS_CACHE] = (all_services, descriptions)
return descriptions
@ -652,7 +664,8 @@ def async_set_service_schema(
if "target" in schema:
description["target"] = schema["target"]
hass.data[SERVICE_DESCRIPTION_CACHE][f"{domain}.{service}"] = description
hass.data.pop(ALL_SERVICE_DESCRIPTIONS_CACHE, None)
hass.data[SERVICE_DESCRIPTION_CACHE][(domain, service)] = description
@bind_hass

View file

@ -514,10 +514,11 @@ async def test_get_states(hass: HomeAssistant, websocket_client) -> None:
async def test_get_services(hass: HomeAssistant, websocket_client) -> None:
"""Test get_services command."""
await websocket_client.send_json({"id": 5, "type": "get_services"})
for id_ in (5, 6):
await websocket_client.send_json({"id": id_, "type": "get_services"})
msg = await websocket_client.receive_json()
assert msg["id"] == 5
assert msg["id"] == id_
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == hass.services.async_services()

View file

@ -564,6 +564,23 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
assert "description" in descriptions[logger.DOMAIN]["set_level"]
assert "fields" in descriptions[logger.DOMAIN]["set_level"]
hass.services.async_register(logger.DOMAIN, "new_service", lambda x: None, None)
service.async_set_service_schema(
hass, logger.DOMAIN, "new_service", {"description": "new service"}
)
descriptions = await service.async_get_all_descriptions(hass)
assert "description" in descriptions[logger.DOMAIN]["new_service"]
assert descriptions[logger.DOMAIN]["new_service"]["description"] == "new service"
hass.services.async_register(
logger.DOMAIN, "another_new_service", lambda x: None, None
)
descriptions = await service.async_get_all_descriptions(hass)
assert "another_new_service" in descriptions[logger.DOMAIN]
# Verify the cache returns the same object
assert await service.async_get_all_descriptions(hass) is descriptions
async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -> None:
"""Test service calls invoked only if entity has required features."""