Cache generation of the service descriptions (#93131)
This commit is contained in:
parent
6c56ceead0
commit
b993fe1c9d
6 changed files with 86 additions and 45 deletions
|
@ -33,6 +33,7 @@ from homeassistant.helpers.json import (
|
|||
JSON_DUMP,
|
||||
ExtendedJSONEncoder,
|
||||
find_paths_unserializable_data,
|
||||
json_dumps,
|
||||
)
|
||||
from homeassistant.helpers.service import async_get_all_descriptions
|
||||
from homeassistant.loader import (
|
||||
|
@ -48,17 +49,9 @@ from homeassistant.util.json import format_unserializable_data
|
|||
from . import const, decorators, messages
|
||||
from .connection import ActiveConnection
|
||||
from .const import ERR_NOT_FOUND
|
||||
from .messages import construct_event_message, construct_result_message
|
||||
|
||||
_STATES_TEMPLATE = "__STATES__"
|
||||
_STATES_JSON_TEMPLATE = '"__STATES__"'
|
||||
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE = JSON_DUMP(
|
||||
messages.event_message(
|
||||
messages.IDEN_TEMPLATE, {messages.ENTITY_EVENT_ADD: _STATES_TEMPLATE}
|
||||
)
|
||||
)
|
||||
_HANDLE_GET_STATES_TEMPLATE = JSON_DUMP(
|
||||
messages.result_message(messages.IDEN_TEMPLATE, _STATES_TEMPLATE)
|
||||
)
|
||||
ALL_SERVICE_DESCRIPTIONS_JSON_CACHE = "websocket_api_all_service_descriptions_json"
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -280,15 +273,8 @@ def _send_handle_get_states_response(
|
|||
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
|
||||
) -> None:
|
||||
"""Send handle get states response."""
|
||||
connection.send_message(
|
||||
_HANDLE_GET_STATES_TEMPLATE.replace(
|
||||
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
|
||||
).replace(
|
||||
_STATES_JSON_TEMPLATE,
|
||||
"[" + ",".join(serialized_states) + "]",
|
||||
1,
|
||||
)
|
||||
)
|
||||
joined_states = ",".join(serialized_states)
|
||||
connection.send_message(construct_result_message(msg_id, f"[{joined_states}]"))
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -359,25 +345,35 @@ def _send_handle_entities_init_response(
|
|||
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
|
||||
) -> None:
|
||||
"""Send handle entities init response."""
|
||||
joined_states = ",".join(serialized_states)
|
||||
connection.send_message(
|
||||
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE.replace(
|
||||
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
|
||||
).replace(
|
||||
_STATES_JSON_TEMPLATE,
|
||||
"{" + ",".join(serialized_states) + "}",
|
||||
1,
|
||||
)
|
||||
construct_event_message(msg_id, f'{{"a":{{{joined_states}}}}}')
|
||||
)
|
||||
|
||||
|
||||
async def _async_get_all_descriptions_json(hass: HomeAssistant) -> str:
|
||||
"""Return JSON of descriptions (i.e. user documentation) for all service calls."""
|
||||
descriptions = await async_get_all_descriptions(hass)
|
||||
if ALL_SERVICE_DESCRIPTIONS_JSON_CACHE in hass.data:
|
||||
cached_descriptions, cached_json_payload = hass.data[
|
||||
ALL_SERVICE_DESCRIPTIONS_JSON_CACHE
|
||||
]
|
||||
# If the descriptions are the same, return the cached JSON payload
|
||||
if cached_descriptions is descriptions:
|
||||
return cast(str, cached_json_payload)
|
||||
json_payload = json_dumps(descriptions)
|
||||
hass.data[ALL_SERVICE_DESCRIPTIONS_JSON_CACHE] = (descriptions, json_payload)
|
||||
return json_payload
|
||||
|
||||
|
||||
@decorators.websocket_command({vol.Required("type"): "get_services"})
|
||||
@decorators.async_response
|
||||
async def handle_get_services(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle get services command."""
|
||||
descriptions = await async_get_all_descriptions(hass)
|
||||
connection.send_result(msg["id"], descriptions)
|
||||
payload = await _async_get_all_descriptions_json(hass)
|
||||
connection.send_message(construct_result_message(msg["id"], payload))
|
||||
|
||||
|
||||
@callback
|
||||
|
|
|
@ -135,7 +135,8 @@ class WebSocketHandler:
|
|||
)
|
||||
messages_remaining -= 1
|
||||
|
||||
coalesced_messages = "[" + ",".join(messages) + "]"
|
||||
joined_messages = ",".join(messages)
|
||||
coalesced_messages = f"[{joined_messages}]"
|
||||
debug("Sending %s", coalesced_messages)
|
||||
await send_str(coalesced_messages)
|
||||
finally:
|
||||
|
|
|
@ -18,7 +18,6 @@ from homeassistant.core import Event, State
|
|||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.json import JSON_DUMP, find_paths_unserializable_data
|
||||
from homeassistant.util.json import format_unserializable_data
|
||||
from homeassistant.util.yaml.loader import JSON_TYPE
|
||||
|
||||
from . import const
|
||||
|
||||
|
@ -44,11 +43,17 @@ ENTITY_EVENT_REMOVE = "r"
|
|||
ENTITY_EVENT_CHANGE = "c"
|
||||
|
||||
|
||||
def result_message(iden: JSON_TYPE | int, result: Any = None) -> dict[str, Any]:
|
||||
def result_message(iden: int, result: Any = None) -> dict[str, Any]:
|
||||
"""Return a success result message."""
|
||||
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}
|
||||
|
||||
|
||||
def construct_result_message(iden: int, payload: str) -> str:
|
||||
"""Construct a success result message JSON."""
|
||||
iden_str = str(iden)
|
||||
return f'{{"id":{iden_str},"type":"result","success":true,"result":{payload}}}'
|
||||
|
||||
|
||||
def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]:
|
||||
"""Return an error result message."""
|
||||
return {
|
||||
|
@ -59,7 +64,13 @@ def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def event_message(iden: JSON_TYPE | int, event: Any) -> dict[str, Any]:
|
||||
def construct_event_message(iden: int, payload: str) -> str:
|
||||
"""Construct an event message JSON."""
|
||||
iden_str = str(iden)
|
||||
return f'{{"id":{iden_str},"type":"event","event":{payload}}}'
|
||||
|
||||
|
||||
def event_message(iden: int, event: Any) -> dict[str, Any]:
|
||||
"""Return an event message."""
|
||||
return {"id": iden, "type": "event", "event": event}
|
||||
|
||||
|
@ -83,7 +94,7 @@ def _cached_event_message(event: Event) -> str:
|
|||
The IDEN_TEMPLATE is used which will be replaced
|
||||
with the actual iden in cached_event_message
|
||||
"""
|
||||
return message_to_json(event_message(IDEN_TEMPLATE, event))
|
||||
return message_to_json({"id": IDEN_TEMPLATE, "type": "event", "event": event})
|
||||
|
||||
|
||||
def cached_state_diff_message(iden: int, event: Event) -> str:
|
||||
|
@ -105,7 +116,9 @@ def _cached_state_diff_message(event: Event) -> str:
|
|||
The IDEN_TEMPLATE is used which will be replaced
|
||||
with the actual iden in cached_event_message
|
||||
"""
|
||||
return message_to_json(event_message(IDEN_TEMPLATE, _state_diff_event(event)))
|
||||
return message_to_json(
|
||||
{"id": IDEN_TEMPLATE, "type": "event", "event": _state_diff_event(event)}
|
||||
)
|
||||
|
||||
|
||||
def _state_diff_event(event: Event) -> dict:
|
||||
|
|
|
@ -59,6 +59,7 @@ CONF_SERVICE_ENTITY_ID = "entity_id"
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
|
||||
ALL_SERVICE_DESCRIPTIONS_CACHE = "all_service_descriptions_cache"
|
||||
|
||||
|
||||
@cache
|
||||
|
@ -559,17 +560,27 @@ async def async_get_all_descriptions(
|
|||
) -> dict[str, dict[str, Any]]:
|
||||
"""Return descriptions (i.e. user documentation) for all service calls."""
|
||||
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
|
||||
format_cache_key = "{}.{}".format
|
||||
services = hass.services.async_services()
|
||||
|
||||
# See if there are new services not seen before.
|
||||
# Any service that we saw before already has an entry in description_cache.
|
||||
missing = set()
|
||||
all_services = []
|
||||
for domain in services:
|
||||
for service in services[domain]:
|
||||
if format_cache_key(domain, service) not in descriptions_cache:
|
||||
cache_key = (domain, service)
|
||||
all_services.append(cache_key)
|
||||
if cache_key not in descriptions_cache:
|
||||
missing.add(domain)
|
||||
break
|
||||
|
||||
# If we have a complete cache, check if it is still valid
|
||||
if ALL_SERVICE_DESCRIPTIONS_CACHE in hass.data:
|
||||
previous_all_services, previous_descriptions_cache = hass.data[
|
||||
ALL_SERVICE_DESCRIPTIONS_CACHE
|
||||
]
|
||||
# If the services are the same, we can return the cache
|
||||
if previous_all_services == all_services:
|
||||
return cast(dict[str, dict[str, Any]], previous_descriptions_cache)
|
||||
|
||||
# Files we loaded for missing descriptions
|
||||
loaded = {}
|
||||
|
@ -595,7 +606,7 @@ async def async_get_all_descriptions(
|
|||
descriptions[domain] = {}
|
||||
|
||||
for service in services[domain]:
|
||||
cache_key = format_cache_key(domain, service)
|
||||
cache_key = (domain, service)
|
||||
description = descriptions_cache.get(cache_key)
|
||||
|
||||
# Cache missing descriptions
|
||||
|
@ -622,6 +633,7 @@ async def async_get_all_descriptions(
|
|||
|
||||
descriptions[domain][service] = description
|
||||
|
||||
hass.data[ALL_SERVICE_DESCRIPTIONS_CACHE] = (all_services, descriptions)
|
||||
return descriptions
|
||||
|
||||
|
||||
|
@ -652,7 +664,8 @@ def async_set_service_schema(
|
|||
if "target" in schema:
|
||||
description["target"] = schema["target"]
|
||||
|
||||
hass.data[SERVICE_DESCRIPTION_CACHE][f"{domain}.{service}"] = description
|
||||
hass.data.pop(ALL_SERVICE_DESCRIPTIONS_CACHE, None)
|
||||
hass.data[SERVICE_DESCRIPTION_CACHE][(domain, service)] = description
|
||||
|
||||
|
||||
@bind_hass
|
||||
|
|
|
@ -514,10 +514,11 @@ async def test_get_states(hass: HomeAssistant, websocket_client) -> None:
|
|||
|
||||
async def test_get_services(hass: HomeAssistant, websocket_client) -> None:
|
||||
"""Test get_services command."""
|
||||
await websocket_client.send_json({"id": 5, "type": "get_services"})
|
||||
for id_ in (5, 6):
|
||||
await websocket_client.send_json({"id": id_, "type": "get_services"})
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 5
|
||||
assert msg["id"] == id_
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
assert msg["result"] == hass.services.async_services()
|
||||
|
|
|
@ -564,6 +564,23 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
|
|||
assert "description" in descriptions[logger.DOMAIN]["set_level"]
|
||||
assert "fields" in descriptions[logger.DOMAIN]["set_level"]
|
||||
|
||||
hass.services.async_register(logger.DOMAIN, "new_service", lambda x: None, None)
|
||||
service.async_set_service_schema(
|
||||
hass, logger.DOMAIN, "new_service", {"description": "new service"}
|
||||
)
|
||||
descriptions = await service.async_get_all_descriptions(hass)
|
||||
assert "description" in descriptions[logger.DOMAIN]["new_service"]
|
||||
assert descriptions[logger.DOMAIN]["new_service"]["description"] == "new service"
|
||||
|
||||
hass.services.async_register(
|
||||
logger.DOMAIN, "another_new_service", lambda x: None, None
|
||||
)
|
||||
descriptions = await service.async_get_all_descriptions(hass)
|
||||
assert "another_new_service" in descriptions[logger.DOMAIN]
|
||||
|
||||
# Verify the cache returns the same object
|
||||
assert await service.async_get_all_descriptions(hass) is descriptions
|
||||
|
||||
|
||||
async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -> None:
|
||||
"""Test service calls invoked only if entity has required features."""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue