Cache generation of the service descriptions (#93131)
This commit is contained in:
parent
6c56ceead0
commit
b993fe1c9d
6 changed files with 86 additions and 45 deletions
|
@ -33,6 +33,7 @@ from homeassistant.helpers.json import (
|
||||||
JSON_DUMP,
|
JSON_DUMP,
|
||||||
ExtendedJSONEncoder,
|
ExtendedJSONEncoder,
|
||||||
find_paths_unserializable_data,
|
find_paths_unserializable_data,
|
||||||
|
json_dumps,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.service import async_get_all_descriptions
|
from homeassistant.helpers.service import async_get_all_descriptions
|
||||||
from homeassistant.loader import (
|
from homeassistant.loader import (
|
||||||
|
@ -48,17 +49,9 @@ from homeassistant.util.json import format_unserializable_data
|
||||||
from . import const, decorators, messages
|
from . import const, decorators, messages
|
||||||
from .connection import ActiveConnection
|
from .connection import ActiveConnection
|
||||||
from .const import ERR_NOT_FOUND
|
from .const import ERR_NOT_FOUND
|
||||||
|
from .messages import construct_event_message, construct_result_message
|
||||||
|
|
||||||
_STATES_TEMPLATE = "__STATES__"
|
ALL_SERVICE_DESCRIPTIONS_JSON_CACHE = "websocket_api_all_service_descriptions_json"
|
||||||
_STATES_JSON_TEMPLATE = '"__STATES__"'
|
|
||||||
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE = JSON_DUMP(
|
|
||||||
messages.event_message(
|
|
||||||
messages.IDEN_TEMPLATE, {messages.ENTITY_EVENT_ADD: _STATES_TEMPLATE}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
_HANDLE_GET_STATES_TEMPLATE = JSON_DUMP(
|
|
||||||
messages.result_message(messages.IDEN_TEMPLATE, _STATES_TEMPLATE)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -280,15 +273,8 @@ def _send_handle_get_states_response(
|
||||||
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
|
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Send handle get states response."""
|
"""Send handle get states response."""
|
||||||
connection.send_message(
|
joined_states = ",".join(serialized_states)
|
||||||
_HANDLE_GET_STATES_TEMPLATE.replace(
|
connection.send_message(construct_result_message(msg_id, f"[{joined_states}]"))
|
||||||
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
|
|
||||||
).replace(
|
|
||||||
_STATES_JSON_TEMPLATE,
|
|
||||||
"[" + ",".join(serialized_states) + "]",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -359,25 +345,35 @@ def _send_handle_entities_init_response(
|
||||||
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
|
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Send handle entities init response."""
|
"""Send handle entities init response."""
|
||||||
|
joined_states = ",".join(serialized_states)
|
||||||
connection.send_message(
|
connection.send_message(
|
||||||
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE.replace(
|
construct_event_message(msg_id, f'{{"a":{{{joined_states}}}}}')
|
||||||
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
|
|
||||||
).replace(
|
|
||||||
_STATES_JSON_TEMPLATE,
|
|
||||||
"{" + ",".join(serialized_states) + "}",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_get_all_descriptions_json(hass: HomeAssistant) -> str:
|
||||||
|
"""Return JSON of descriptions (i.e. user documentation) for all service calls."""
|
||||||
|
descriptions = await async_get_all_descriptions(hass)
|
||||||
|
if ALL_SERVICE_DESCRIPTIONS_JSON_CACHE in hass.data:
|
||||||
|
cached_descriptions, cached_json_payload = hass.data[
|
||||||
|
ALL_SERVICE_DESCRIPTIONS_JSON_CACHE
|
||||||
|
]
|
||||||
|
# If the descriptions are the same, return the cached JSON payload
|
||||||
|
if cached_descriptions is descriptions:
|
||||||
|
return cast(str, cached_json_payload)
|
||||||
|
json_payload = json_dumps(descriptions)
|
||||||
|
hass.data[ALL_SERVICE_DESCRIPTIONS_JSON_CACHE] = (descriptions, json_payload)
|
||||||
|
return json_payload
|
||||||
|
|
||||||
|
|
||||||
@decorators.websocket_command({vol.Required("type"): "get_services"})
|
@decorators.websocket_command({vol.Required("type"): "get_services"})
|
||||||
@decorators.async_response
|
@decorators.async_response
|
||||||
async def handle_get_services(
|
async def handle_get_services(
|
||||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle get services command."""
|
"""Handle get services command."""
|
||||||
descriptions = await async_get_all_descriptions(hass)
|
payload = await _async_get_all_descriptions_json(hass)
|
||||||
connection.send_result(msg["id"], descriptions)
|
connection.send_message(construct_result_message(msg["id"], payload))
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
|
|
@ -135,7 +135,8 @@ class WebSocketHandler:
|
||||||
)
|
)
|
||||||
messages_remaining -= 1
|
messages_remaining -= 1
|
||||||
|
|
||||||
coalesced_messages = "[" + ",".join(messages) + "]"
|
joined_messages = ",".join(messages)
|
||||||
|
coalesced_messages = f"[{joined_messages}]"
|
||||||
debug("Sending %s", coalesced_messages)
|
debug("Sending %s", coalesced_messages)
|
||||||
await send_str(coalesced_messages)
|
await send_str(coalesced_messages)
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -18,7 +18,6 @@ from homeassistant.core import Event, State
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers.json import JSON_DUMP, find_paths_unserializable_data
|
from homeassistant.helpers.json import JSON_DUMP, find_paths_unserializable_data
|
||||||
from homeassistant.util.json import format_unserializable_data
|
from homeassistant.util.json import format_unserializable_data
|
||||||
from homeassistant.util.yaml.loader import JSON_TYPE
|
|
||||||
|
|
||||||
from . import const
|
from . import const
|
||||||
|
|
||||||
|
@ -44,11 +43,17 @@ ENTITY_EVENT_REMOVE = "r"
|
||||||
ENTITY_EVENT_CHANGE = "c"
|
ENTITY_EVENT_CHANGE = "c"
|
||||||
|
|
||||||
|
|
||||||
def result_message(iden: JSON_TYPE | int, result: Any = None) -> dict[str, Any]:
|
def result_message(iden: int, result: Any = None) -> dict[str, Any]:
|
||||||
"""Return a success result message."""
|
"""Return a success result message."""
|
||||||
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}
|
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}
|
||||||
|
|
||||||
|
|
||||||
|
def construct_result_message(iden: int, payload: str) -> str:
|
||||||
|
"""Construct a success result message JSON."""
|
||||||
|
iden_str = str(iden)
|
||||||
|
return f'{{"id":{iden_str},"type":"result","success":true,"result":{payload}}}'
|
||||||
|
|
||||||
|
|
||||||
def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]:
|
def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]:
|
||||||
"""Return an error result message."""
|
"""Return an error result message."""
|
||||||
return {
|
return {
|
||||||
|
@ -59,7 +64,13 @@ def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def event_message(iden: JSON_TYPE | int, event: Any) -> dict[str, Any]:
|
def construct_event_message(iden: int, payload: str) -> str:
|
||||||
|
"""Construct an event message JSON."""
|
||||||
|
iden_str = str(iden)
|
||||||
|
return f'{{"id":{iden_str},"type":"event","event":{payload}}}'
|
||||||
|
|
||||||
|
|
||||||
|
def event_message(iden: int, event: Any) -> dict[str, Any]:
|
||||||
"""Return an event message."""
|
"""Return an event message."""
|
||||||
return {"id": iden, "type": "event", "event": event}
|
return {"id": iden, "type": "event", "event": event}
|
||||||
|
|
||||||
|
@ -83,7 +94,7 @@ def _cached_event_message(event: Event) -> str:
|
||||||
The IDEN_TEMPLATE is used which will be replaced
|
The IDEN_TEMPLATE is used which will be replaced
|
||||||
with the actual iden in cached_event_message
|
with the actual iden in cached_event_message
|
||||||
"""
|
"""
|
||||||
return message_to_json(event_message(IDEN_TEMPLATE, event))
|
return message_to_json({"id": IDEN_TEMPLATE, "type": "event", "event": event})
|
||||||
|
|
||||||
|
|
||||||
def cached_state_diff_message(iden: int, event: Event) -> str:
|
def cached_state_diff_message(iden: int, event: Event) -> str:
|
||||||
|
@ -105,7 +116,9 @@ def _cached_state_diff_message(event: Event) -> str:
|
||||||
The IDEN_TEMPLATE is used which will be replaced
|
The IDEN_TEMPLATE is used which will be replaced
|
||||||
with the actual iden in cached_event_message
|
with the actual iden in cached_event_message
|
||||||
"""
|
"""
|
||||||
return message_to_json(event_message(IDEN_TEMPLATE, _state_diff_event(event)))
|
return message_to_json(
|
||||||
|
{"id": IDEN_TEMPLATE, "type": "event", "event": _state_diff_event(event)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _state_diff_event(event: Event) -> dict:
|
def _state_diff_event(event: Event) -> dict:
|
||||||
|
|
|
@ -59,6 +59,7 @@ CONF_SERVICE_ENTITY_ID = "entity_id"
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
|
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
|
||||||
|
ALL_SERVICE_DESCRIPTIONS_CACHE = "all_service_descriptions_cache"
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
|
@ -559,17 +560,27 @@ async def async_get_all_descriptions(
|
||||||
) -> dict[str, dict[str, Any]]:
|
) -> dict[str, dict[str, Any]]:
|
||||||
"""Return descriptions (i.e. user documentation) for all service calls."""
|
"""Return descriptions (i.e. user documentation) for all service calls."""
|
||||||
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
|
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
|
||||||
format_cache_key = "{}.{}".format
|
|
||||||
services = hass.services.async_services()
|
services = hass.services.async_services()
|
||||||
|
|
||||||
# See if there are new services not seen before.
|
# See if there are new services not seen before.
|
||||||
# Any service that we saw before already has an entry in description_cache.
|
# Any service that we saw before already has an entry in description_cache.
|
||||||
missing = set()
|
missing = set()
|
||||||
|
all_services = []
|
||||||
for domain in services:
|
for domain in services:
|
||||||
for service in services[domain]:
|
for service in services[domain]:
|
||||||
if format_cache_key(domain, service) not in descriptions_cache:
|
cache_key = (domain, service)
|
||||||
|
all_services.append(cache_key)
|
||||||
|
if cache_key not in descriptions_cache:
|
||||||
missing.add(domain)
|
missing.add(domain)
|
||||||
break
|
|
||||||
|
# If we have a complete cache, check if it is still valid
|
||||||
|
if ALL_SERVICE_DESCRIPTIONS_CACHE in hass.data:
|
||||||
|
previous_all_services, previous_descriptions_cache = hass.data[
|
||||||
|
ALL_SERVICE_DESCRIPTIONS_CACHE
|
||||||
|
]
|
||||||
|
# If the services are the same, we can return the cache
|
||||||
|
if previous_all_services == all_services:
|
||||||
|
return cast(dict[str, dict[str, Any]], previous_descriptions_cache)
|
||||||
|
|
||||||
# Files we loaded for missing descriptions
|
# Files we loaded for missing descriptions
|
||||||
loaded = {}
|
loaded = {}
|
||||||
|
@ -595,7 +606,7 @@ async def async_get_all_descriptions(
|
||||||
descriptions[domain] = {}
|
descriptions[domain] = {}
|
||||||
|
|
||||||
for service in services[domain]:
|
for service in services[domain]:
|
||||||
cache_key = format_cache_key(domain, service)
|
cache_key = (domain, service)
|
||||||
description = descriptions_cache.get(cache_key)
|
description = descriptions_cache.get(cache_key)
|
||||||
|
|
||||||
# Cache missing descriptions
|
# Cache missing descriptions
|
||||||
|
@ -622,6 +633,7 @@ async def async_get_all_descriptions(
|
||||||
|
|
||||||
descriptions[domain][service] = description
|
descriptions[domain][service] = description
|
||||||
|
|
||||||
|
hass.data[ALL_SERVICE_DESCRIPTIONS_CACHE] = (all_services, descriptions)
|
||||||
return descriptions
|
return descriptions
|
||||||
|
|
||||||
|
|
||||||
|
@ -652,7 +664,8 @@ def async_set_service_schema(
|
||||||
if "target" in schema:
|
if "target" in schema:
|
||||||
description["target"] = schema["target"]
|
description["target"] = schema["target"]
|
||||||
|
|
||||||
hass.data[SERVICE_DESCRIPTION_CACHE][f"{domain}.{service}"] = description
|
hass.data.pop(ALL_SERVICE_DESCRIPTIONS_CACHE, None)
|
||||||
|
hass.data[SERVICE_DESCRIPTION_CACHE][(domain, service)] = description
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
|
|
|
@ -514,13 +514,14 @@ async def test_get_states(hass: HomeAssistant, websocket_client) -> None:
|
||||||
|
|
||||||
async def test_get_services(hass: HomeAssistant, websocket_client) -> None:
|
async def test_get_services(hass: HomeAssistant, websocket_client) -> None:
|
||||||
"""Test get_services command."""
|
"""Test get_services command."""
|
||||||
await websocket_client.send_json({"id": 5, "type": "get_services"})
|
for id_ in (5, 6):
|
||||||
|
await websocket_client.send_json({"id": id_, "type": "get_services"})
|
||||||
|
|
||||||
msg = await websocket_client.receive_json()
|
msg = await websocket_client.receive_json()
|
||||||
assert msg["id"] == 5
|
assert msg["id"] == id_
|
||||||
assert msg["type"] == const.TYPE_RESULT
|
assert msg["type"] == const.TYPE_RESULT
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == hass.services.async_services()
|
assert msg["result"] == hass.services.async_services()
|
||||||
|
|
||||||
|
|
||||||
async def test_get_config(hass: HomeAssistant, websocket_client) -> None:
|
async def test_get_config(hass: HomeAssistant, websocket_client) -> None:
|
||||||
|
|
|
@ -564,6 +564,23 @@ async def test_async_get_all_descriptions(hass: HomeAssistant) -> None:
|
||||||
assert "description" in descriptions[logger.DOMAIN]["set_level"]
|
assert "description" in descriptions[logger.DOMAIN]["set_level"]
|
||||||
assert "fields" in descriptions[logger.DOMAIN]["set_level"]
|
assert "fields" in descriptions[logger.DOMAIN]["set_level"]
|
||||||
|
|
||||||
|
hass.services.async_register(logger.DOMAIN, "new_service", lambda x: None, None)
|
||||||
|
service.async_set_service_schema(
|
||||||
|
hass, logger.DOMAIN, "new_service", {"description": "new service"}
|
||||||
|
)
|
||||||
|
descriptions = await service.async_get_all_descriptions(hass)
|
||||||
|
assert "description" in descriptions[logger.DOMAIN]["new_service"]
|
||||||
|
assert descriptions[logger.DOMAIN]["new_service"]["description"] == "new service"
|
||||||
|
|
||||||
|
hass.services.async_register(
|
||||||
|
logger.DOMAIN, "another_new_service", lambda x: None, None
|
||||||
|
)
|
||||||
|
descriptions = await service.async_get_all_descriptions(hass)
|
||||||
|
assert "another_new_service" in descriptions[logger.DOMAIN]
|
||||||
|
|
||||||
|
# Verify the cache returns the same object
|
||||||
|
assert await service.async_get_all_descriptions(hass) is descriptions
|
||||||
|
|
||||||
|
|
||||||
async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -> None:
|
async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -> None:
|
||||||
"""Test service calls invoked only if entity has required features."""
|
"""Test service calls invoked only if entity has required features."""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue