Cache generation of the service descriptions (#93131)

This commit is contained in:
J. Nick Koston 2023-05-16 21:42:37 -05:00 committed by GitHub
parent 6c56ceead0
commit b993fe1c9d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 86 additions and 45 deletions

View file

@ -33,6 +33,7 @@ from homeassistant.helpers.json import (
JSON_DUMP, JSON_DUMP,
ExtendedJSONEncoder, ExtendedJSONEncoder,
find_paths_unserializable_data, find_paths_unserializable_data,
json_dumps,
) )
from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import ( from homeassistant.loader import (
@ -48,17 +49,9 @@ from homeassistant.util.json import format_unserializable_data
from . import const, decorators, messages from . import const, decorators, messages
from .connection import ActiveConnection from .connection import ActiveConnection
from .const import ERR_NOT_FOUND from .const import ERR_NOT_FOUND
from .messages import construct_event_message, construct_result_message
_STATES_TEMPLATE = "__STATES__" ALL_SERVICE_DESCRIPTIONS_JSON_CACHE = "websocket_api_all_service_descriptions_json"
_STATES_JSON_TEMPLATE = '"__STATES__"'
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE = JSON_DUMP(
messages.event_message(
messages.IDEN_TEMPLATE, {messages.ENTITY_EVENT_ADD: _STATES_TEMPLATE}
)
)
_HANDLE_GET_STATES_TEMPLATE = JSON_DUMP(
messages.result_message(messages.IDEN_TEMPLATE, _STATES_TEMPLATE)
)
@callback @callback
@ -280,15 +273,8 @@ def _send_handle_get_states_response(
connection: ActiveConnection, msg_id: int, serialized_states: list[str] connection: ActiveConnection, msg_id: int, serialized_states: list[str]
) -> None: ) -> None:
"""Send handle get states response.""" """Send handle get states response."""
connection.send_message( joined_states = ",".join(serialized_states)
_HANDLE_GET_STATES_TEMPLATE.replace( connection.send_message(construct_result_message(msg_id, f"[{joined_states}]"))
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
).replace(
_STATES_JSON_TEMPLATE,
"[" + ",".join(serialized_states) + "]",
1,
)
)
@callback @callback
@ -359,25 +345,35 @@ def _send_handle_entities_init_response(
connection: ActiveConnection, msg_id: int, serialized_states: list[str] connection: ActiveConnection, msg_id: int, serialized_states: list[str]
) -> None: ) -> None:
"""Send handle entities init response.""" """Send handle entities init response."""
joined_states = ",".join(serialized_states)
connection.send_message( connection.send_message(
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE.replace( construct_event_message(msg_id, f'{{"a":{{{joined_states}}}}}')
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
).replace(
_STATES_JSON_TEMPLATE,
"{" + ",".join(serialized_states) + "}",
1,
)
) )
async def _async_get_all_descriptions_json(hass: HomeAssistant) -> str:
"""Return JSON of descriptions (i.e. user documentation) for all service calls."""
descriptions = await async_get_all_descriptions(hass)
if ALL_SERVICE_DESCRIPTIONS_JSON_CACHE in hass.data:
cached_descriptions, cached_json_payload = hass.data[
ALL_SERVICE_DESCRIPTIONS_JSON_CACHE
]
# If the descriptions are the same, return the cached JSON payload
if cached_descriptions is descriptions:
return cast(str, cached_json_payload)
json_payload = json_dumps(descriptions)
hass.data[ALL_SERVICE_DESCRIPTIONS_JSON_CACHE] = (descriptions, json_payload)
return json_payload
@decorators.websocket_command({vol.Required("type"): "get_services"}) @decorators.websocket_command({vol.Required("type"): "get_services"})
@decorators.async_response @decorators.async_response
async def handle_get_services( async def handle_get_services(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Handle get services command.""" """Handle get services command."""
descriptions = await async_get_all_descriptions(hass) payload = await _async_get_all_descriptions_json(hass)
connection.send_result(msg["id"], descriptions) connection.send_message(construct_result_message(msg["id"], payload))
@callback @callback

View file

@ -135,7 +135,8 @@ class WebSocketHandler:
) )
messages_remaining -= 1 messages_remaining -= 1
coalesced_messages = "[" + ",".join(messages) + "]" joined_messages = ",".join(messages)
coalesced_messages = f"[{joined_messages}]"
debug("Sending %s", coalesced_messages) debug("Sending %s", coalesced_messages)
await send_str(coalesced_messages) await send_str(coalesced_messages)
finally: finally:

View file

@ -18,7 +18,6 @@ from homeassistant.core import Event, State
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.json import JSON_DUMP, find_paths_unserializable_data from homeassistant.helpers.json import JSON_DUMP, find_paths_unserializable_data
from homeassistant.util.json import format_unserializable_data from homeassistant.util.json import format_unserializable_data
from homeassistant.util.yaml.loader import JSON_TYPE
from . import const from . import const
@ -44,11 +43,17 @@ ENTITY_EVENT_REMOVE = "r"
ENTITY_EVENT_CHANGE = "c" ENTITY_EVENT_CHANGE = "c"
def result_message(iden: JSON_TYPE | int, result: Any = None) -> dict[str, Any]: def result_message(iden: int, result: Any = None) -> dict[str, Any]:
"""Return a success result message.""" """Return a success result message."""
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result} return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}
def construct_result_message(iden: int, payload: str) -> str:
"""Construct a success result message JSON."""
iden_str = str(iden)
return f'{{"id":{iden_str},"type":"result","success":true,"result":{payload}}}'
def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]: def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]:
"""Return an error result message.""" """Return an error result message."""
return { return {
@ -59,7 +64,13 @@ def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]:
} }
def event_message(iden: JSON_TYPE | int, event: Any) -> dict[str, Any]: def construct_event_message(iden: int, payload: str) -> str:
"""Construct an event message JSON."""
iden_str = str(iden)
return f'{{"id":{iden_str},"type":"event","event":{payload}}}'
def event_message(iden: int, event: Any) -> dict[str, Any]:
"""Return an event message.""" """Return an event message."""
return {"id": iden, "type": "event", "event": event} return {"id": iden, "type": "event", "event": event}
@ -83,7 +94,7 @@ def _cached_event_message(event: Event) -> str:
The IDEN_TEMPLATE is used which will be replaced The IDEN_TEMPLATE is used which will be replaced
with the actual iden in cached_event_message with the actual iden in cached_event_message
""" """
return message_to_json(event_message(IDEN_TEMPLATE, event)) return message_to_json({"id": IDEN_TEMPLATE, "type": "event", "event": event})
def cached_state_diff_message(iden: int, event: Event) -> str: def cached_state_diff_message(iden: int, event: Event) -> str:
@ -105,7 +116,9 @@ def _cached_state_diff_message(event: Event) -> str:
The IDEN_TEMPLATE is used which will be replaced The IDEN_TEMPLATE is used which will be replaced
with the actual iden in cached_event_message with the actual iden in cached_event_message
""" """
return message_to_json(event_message(IDEN_TEMPLATE, _state_diff_event(event))) return message_to_json(
{"id": IDEN_TEMPLATE, "type": "event", "event": _state_diff_event(event)}
)
def _state_diff_event(event: Event) -> dict: def _state_diff_event(event: Event) -> dict:

View file

@ -59,6 +59,7 @@ CONF_SERVICE_ENTITY_ID = "entity_id"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SERVICE_DESCRIPTION_CACHE = "service_description_cache" SERVICE_DESCRIPTION_CACHE = "service_description_cache"
ALL_SERVICE_DESCRIPTIONS_CACHE = "all_service_descriptions_cache"
@cache @cache
@ -559,17 +560,27 @@ async def async_get_all_descriptions(
) -> dict[str, dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
"""Return descriptions (i.e. user documentation) for all service calls.""" """Return descriptions (i.e. user documentation) for all service calls."""
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {}) descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
format_cache_key = "{}.{}".format
services = hass.services.async_services() services = hass.services.async_services()
# See if there are new services not seen before. # See if there are new services not seen before.
# Any service that we saw before already has an entry in description_cache. # Any service that we saw before already has an entry in description_cache.
missing = set() missing = set()
all_services = []
for domain in services: for domain in services:
for service in services[domain]: for service in services[domain]:
if format_cache_key(domain, service) not in descriptions_cache: cache_key = (domain, service)
all_services.append(cache_key)
if cache_key not in descriptions_cache:
missing.add(domain) missing.add(domain)
break
# If we have a complete cache, check if it is still valid
if ALL_SERVICE_DESCRIPTIONS_CACHE in hass.data:
previous_all_services, previous_descriptions_cache = hass.data[
ALL_SERVICE_DESCRIPTIONS_CACHE
]
# If the services are the same, we can return the cache
if previous_all_services == all_services:
return cast(dict[str, dict[str, Any]], previous_descriptions_cache)
# Files we loaded for missing descriptions # Files we loaded for missing descriptions
loaded = {} loaded = {}
@ -595,7 +606,7 @@ async def async_get_all_descriptions(
descriptions[domain] = {} descriptions[domain] = {}
for service in services[domain]: for service in services[domain]:
cache_key = format_cache_key(domain, service) cache_key = (domain, service)
description = descriptions_cache.get(cache_key) description = descriptions_cache.get(cache_key)
# Cache missing descriptions # Cache missing descriptions
@ -622,6 +633,7 @@ async def async_get_all_descriptions(
descriptions[domain][service] = description descriptions[domain][service] = description
hass.data[ALL_SERVICE_DESCRIPTIONS_CACHE] = (all_services, descriptions)
return descriptions return descriptions
@ -652,7 +664,8 @@ def async_set_service_schema(
if "target" in schema: if "target" in schema:
description["target"] = schema["target"] description["target"] = schema["target"]
hass.data[SERVICE_DESCRIPTION_CACHE][f"{domain}.{service}"] = description hass.data.pop(ALL_SERVICE_DESCRIPTIONS_CACHE, None)
hass.data[SERVICE_DESCRIPTION_CACHE][(domain, service)] = description
@bind_hass @bind_hass

View file

@ -514,13 +514,14 @@ async def test_get_states(hass: HomeAssistant, websocket_client) -> None:
async def test_get_services(hass: HomeAssistant, websocket_client) -> None: async def test_get_services(hass: HomeAssistant, websocket_client) -> None:
"""Test get_services command.""" """Test get_services command."""
await websocket_client.send_json({"id": 5, "type": "get_services"}) for id_ in (5, 6):
await websocket_client.send_json({"id": id_, "type": "get_services"})
msg = await websocket_client.receive_json() msg = await websocket_client.receive_json()
assert msg["id"] == 5 assert msg["id"] == id_
assert msg["type"] == const.TYPE_RESULT assert msg["type"] == const.TYPE_RESULT
assert msg["success"] assert msg["success"]
assert msg["result"] == hass.services.async_services() assert msg["result"] == hass.services.async_services()
async def test_get_config(hass: HomeAssistant, websocket_client) -> None: async def test_get_config(hass: HomeAssistant, websocket_client) -> None:

View file

@ -564,6 +564,23 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
assert "description" in descriptions[logger.DOMAIN]["set_level"] assert "description" in descriptions[logger.DOMAIN]["set_level"]
assert "fields" in descriptions[logger.DOMAIN]["set_level"] assert "fields" in descriptions[logger.DOMAIN]["set_level"]
hass.services.async_register(logger.DOMAIN, "new_service", lambda x: None, None)
service.async_set_service_schema(
hass, logger.DOMAIN, "new_service", {"description": "new service"}
)
descriptions = await service.async_get_all_descriptions(hass)
assert "description" in descriptions[logger.DOMAIN]["new_service"]
assert descriptions[logger.DOMAIN]["new_service"]["description"] == "new service"
hass.services.async_register(
logger.DOMAIN, "another_new_service", lambda x: None, None
)
descriptions = await service.async_get_all_descriptions(hass)
assert "another_new_service" in descriptions[logger.DOMAIN]
# Verify the cache returns the same object
assert await service.async_get_all_descriptions(hass) is descriptions
async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -> None: async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -> None:
"""Test service calls invoked only if entity has required features.""" """Test service calls invoked only if entity has required features."""