Fix state_changes_during_period history query when no entities are passed (#73139)

This commit is contained in:
J. Nick Koston 2022-06-06 09:50:52 -10:00 committed by GitHub
parent 861de5c0f0
commit de2e9b6d77
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 8 deletions

View file

@ -352,7 +352,8 @@ def _state_changed_during_period_stmt(
)
if end_time:
stmt += lambda q: q.filter(States.last_updated < end_time)
stmt += lambda q: q.filter(States.entity_id == entity_id)
if entity_id:
stmt += lambda q: q.filter(States.entity_id == entity_id)
if join_attributes:
stmt += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
@ -378,6 +379,7 @@ def state_changes_during_period(
) -> MutableMapping[str, list[State]]:
"""Return states changes during UTC period start_time - end_time."""
entity_id = entity_id.lower() if entity_id is not None else None
entity_ids = [entity_id] if entity_id is not None else None
with session_scope(hass=hass) as session:
stmt = _state_changed_during_period_stmt(
@ -392,8 +394,6 @@ def state_changes_during_period(
states = execute_stmt_lambda_element(
session, stmt, None if entity_id else start_time, end_time
)
entity_ids = [entity_id] if entity_id is not None else None
return cast(
MutableMapping[str, list[State]],
_sorted_states_to_dict(
@ -408,14 +408,16 @@ def state_changes_during_period(
def _get_last_state_changes_stmt(
schema_version: int, number_of_states: int, entity_id: str
schema_version: int, number_of_states: int, entity_id: str | None
) -> StatementLambdaElement:
stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, False, include_last_changed=False
)
stmt += lambda q: q.filter(
(States.last_changed == States.last_updated) | States.last_changed.is_(None)
).filter(States.entity_id == entity_id)
)
if entity_id:
stmt += lambda q: q.filter(States.entity_id == entity_id)
if join_attributes:
stmt += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
@ -427,19 +429,18 @@ def _get_last_state_changes_stmt(
def get_last_state_changes(
hass: HomeAssistant, number_of_states: int, entity_id: str
hass: HomeAssistant, number_of_states: int, entity_id: str | None
) -> MutableMapping[str, list[State]]:
"""Return the last number_of_states."""
start_time = dt_util.utcnow()
entity_id = entity_id.lower() if entity_id is not None else None
entity_ids = [entity_id] if entity_id is not None else None
with session_scope(hass=hass) as session:
stmt = _get_last_state_changes_stmt(
_schema_version(hass), number_of_states, entity_id
)
states = list(execute_stmt_lambda_element(session, stmt))
entity_ids = [entity_id] if entity_id is not None else None
return cast(
MutableMapping[str, list[State]],
_sorted_states_to_dict(

View file

@ -878,3 +878,32 @@ async def test_get_full_significant_states_handles_empty_last_changed(
assert db_sensor_one_states[0].last_updated is not None
assert db_sensor_one_states[1].last_updated is not None
assert db_sensor_one_states[0].last_updated != db_sensor_one_states[1].last_updated
def test_state_changes_during_period_multiple_entities_single_test(hass_recorder):
"""Test state change during period with multiple entities in the same test.
This test ensures the sqlalchemy query cache does not
generate incorrect results.
"""
hass = hass_recorder()
start = dt_util.utcnow()
test_entites = {f"sensor.{i}": str(i) for i in range(30)}
for entity_id, value in test_entites.items():
hass.states.set(entity_id, value)
wait_recording_done(hass)
end = dt_util.utcnow()
hist = history.state_changes_during_period(hass, start, end, None)
for entity_id, value in test_entites.items():
hist[entity_id][0].state == value
for entity_id, value in test_entites.items():
hist = history.state_changes_during_period(hass, start, end, entity_id)
assert len(hist) == 1
hist[entity_id][0].state == value
hist = history.state_changes_during_period(hass, start, end, None)
for entity_id, value in test_entites.items():
hist[entity_id][0].state == value