diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 2621db9cb70..750f504d096 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -28,6 +28,9 @@ from .const import ( # noqa: F401 EVENT_RECORDER_5MIN_STATISTICS_GENERATED, EVENT_RECORDER_HOURLY_STATISTICS_GENERATED, EXCLUDE_ATTRIBUTES, + INTEGRATION_PLATFORM_COMPILE_STATISTICS, + INTEGRATION_PLATFORM_EXCLUDE_ATTRIBUTES, + INTEGRATION_PLATFORMS_LOAD_IN_RECORDER_THREAD, SQLITE_URL_PREFIX, ) from .core import Recorder @@ -165,14 +168,40 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async_register_services(hass, instance) websocket_api.async_setup(hass) entity_registry.async_setup(hass) - await async_process_integration_platforms(hass, DOMAIN, _process_recorder_platform) + + await _async_setup_integration_platform( + hass, instance, exclude_attributes_by_domain + ) return await instance.async_db_ready -async def _process_recorder_platform( - hass: HomeAssistant, domain: str, platform: Any +async def _async_setup_integration_platform( + hass: HomeAssistant, + instance: Recorder, + exclude_attributes_by_domain: dict[str, set[str]], ) -> None: - """Process a recorder platform.""" - instance = get_instance(hass) - instance.queue_task(AddRecorderPlatformTask(domain, platform)) + """Set up a recorder integration platform.""" + + async def _process_recorder_platform( + hass: HomeAssistant, domain: str, platform: Any + ) -> None: + """Process a recorder platform.""" + # We need to add this before as soon as the component is loaded + # to ensure by the time the state is recorded that the excluded + # attributes are known. This is safe to modify in the event loop + # since exclude_attributes_by_domain is never iterated over. + if exclude_attributes := getattr( + platform, INTEGRATION_PLATFORM_EXCLUDE_ATTRIBUTES, None + ): + exclude_attributes_by_domain[domain] = exclude_attributes(hass) + + # If the platform has a compile_statistics method, we need to + # add it to the recorder queue to be processed. + if any( + hasattr(platform, _attr) + for _attr in INTEGRATION_PLATFORMS_LOAD_IN_RECORDER_THREAD + ): + instance.queue_task(AddRecorderPlatformTask(domain, platform)) + + await async_process_integration_platforms(hass, DOMAIN, _process_recorder_platform) diff --git a/homeassistant/components/recorder/const.py b/homeassistant/components/recorder/const.py index 6bf46efd360..fbec19a2d1e 100644 --- a/homeassistant/components/recorder/const.py +++ b/homeassistant/components/recorder/const.py @@ -51,6 +51,19 @@ STATES_META_SCHEMA_VERSION = 38 LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION = 28 +INTEGRATION_PLATFORM_EXCLUDE_ATTRIBUTES = "exclude_attributes" + +INTEGRATION_PLATFORM_COMPILE_STATISTICS = "compile_statistics" +INTEGRATION_PLATFORM_VALIDATE_STATISTICS = "validate_statistics" +INTEGRATION_PLATFORM_LIST_STATISTIC_IDS = "list_statistic_ids" + +INTEGRATION_PLATFORMS_LOAD_IN_RECORDER_THREAD = { + INTEGRATION_PLATFORM_COMPILE_STATISTICS, + INTEGRATION_PLATFORM_VALIDATE_STATISTICS, + INTEGRATION_PLATFORM_LIST_STATISTIC_IDS, +} + + class SupportedDialect(StrEnum): """Supported dialects.""" diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 82fbf7798f9..8025616d246 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -47,6 +47,9 @@ from .const import ( DOMAIN, EVENT_RECORDER_5MIN_STATISTICS_GENERATED, EVENT_RECORDER_HOURLY_STATISTICS_GENERATED, + INTEGRATION_PLATFORM_COMPILE_STATISTICS, + INTEGRATION_PLATFORM_LIST_STATISTIC_IDS, + INTEGRATION_PLATFORM_VALIDATE_STATISTICS, SupportedDialect, ) from .db_schema import ( @@ -502,9 +505,13 @@ def _compile_statistics( current_metadata: dict[str, tuple[int, StatisticMetaData]] = {} # Collect statistics from all platforms implementing support for domain, platform in instance.hass.data[DOMAIN].recorder_platforms.items(): - if not hasattr(platform, "compile_statistics"): + if not ( + platform_compile_statistics := getattr( + platform, INTEGRATION_PLATFORM_COMPILE_STATISTICS, None + ) + ): continue - compiled: PlatformCompiledStatistics = platform.compile_statistics( + compiled: PlatformCompiledStatistics = platform_compile_statistics( instance.hass, start, end ) _LOGGER.debug( @@ -783,9 +790,13 @@ def list_statistic_ids( # # Query all integrations with a registered recorder platform for platform in hass.data[DOMAIN].recorder_platforms.values(): - if not hasattr(platform, "list_statistic_ids"): + if not ( + platform_list_statistic_ids := getattr( + platform, INTEGRATION_PLATFORM_LIST_STATISTIC_IDS, None + ) + ): continue - platform_statistic_ids = platform.list_statistic_ids( + platform_statistic_ids = platform_list_statistic_ids( hass, statistic_ids=statistic_ids, statistic_type=statistic_type ) @@ -1931,9 +1942,10 @@ def validate_statistics(hass: HomeAssistant) -> dict[str, list[ValidationIssue]] """Validate statistics.""" platform_validation: dict[str, list[ValidationIssue]] = {} for platform in hass.data[DOMAIN].recorder_platforms.values(): - if not hasattr(platform, "validate_statistics"): - continue - platform_validation.update(platform.validate_statistics(hass)) + if platform_validate_statistics := getattr( + platform, INTEGRATION_PLATFORM_VALIDATE_STATISTICS, None + ): + platform_validation.update(platform_validate_statistics(hass)) return platform_validation diff --git a/homeassistant/components/recorder/tasks.py b/homeassistant/components/recorder/tasks.py index 7b8fa4867b6..ef118857059 100644 --- a/homeassistant/components/recorder/tasks.py +++ b/homeassistant/components/recorder/tasks.py @@ -14,7 +14,7 @@ from homeassistant.core import Event from homeassistant.helpers.typing import UndefinedType from . import entity_registry, purge, statistics -from .const import DOMAIN, EXCLUDE_ATTRIBUTES +from .const import DOMAIN from .db_schema import Statistics, StatisticsShortTerm from .models import StatisticData, StatisticMetaData from .util import periodic_db_cleanups @@ -317,11 +317,8 @@ class AddRecorderPlatformTask(RecorderTask): hass = instance.hass domain = self.domain platform = self.platform - platforms: dict[str, Any] = hass.data[DOMAIN].recorder_platforms platforms[domain] = platform - if hasattr(self.platform, "exclude_attributes"): - hass.data[EXCLUDE_ATTRIBUTES][domain] = platform.exclude_attributes(hass) @dataclass diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 3232b10fdce..8fb45cb3d4d 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -2112,10 +2112,15 @@ async def test_connect_args_priority(hass: HomeAssistant, config_url) -> None: assert connect_params[0]["charset"] == "utf8mb4" +@pytest.mark.parametrize("core_state", [CoreState.starting, CoreState.running]) async def test_excluding_attributes_by_integration( - recorder_mock: Recorder, hass: HomeAssistant, entity_registry: er.EntityRegistry + recorder_mock: Recorder, + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + core_state: CoreState, ) -> None: """Test that an integration's recorder platform can exclude attributes.""" + hass.state = core_state state = "restoring_from_db" attributes = {"test_attr": 5, "excluded": 10} mock_platform( @@ -2131,6 +2136,7 @@ async def test_excluding_attributes_by_integration( platform = MockEntityPlatform(hass, platform_name="fake_integration") entity_platform = MockEntity(entity_id=entity_id, extra_state_attributes=attributes) await platform.async_add_entities([entity_platform]) + await hass.async_block_till_done() await async_wait_recording_done(hass)