diff --git a/homeassistant/components/recorder/models.py b/homeassistant/components/recorder/models.py index 8b5aef88738..28eff4d9d95 100644 --- a/homeassistant/components/recorder/models.py +++ b/homeassistant/components/recorder/models.py @@ -267,6 +267,7 @@ class Statistics(Base): # type: ignore class StatisticMetaData(TypedDict, total=False): """Statistic meta data class.""" + statistic_id: str unit_of_measurement: str | None has_mean: bool has_sum: bool diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 06f7851b1a6..ddc542d23b7 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -53,6 +53,13 @@ QUERY_STATISTIC_META = [ StatisticsMeta.id, StatisticsMeta.statistic_id, StatisticsMeta.unit_of_measurement, + StatisticsMeta.has_mean, + StatisticsMeta.has_sum, +] + +QUERY_STATISTIC_META_ID = [ + StatisticsMeta.id, + StatisticsMeta.statistic_id, ] STATISTICS_BAKERY = "recorder_statistics_bakery" @@ -124,33 +131,61 @@ def _get_metadata_ids( ) -> list[str]: """Resolve metadata_id for a list of statistic_ids.""" baked_query = hass.data[STATISTICS_META_BAKERY]( - lambda session: session.query(*QUERY_STATISTIC_META) + lambda session: session.query(*QUERY_STATISTIC_META_ID) ) baked_query += lambda q: q.filter( StatisticsMeta.statistic_id.in_(bindparam("statistic_ids")) ) result = execute(baked_query(session).params(statistic_ids=statistic_ids)) - return [id for id, _, _ in result] if result else [] + return [id for id, _ in result] if result else [] -def _get_or_add_metadata_id( +def _update_or_add_metadata( hass: HomeAssistant, session: scoped_session, statistic_id: str, - metadata: StatisticMetaData, + new_metadata: StatisticMetaData, ) -> str: """Get metadata_id for a statistic_id, add if it doesn't exist.""" - metadata_id = _get_metadata_ids(hass, session, [statistic_id]) - if not metadata_id: - unit = metadata["unit_of_measurement"] - has_mean = metadata["has_mean"] - has_sum = metadata["has_sum"] + old_metadata_dict = _get_metadata(hass, session, [statistic_id], None) + if not old_metadata_dict: + unit = new_metadata["unit_of_measurement"] + has_mean = new_metadata["has_mean"] + has_sum = new_metadata["has_sum"] session.add( StatisticsMeta.from_meta(DOMAIN, statistic_id, unit, has_mean, has_sum) ) - metadata_id = _get_metadata_ids(hass, session, [statistic_id]) - return metadata_id[0] + metadata_ids = _get_metadata_ids(hass, session, [statistic_id]) + _LOGGER.debug( + "Added new statistics metadata for %s, new_metadata: %s", + statistic_id, + new_metadata, + ) + return metadata_ids[0] + + metadata_id, old_metadata = next(iter(old_metadata_dict.items())) + if ( + old_metadata["has_mean"] != new_metadata["has_mean"] + or old_metadata["has_sum"] != new_metadata["has_sum"] + or old_metadata["unit_of_measurement"] != new_metadata["unit_of_measurement"] + ): + session.query(StatisticsMeta).filter_by(statistic_id=statistic_id).update( + { + StatisticsMeta.has_mean: new_metadata["has_mean"], + StatisticsMeta.has_sum: new_metadata["has_sum"], + StatisticsMeta.unit_of_measurement: new_metadata["unit_of_measurement"], + }, + synchronize_session=False, + ) + _LOGGER.debug( + "Updated statistics metadata for %s, old_metadata: %s, new_metadata: %s", + statistic_id, + old_metadata, + new_metadata, + ) + + return metadata_id @retryable_database_job("statistics") @@ -177,7 +212,7 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool: with session_scope(session=instance.get_session()) as session: # type: ignore for stats in platform_stats: for entity_id, stat in stats.items(): - metadata_id = _get_or_add_metadata_id( + metadata_id = _update_or_add_metadata( instance.hass, session, entity_id, stat["meta"] ) session.add(Statistics.from_stats(metadata_id, start, stat["stat"])) @@ -191,14 +226,19 @@ def _get_metadata( session: scoped_session, statistic_ids: list[str] | None, statistic_type: str | None, -) -> dict[str, dict[str, str]]: +) -> dict[str, StatisticMetaData]: """Fetch meta data.""" - def _meta(metas: list, wanted_metadata_id: str) -> dict[str, str] | None: - meta = None - for metadata_id, statistic_id, unit in metas: + def _meta(metas: list, wanted_metadata_id: str) -> StatisticMetaData | None: + meta: StatisticMetaData | None = None + for metadata_id, statistic_id, unit, has_mean, has_sum in metas: if metadata_id == wanted_metadata_id: - meta = {"unit_of_measurement": unit, "statistic_id": statistic_id} + meta = { + "statistic_id": statistic_id, + "unit_of_measurement": unit, + "has_mean": has_mean, + "has_sum": has_sum, + } return meta baked_query = hass.data[STATISTICS_META_BAKERY]( @@ -219,7 +259,7 @@ def _get_metadata( return {} metadata_ids = [metadata[0] for metadata in result] - metadata = {} + metadata: dict[str, StatisticMetaData] = {} for _id in metadata_ids: meta = _meta(result, _id) if meta: @@ -230,7 +270,7 @@ def _get_metadata( def get_metadata( hass: HomeAssistant, statistic_id: str, -) -> dict[str, str] | None: +) -> StatisticMetaData | None: """Return metadata for a statistic_id.""" statistic_ids = [statistic_id] with session_scope(hass=hass) as session: @@ -255,7 +295,7 @@ def _configured_unit(unit: str, units: UnitSystem) -> str: def list_statistic_ids( hass: HomeAssistant, statistic_type: str | None = None -) -> list[dict[str, str] | None]: +) -> list[StatisticMetaData | None]: """Return statistic_ids and meta data.""" units = hass.config.units statistic_ids = {} @@ -263,7 +303,9 @@ def list_statistic_ids( metadata = _get_metadata(hass, session, None, statistic_type) for meta in metadata.values(): - unit = _configured_unit(meta["unit_of_measurement"], units) + unit = meta["unit_of_measurement"] + if unit is not None: + unit = _configured_unit(unit, units) meta["unit_of_measurement"] = unit statistic_ids = { @@ -277,7 +319,8 @@ def list_statistic_ids( platform_statistic_ids = platform.list_statistic_ids(hass, statistic_type) for statistic_id, unit in platform_statistic_ids.items(): - unit = _configured_unit(unit, units) + if unit is not None: + unit = _configured_unit(unit, units) platform_statistic_ids[statistic_id] = unit statistic_ids = {**statistic_ids, **platform_statistic_ids} @@ -367,7 +410,7 @@ def _sorted_statistics_to_dict( hass: HomeAssistant, stats: list, statistic_ids: list[str] | None, - metadata: dict[str, dict[str, str]], + metadata: dict[str, StatisticMetaData], ) -> dict[str, list[dict]]: """Convert SQL results into JSON friendly data structure.""" result: dict = defaultdict(list) diff --git a/homeassistant/components/sensor/recorder.py b/homeassistant/components/sensor/recorder.py index bcb21136007..2b59592dd17 100644 --- a/homeassistant/components/sensor/recorder.py +++ b/homeassistant/components/sensor/recorder.py @@ -352,7 +352,7 @@ def compile_statistics( # We have compiled history for this sensor before, use that as a starting point last_reset = old_last_reset = last_stats[entity_id][0]["last_reset"] new_state = old_state = last_stats[entity_id][0]["state"] - _sum = last_stats[entity_id][0]["sum"] + _sum = last_stats[entity_id][0]["sum"] or 0 for fstate, state in fstates: diff --git a/tests/components/sensor/test_recorder.py b/tests/components/sensor/test_recorder.py index c7f356e49ee..2e300b9c748 100644 --- a/tests/components/sensor/test_recorder.py +++ b/tests/components/sensor/test_recorder.py @@ -10,6 +10,7 @@ from homeassistant.components.recorder import history from homeassistant.components.recorder.const import DATA_INSTANCE from homeassistant.components.recorder.models import process_timestamp_to_utc_isoformat from homeassistant.components.recorder.statistics import ( + get_metadata, list_statistic_ids, statistics_during_period, ) @@ -1037,6 +1038,95 @@ def test_compile_hourly_statistics_changing_units_2( assert "Error while processing event StatisticsTask" not in caplog.text +@pytest.mark.parametrize( + "device_class,unit,native_unit,mean,min,max", + [ + (None, None, None, 16.440677, 10, 30), + ], +) +def test_compile_hourly_statistics_changing_statistics( + hass_recorder, caplog, device_class, unit, native_unit, mean, min, max +): + """Test compiling hourly statistics where units change during an hour.""" + zero = dt_util.utcnow() + hass = hass_recorder() + recorder = hass.data[DATA_INSTANCE] + setup_component(hass, "sensor", {}) + attributes_1 = { + "device_class": device_class, + "state_class": "measurement", + "unit_of_measurement": unit, + } + attributes_2 = { + "device_class": device_class, + "state_class": "total_increasing", + "unit_of_measurement": unit, + } + four, states = record_states(hass, zero, "sensor.test1", attributes_1) + recorder.do_adhoc_statistics(period="hourly", start=zero) + wait_recording_done(hass) + statistic_ids = list_statistic_ids(hass) + assert statistic_ids == [ + {"statistic_id": "sensor.test1", "unit_of_measurement": None} + ] + metadata = get_metadata(hass, "sensor.test1") + assert metadata == { + "has_mean": True, + "has_sum": False, + "statistic_id": "sensor.test1", + "unit_of_measurement": None, + } + + # Add more states, with changed state class + four, _states = record_states( + hass, zero + timedelta(hours=1), "sensor.test1", attributes_2 + ) + states["sensor.test1"] += _states["sensor.test1"] + hist = history.get_significant_states(hass, zero, four) + assert dict(states) == dict(hist) + + recorder.do_adhoc_statistics(period="hourly", start=zero + timedelta(hours=1)) + wait_recording_done(hass) + statistic_ids = list_statistic_ids(hass) + assert statistic_ids == [ + {"statistic_id": "sensor.test1", "unit_of_measurement": None} + ] + metadata = get_metadata(hass, "sensor.test1") + assert metadata == { + "has_mean": False, + "has_sum": True, + "statistic_id": "sensor.test1", + "unit_of_measurement": None, + } + stats = statistics_during_period(hass, zero) + assert stats == { + "sensor.test1": [ + { + "statistic_id": "sensor.test1", + "start": process_timestamp_to_utc_isoformat(zero), + "mean": approx(mean), + "min": approx(min), + "max": approx(max), + "last_reset": None, + "state": None, + "sum": None, + }, + { + "statistic_id": "sensor.test1", + "start": process_timestamp_to_utc_isoformat(zero + timedelta(hours=1)), + "mean": None, + "min": None, + "max": None, + "last_reset": None, + "state": approx(30.0), + "sum": approx(30.0), + }, + ] + } + + assert "Error while processing event StatisticsTask" not in caplog.text + + def record_states(hass, zero, entity_id, attributes): """Record some test states.