From de2e9b6d77adb7f86c6ec4aa0a50428ec8606dc3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 6 Jun 2022 09:50:52 -1000 Subject: [PATCH] Fix state_changes_during_period history query when no entities are passed (#73139) --- homeassistant/components/recorder/history.py | 17 ++++++------ tests/components/recorder/test_history.py | 29 ++++++++++++++++++++ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/homeassistant/components/recorder/history.py b/homeassistant/components/recorder/history.py index 5dd5c0d3040..37285f66d1d 100644 --- a/homeassistant/components/recorder/history.py +++ b/homeassistant/components/recorder/history.py @@ -352,7 +352,8 @@ def _state_changed_during_period_stmt( ) if end_time: stmt += lambda q: q.filter(States.last_updated < end_time) - stmt += lambda q: q.filter(States.entity_id == entity_id) + if entity_id: + stmt += lambda q: q.filter(States.entity_id == entity_id) if join_attributes: stmt += lambda q: q.outerjoin( StateAttributes, States.attributes_id == StateAttributes.attributes_id @@ -378,6 +379,7 @@ def state_changes_during_period( ) -> MutableMapping[str, list[State]]: """Return states changes during UTC period start_time - end_time.""" entity_id = entity_id.lower() if entity_id is not None else None + entity_ids = [entity_id] if entity_id is not None else None with session_scope(hass=hass) as session: stmt = _state_changed_during_period_stmt( @@ -392,8 +394,6 @@ def state_changes_during_period( states = execute_stmt_lambda_element( session, stmt, None if entity_id else start_time, end_time ) - entity_ids = [entity_id] if entity_id is not None else None - return cast( MutableMapping[str, list[State]], _sorted_states_to_dict( @@ -408,14 +408,16 @@ def state_changes_during_period( def _get_last_state_changes_stmt( - schema_version: int, number_of_states: int, entity_id: str + schema_version: int, number_of_states: int, entity_id: str | None ) -> StatementLambdaElement: stmt, join_attributes = lambda_stmt_and_join_attributes( schema_version, False, include_last_changed=False ) stmt += lambda q: q.filter( (States.last_changed == States.last_updated) | States.last_changed.is_(None) - ).filter(States.entity_id == entity_id) + ) + if entity_id: + stmt += lambda q: q.filter(States.entity_id == entity_id) if join_attributes: stmt += lambda q: q.outerjoin( StateAttributes, States.attributes_id == StateAttributes.attributes_id @@ -427,19 +429,18 @@ def _get_last_state_changes_stmt( def get_last_state_changes( - hass: HomeAssistant, number_of_states: int, entity_id: str + hass: HomeAssistant, number_of_states: int, entity_id: str | None ) -> MutableMapping[str, list[State]]: """Return the last number_of_states.""" start_time = dt_util.utcnow() entity_id = entity_id.lower() if entity_id is not None else None + entity_ids = [entity_id] if entity_id is not None else None with session_scope(hass=hass) as session: stmt = _get_last_state_changes_stmt( _schema_version(hass), number_of_states, entity_id ) states = list(execute_stmt_lambda_element(session, stmt)) - entity_ids = [entity_id] if entity_id is not None else None - return cast( MutableMapping[str, list[State]], _sorted_states_to_dict( diff --git a/tests/components/recorder/test_history.py b/tests/components/recorder/test_history.py index da6c3a8af35..ee02ffbec49 100644 --- a/tests/components/recorder/test_history.py +++ b/tests/components/recorder/test_history.py @@ -878,3 +878,32 @@ async def test_get_full_significant_states_handles_empty_last_changed( assert db_sensor_one_states[0].last_updated is not None assert db_sensor_one_states[1].last_updated is not None assert db_sensor_one_states[0].last_updated != db_sensor_one_states[1].last_updated + + +def test_state_changes_during_period_multiple_entities_single_test(hass_recorder): + """Test state change during period with multiple entities in the same test. + + This test ensures the sqlalchemy query cache does not + generate incorrect results. + """ + hass = hass_recorder() + start = dt_util.utcnow() + test_entites = {f"sensor.{i}": str(i) for i in range(30)} + for entity_id, value in test_entites.items(): + hass.states.set(entity_id, value) + + wait_recording_done(hass) + end = dt_util.utcnow() + + hist = history.state_changes_during_period(hass, start, end, None) + for entity_id, value in test_entites.items(): + hist[entity_id][0].state == value + + for entity_id, value in test_entites.items(): + hist = history.state_changes_during_period(hass, start, end, entity_id) + assert len(hist) == 1 + hist[entity_id][0].state == value + + hist = history.state_changes_during_period(hass, start, end, None) + for entity_id, value in test_entites.items(): + hist[entity_id][0].state == value