Make recorder WS command recorder/clear_statistics wait (#127120)

This commit is contained in:
Erik Montnemery 2024-10-02 10:43:40 +02:00 committed by Franck Nijhof
parent 565203047c
commit 9c28a4e8a0
No known key found for this signature in database
GPG key ID: D62583BA8AB11CA3
4 changed files with 51 additions and 5 deletions

View file

@ -570,9 +570,11 @@ class Recorder(threading.Thread):
)
@callback
def async_clear_statistics(self, statistic_ids: list[str]) -> None:
def async_clear_statistics(
self, statistic_ids: list[str], *, on_done: Callable[[], None] | None = None
) -> None:
"""Clear statistics for a list of statistic_ids."""
self.queue_task(ClearStatisticsTask(statistic_ids))
self.queue_task(ClearStatisticsTask(on_done, statistic_ids))
@callback
def async_update_statistics_metadata(

View file

@ -60,11 +60,14 @@ class ChangeStatisticsUnitTask(RecorderTask):
class ClearStatisticsTask(RecorderTask):
"""Object to store statistics_ids which for which to remove statistics."""
on_done: Callable[[], None] | None
statistic_ids: list[str]
def run(self, instance: Recorder) -> None:
"""Handle the task."""
statistics.clear_statistics(instance, self.statistic_ids)
if self.on_done:
self.on_done()
@dataclass(slots=True)

View file

@ -49,6 +49,7 @@ from .statistics import (
)
from .util import PERIOD_SCHEMA, get_instance, resolve_period
CLEAR_STATISTICS_TIME_OUT = 10
UPDATE_STATISTICS_METADATA_TIME_OUT = 10
UNIT_SCHEMA = vol.Schema(
@ -322,8 +323,8 @@ async def ws_update_statistics_issues(
vol.Required("statistic_ids"): [str],
}
)
@callback
def ws_clear_statistics(
@websocket_api.async_response
async def ws_clear_statistics(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Clear statistics for a list of statistic_ids.
@ -331,7 +332,23 @@ def ws_clear_statistics(
Note: The WS call posts a job to the recorder's queue and then returns, it doesn't
wait until the job is completed.
"""
get_instance(hass).async_clear_statistics(msg["statistic_ids"])
done_event = asyncio.Event()
def clear_statistics_done() -> None:
hass.loop.call_soon_threadsafe(done_event.set)
get_instance(hass).async_clear_statistics(
msg["statistic_ids"], on_done=clear_statistics_done
)
try:
async with asyncio.timeout(CLEAR_STATISTICS_TIME_OUT):
await done_event.wait()
except TimeoutError:
connection.send_error(
msg["id"], websocket_api.ERR_TIMEOUT, "clear_statistics timed out"
)
return
connection.send_result(msg["id"])

View file

@ -2116,6 +2116,30 @@ async def test_clear_statistics(
assert response["result"] == {"sensor.test2": expected_response["sensor.test2"]}
async def test_clear_statistics_time_out(
recorder_mock: Recorder, hass: HomeAssistant, hass_ws_client: WebSocketGenerator
) -> None:
"""Test removing statistics with time-out error."""
client = await hass_ws_client()
with (
patch.object(recorder.tasks.ClearStatisticsTask, "run"),
patch.object(recorder.websocket_api, "CLEAR_STATISTICS_TIME_OUT", 0),
):
await client.send_json_auto_id(
{
"type": "recorder/clear_statistics",
"statistic_ids": ["sensor.test"],
}
)
response = await client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "timeout",
"message": "clear_statistics timed out",
}
@pytest.mark.parametrize(
("new_unit", "new_unit_class", "new_display_unit"),
[("dogs", None, "dogs"), (None, "unitless", None), ("W", "power", "kW")],