Fix recorder attribute excludes not being effective until after startup (#90198)

* Fix attribute excludes not being effective until after startup

fixes #90016

* reduce
This commit is contained in:
J. Nick Koston 2023-03-23 14:52:37 -10:00 committed by GitHub
parent dd0f05b980
commit d49fbc17df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 75 additions and 18 deletions

View file

@ -28,6 +28,9 @@ from .const import ( # noqa: F401
EVENT_RECORDER_5MIN_STATISTICS_GENERATED,
EVENT_RECORDER_HOURLY_STATISTICS_GENERATED,
EXCLUDE_ATTRIBUTES,
INTEGRATION_PLATFORM_COMPILE_STATISTICS,
INTEGRATION_PLATFORM_EXCLUDE_ATTRIBUTES,
INTEGRATION_PLATFORMS_LOAD_IN_RECORDER_THREAD,
SQLITE_URL_PREFIX,
)
from .core import Recorder
@ -165,14 +168,40 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async_register_services(hass, instance)
websocket_api.async_setup(hass)
entity_registry.async_setup(hass)
await async_process_integration_platforms(hass, DOMAIN, _process_recorder_platform)
await _async_setup_integration_platform(
hass, instance, exclude_attributes_by_domain
)
return await instance.async_db_ready
async def _async_setup_integration_platform(
hass: HomeAssistant,
instance: Recorder,
exclude_attributes_by_domain: dict[str, set[str]],
) -> None:
"""Set up a recorder integration platform."""
async def _process_recorder_platform(
hass: HomeAssistant, domain: str, platform: Any
) -> None:
"""Process a recorder platform."""
instance = get_instance(hass)
# We need to add this before as soon as the component is loaded
# to ensure by the time the state is recorded that the excluded
# attributes are known. This is safe to modify in the event loop
# since exclude_attributes_by_domain is never iterated over.
if exclude_attributes := getattr(
platform, INTEGRATION_PLATFORM_EXCLUDE_ATTRIBUTES, None
):
exclude_attributes_by_domain[domain] = exclude_attributes(hass)
# If the platform has a compile_statistics method, we need to
# add it to the recorder queue to be processed.
if any(
hasattr(platform, _attr)
for _attr in INTEGRATION_PLATFORMS_LOAD_IN_RECORDER_THREAD
):
instance.queue_task(AddRecorderPlatformTask(domain, platform))
await async_process_integration_platforms(hass, DOMAIN, _process_recorder_platform)

View file

@ -51,6 +51,19 @@ STATES_META_SCHEMA_VERSION = 38
LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION = 28
INTEGRATION_PLATFORM_EXCLUDE_ATTRIBUTES = "exclude_attributes"
INTEGRATION_PLATFORM_COMPILE_STATISTICS = "compile_statistics"
INTEGRATION_PLATFORM_VALIDATE_STATISTICS = "validate_statistics"
INTEGRATION_PLATFORM_LIST_STATISTIC_IDS = "list_statistic_ids"
INTEGRATION_PLATFORMS_LOAD_IN_RECORDER_THREAD = {
INTEGRATION_PLATFORM_COMPILE_STATISTICS,
INTEGRATION_PLATFORM_VALIDATE_STATISTICS,
INTEGRATION_PLATFORM_LIST_STATISTIC_IDS,
}
class SupportedDialect(StrEnum):
"""Supported dialects."""

View file

@ -47,6 +47,9 @@ from .const import (
DOMAIN,
EVENT_RECORDER_5MIN_STATISTICS_GENERATED,
EVENT_RECORDER_HOURLY_STATISTICS_GENERATED,
INTEGRATION_PLATFORM_COMPILE_STATISTICS,
INTEGRATION_PLATFORM_LIST_STATISTIC_IDS,
INTEGRATION_PLATFORM_VALIDATE_STATISTICS,
SupportedDialect,
)
from .db_schema import (
@ -502,9 +505,13 @@ def _compile_statistics(
current_metadata: dict[str, tuple[int, StatisticMetaData]] = {}
# Collect statistics from all platforms implementing support
for domain, platform in instance.hass.data[DOMAIN].recorder_platforms.items():
if not hasattr(platform, "compile_statistics"):
if not (
platform_compile_statistics := getattr(
platform, INTEGRATION_PLATFORM_COMPILE_STATISTICS, None
)
):
continue
compiled: PlatformCompiledStatistics = platform.compile_statistics(
compiled: PlatformCompiledStatistics = platform_compile_statistics(
instance.hass, start, end
)
_LOGGER.debug(
@ -783,9 +790,13 @@ def list_statistic_ids(
#
# Query all integrations with a registered recorder platform
for platform in hass.data[DOMAIN].recorder_platforms.values():
if not hasattr(platform, "list_statistic_ids"):
if not (
platform_list_statistic_ids := getattr(
platform, INTEGRATION_PLATFORM_LIST_STATISTIC_IDS, None
)
):
continue
platform_statistic_ids = platform.list_statistic_ids(
platform_statistic_ids = platform_list_statistic_ids(
hass, statistic_ids=statistic_ids, statistic_type=statistic_type
)
@ -1931,9 +1942,10 @@ def validate_statistics(hass: HomeAssistant) -> dict[str, list[ValidationIssue]]
"""Validate statistics."""
platform_validation: dict[str, list[ValidationIssue]] = {}
for platform in hass.data[DOMAIN].recorder_platforms.values():
if not hasattr(platform, "validate_statistics"):
continue
platform_validation.update(platform.validate_statistics(hass))
if platform_validate_statistics := getattr(
platform, INTEGRATION_PLATFORM_VALIDATE_STATISTICS, None
):
platform_validation.update(platform_validate_statistics(hass))
return platform_validation

View file

@ -14,7 +14,7 @@ from homeassistant.core import Event
from homeassistant.helpers.typing import UndefinedType
from . import entity_registry, purge, statistics
from .const import DOMAIN, EXCLUDE_ATTRIBUTES
from .const import DOMAIN
from .db_schema import Statistics, StatisticsShortTerm
from .models import StatisticData, StatisticMetaData
from .util import periodic_db_cleanups
@ -317,11 +317,8 @@ class AddRecorderPlatformTask(RecorderTask):
hass = instance.hass
domain = self.domain
platform = self.platform
platforms: dict[str, Any] = hass.data[DOMAIN].recorder_platforms
platforms[domain] = platform
if hasattr(self.platform, "exclude_attributes"):
hass.data[EXCLUDE_ATTRIBUTES][domain] = platform.exclude_attributes(hass)
@dataclass

View file

@ -2112,10 +2112,15 @@ async def test_connect_args_priority(hass: HomeAssistant, config_url) -> None:
assert connect_params[0]["charset"] == "utf8mb4"
@pytest.mark.parametrize("core_state", [CoreState.starting, CoreState.running])
async def test_excluding_attributes_by_integration(
recorder_mock: Recorder, hass: HomeAssistant, entity_registry: er.EntityRegistry
recorder_mock: Recorder,
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
core_state: CoreState,
) -> None:
"""Test that an integration's recorder platform can exclude attributes."""
hass.state = core_state
state = "restoring_from_db"
attributes = {"test_attr": 5, "excluded": 10}
mock_platform(
@ -2131,6 +2136,7 @@ async def test_excluding_attributes_by_integration(
platform = MockEntityPlatform(hass, platform_name="fake_integration")
entity_platform = MockEntity(entity_id=entity_id, extra_state_attributes=attributes)
await platform.async_add_entities([entity_platform])
await hass.async_block_till_done()
await async_wait_recording_done(hass)