Store automation and script traces (#56894)

* Store automation and script traces

* Pylint

* Deduplicate code

* Fix issues when no stored traces are available

* Store serialized data for restored traces

* Update WS API

* Update test

* Restore context

* Improve tests

* Add new test files

* Rename restore_traces to async_restore_traces

* Refactor trace.websocket_api

* Defer loading stored traces

* Lint

* Revert refactoring which is no longer needed

* Correct order when restoring traces

* Apply suggestion from code review

* Improve test coverage

* Apply suggestions from code review
This commit is contained in:
Erik Montnemery 2021-10-19 10:23:23 +02:00 committed by GitHub
parent 29c062fcc4
commit 961ee717ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 1256 additions and 191 deletions

View file

@ -1,14 +1,19 @@
"""Test Trace websocket API."""
import asyncio
import json
from typing import DefaultDict
from unittest.mock import patch
import pytest
from homeassistant.bootstrap import async_setup_component
from homeassistant.components.trace.const import DEFAULT_STORED_TRACES
from homeassistant.core import Context, callback
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Context, CoreState, callback
from homeassistant.helpers.typing import UNDEFINED
from homeassistant.util.uuid import random_uuid_hex
from tests.common import assert_lists_same
from tests.common import assert_lists_same, load_fixture
def _find_run_id(traces, trace_type, item_id):
@ -70,8 +75,12 @@ def _assert_raw_config(domain, config, trace):
assert trace["config"] == config
async def _assert_contexts(client, next_id, contexts):
await client.send_json({"id": next_id(), "type": "trace/contexts"})
async def _assert_contexts(client, next_id, contexts, domain=None, item_id=None):
request = {"id": next_id(), "type": "trace/contexts"}
if domain is not None:
request["domain"] = domain
request["item_id"] = item_id
await client.send_json(request)
response = await client.receive_json()
assert response["success"]
assert response["result"] == contexts
@ -101,6 +110,7 @@ async def _assert_contexts(client, next_id, contexts):
)
async def test_get_trace(
hass,
hass_storage,
hass_ws_client,
domain,
prefix,
@ -152,6 +162,8 @@ async def test_get_trace(
client = await hass_ws_client()
contexts = {}
contexts_sun = {}
contexts_moon = {}
# Trigger "sun" automation / run "sun" script
context = Context()
@ -195,6 +207,11 @@ async def test_get_trace(
"domain": domain,
"item_id": trace["item_id"],
}
contexts_sun[trace["context"]["id"]] = {
"run_id": trace["run_id"],
"domain": domain,
"item_id": trace["item_id"],
}
# Trigger "moon" automation, with passing condition / run "moon" script
await _run_automation_or_script(hass, domain, moon_config, "test_event2", context)
@ -244,10 +261,17 @@ async def test_get_trace(
"domain": domain,
"item_id": trace["item_id"],
}
contexts_moon[trace["context"]["id"]] = {
"run_id": trace["run_id"],
"domain": domain,
"item_id": trace["item_id"],
}
if len(extra_trace_keys) <= 2:
# Check contexts
await _assert_contexts(client, next_id, contexts)
await _assert_contexts(client, next_id, contexts_moon, domain, "moon")
await _assert_contexts(client, next_id, contexts_sun, domain, "sun")
return
# Trigger "moon" automation with failing condition
@ -291,6 +315,11 @@ async def test_get_trace(
"domain": domain,
"item_id": trace["item_id"],
}
contexts_moon[trace["context"]["id"]] = {
"run_id": trace["run_id"],
"domain": domain,
"item_id": trace["item_id"],
}
# Trigger "moon" automation with passing condition
hass.bus.async_fire("test_event2")
@ -336,9 +365,119 @@ async def test_get_trace(
"domain": domain,
"item_id": trace["item_id"],
}
contexts_moon[trace["context"]["id"]] = {
"run_id": trace["run_id"],
"domain": domain,
"item_id": trace["item_id"],
}
# Check contexts
await _assert_contexts(client, next_id, contexts)
await _assert_contexts(client, next_id, contexts_moon, domain, "moon")
await _assert_contexts(client, next_id, contexts_sun, domain, "sun")
# List traces
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
trace_list = response["result"]
# Get all traces and generate expected stored traces
traces = DefaultDict(list)
for trace in trace_list:
item_id = trace["item_id"]
run_id = trace["run_id"]
await client.send_json(
{
"id": next_id(),
"type": "trace/get",
"domain": domain,
"item_id": item_id,
"run_id": run_id,
}
)
response = await client.receive_json()
assert response["success"]
traces[f"{domain}.{item_id}"].append(
{"short_dict": trace, "extended_dict": response["result"]}
)
# Fake stop
assert "trace.saved_traces" not in hass_storage
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
# Check that saved data is same as the serialized traces
assert "trace.saved_traces" in hass_storage
assert hass_storage["trace.saved_traces"]["data"] == traces
@pytest.mark.parametrize("domain", ["automation", "script"])
async def test_restore_traces(hass, hass_storage, hass_ws_client, domain):
"""Test restored traces."""
hass.state = CoreState.not_running
id = 1
def next_id():
nonlocal id
id += 1
return id
saved_traces = json.loads(load_fixture(f"trace/{domain}_saved_traces.json"))
hass_storage["trace.saved_traces"] = saved_traces
await _setup_automation_or_script(hass, domain, [])
await hass.async_start()
await hass.async_block_till_done()
client = await hass_ws_client()
# List traces
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
trace_list = response["result"]
# Get all traces and generate expected stored traces
traces = DefaultDict(list)
contexts = {}
for trace in trace_list:
item_id = trace["item_id"]
run_id = trace["run_id"]
await client.send_json(
{
"id": next_id(),
"type": "trace/get",
"domain": domain,
"item_id": item_id,
"run_id": run_id,
}
)
response = await client.receive_json()
assert response["success"]
traces[f"{domain}.{item_id}"].append(
{"short_dict": trace, "extended_dict": response["result"]}
)
contexts[response["result"]["context"]["id"]] = {
"run_id": trace["run_id"],
"domain": domain,
"item_id": trace["item_id"],
}
# Check that loaded data is same as the serialized traces
assert hass_storage["trace.saved_traces"]["data"] == traces
# Check restored contexts
await _assert_contexts(client, next_id, contexts)
# Fake stop
hass_storage.pop("trace.saved_traces")
assert "trace.saved_traces" not in hass_storage
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
# Check that saved data is same as the serialized traces
assert "trace.saved_traces" in hass_storage
assert hass_storage["trace.saved_traces"] == saved_traces
@pytest.mark.parametrize("domain", ["automation", "script"])
@ -368,6 +507,13 @@ async def test_trace_overflow(hass, hass_ws_client, domain, stored_traces):
"""Test the number of stored traces per script or automation is limited."""
id = 1
trace_uuids = []
def mock_random_uuid_hex():
nonlocal trace_uuids
trace_uuids.append(random_uuid_hex())
return trace_uuids[-1]
def next_id():
nonlocal id
id += 1
@ -404,13 +550,16 @@ async def test_trace_overflow(hass, hass_ws_client, domain, stored_traces):
response = await client.receive_json()
assert response["success"]
assert len(_find_traces(response["result"], domain, "moon")) == 1
moon_run_id = _find_run_id(response["result"], domain, "moon")
assert len(_find_traces(response["result"], domain, "sun")) == 1
# Trigger "moon" enough times to overflow the max number of stored traces
for _ in range(stored_traces or DEFAULT_STORED_TRACES):
await _run_automation_or_script(hass, domain, moon_config, "test_event2")
await hass.async_block_till_done()
with patch(
"homeassistant.components.trace.uuid_util.random_uuid_hex",
wraps=mock_random_uuid_hex,
):
for _ in range(stored_traces or DEFAULT_STORED_TRACES):
await _run_automation_or_script(hass, domain, moon_config, "test_event2")
await hass.async_block_till_done()
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
@ -418,10 +567,153 @@ async def test_trace_overflow(hass, hass_ws_client, domain, stored_traces):
moon_traces = _find_traces(response["result"], domain, "moon")
assert len(moon_traces) == stored_traces or DEFAULT_STORED_TRACES
assert moon_traces[0]
assert int(moon_traces[0]["run_id"]) == int(moon_run_id) + 1
assert int(moon_traces[-1]["run_id"]) == int(moon_run_id) + (
stored_traces or DEFAULT_STORED_TRACES
)
assert moon_traces[0]["run_id"] == trace_uuids[0]
assert moon_traces[-1]["run_id"] == trace_uuids[-1]
assert len(_find_traces(response["result"], domain, "sun")) == 1
@pytest.mark.parametrize(
"domain,num_restored_moon_traces", [("automation", 3), ("script", 1)]
)
async def test_restore_traces_overflow(
hass, hass_storage, hass_ws_client, domain, num_restored_moon_traces
):
"""Test restored traces are evicted first."""
hass.state = CoreState.not_running
id = 1
trace_uuids = []
def mock_random_uuid_hex():
nonlocal trace_uuids
trace_uuids.append(random_uuid_hex())
return trace_uuids[-1]
def next_id():
nonlocal id
id += 1
return id
saved_traces = json.loads(load_fixture(f"trace/{domain}_saved_traces.json"))
hass_storage["trace.saved_traces"] = saved_traces
sun_config = {
"id": "sun",
"trigger": {"platform": "event", "event_type": "test_event"},
"action": {"event": "some_event"},
}
moon_config = {
"id": "moon",
"trigger": {"platform": "event", "event_type": "test_event2"},
"action": {"event": "another_event"},
}
await _setup_automation_or_script(hass, domain, [sun_config, moon_config])
await hass.async_start()
await hass.async_block_till_done()
client = await hass_ws_client()
# Traces should not yet be restored
assert "trace_traces_restored" not in hass.data
# List traces
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
restored_moon_traces = _find_traces(response["result"], domain, "moon")
assert len(restored_moon_traces) == num_restored_moon_traces
assert len(_find_traces(response["result"], domain, "sun")) == 1
# Traces should be restored
assert "trace_traces_restored" in hass.data
# Trigger "moon" enough times to overflow the max number of stored traces
with patch(
"homeassistant.components.trace.uuid_util.random_uuid_hex",
wraps=mock_random_uuid_hex,
):
for _ in range(DEFAULT_STORED_TRACES - num_restored_moon_traces + 1):
await _run_automation_or_script(hass, domain, moon_config, "test_event2")
await hass.async_block_till_done()
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
moon_traces = _find_traces(response["result"], domain, "moon")
assert len(moon_traces) == DEFAULT_STORED_TRACES
if num_restored_moon_traces > 1:
assert moon_traces[0]["run_id"] == restored_moon_traces[1]["run_id"]
assert moon_traces[num_restored_moon_traces - 1]["run_id"] == trace_uuids[0]
assert moon_traces[-1]["run_id"] == trace_uuids[-1]
assert len(_find_traces(response["result"], domain, "sun")) == 1
@pytest.mark.parametrize(
"domain,num_restored_moon_traces,restored_run_id",
[("automation", 3, "e2c97432afe9b8a42d7983588ed5e6ef"), ("script", 1, "")],
)
async def test_restore_traces_late_overflow(
hass,
hass_storage,
hass_ws_client,
domain,
num_restored_moon_traces,
restored_run_id,
):
"""Test restored traces are evicted first."""
hass.state = CoreState.not_running
id = 1
trace_uuids = []
def mock_random_uuid_hex():
nonlocal trace_uuids
trace_uuids.append(random_uuid_hex())
return trace_uuids[-1]
def next_id():
nonlocal id
id += 1
return id
saved_traces = json.loads(load_fixture(f"trace/{domain}_saved_traces.json"))
hass_storage["trace.saved_traces"] = saved_traces
sun_config = {
"id": "sun",
"trigger": {"platform": "event", "event_type": "test_event"},
"action": {"event": "some_event"},
}
moon_config = {
"id": "moon",
"trigger": {"platform": "event", "event_type": "test_event2"},
"action": {"event": "another_event"},
}
await _setup_automation_or_script(hass, domain, [sun_config, moon_config])
await hass.async_start()
await hass.async_block_till_done()
client = await hass_ws_client()
# Traces should not yet be restored
assert "trace_traces_restored" not in hass.data
# Trigger "moon" enough times to overflow the max number of stored traces
with patch(
"homeassistant.components.trace.uuid_util.random_uuid_hex",
wraps=mock_random_uuid_hex,
):
for _ in range(DEFAULT_STORED_TRACES - num_restored_moon_traces + 1):
await _run_automation_or_script(hass, domain, moon_config, "test_event2")
await hass.async_block_till_done()
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
moon_traces = _find_traces(response["result"], domain, "moon")
assert len(moon_traces) == DEFAULT_STORED_TRACES
if num_restored_moon_traces > 1:
assert moon_traces[0]["run_id"] == restored_run_id
assert moon_traces[num_restored_moon_traces - 1]["run_id"] == trace_uuids[0]
assert moon_traces[-1]["run_id"] == trace_uuids[-1]
assert len(_find_traces(response["result"], domain, "sun")) == 1