Improve performance of counting and iterating states in templates (#40250)

Co-authored-by: Anders Melchiorsen <amelchio@nogoto.net>
This commit is contained in:
J. Nick Koston 2020-09-26 11:36:47 -05:00 committed by GitHub
parent 1d41f024cf
commit 35533407fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 11 deletions

View file

@ -923,6 +923,24 @@ class StateMachine:
if state.domain in domain_filter
]
@callback
def async_entity_ids_count(
self, domain_filter: Optional[Union[str, Iterable]] = None
) -> int:
"""Count the entity ids that are being tracked.
This method must be run in the event loop.
"""
if domain_filter is None:
return len(self._states.keys())
if isinstance(domain_filter, str):
domain_filter = (domain_filter.lower(),)
return len(
[None for state in self._states.values() if state.domain in domain_filter]
)
def all(self, domain_filter: Optional[Union[str, Iterable]] = None) -> List[State]:
"""Create a list of all states."""
return run_callback_threadsafe(

View file

@ -9,7 +9,7 @@ import math
from operator import attrgetter
import random
import re
from typing import Any, Iterable, List, Optional, Union
from typing import Any, Generator, Iterable, List, Optional, Union
from urllib.parse import urlencode as urllib_urlencode
import weakref
@ -425,12 +425,12 @@ class AllStates:
def __iter__(self):
"""Return all states."""
self._collect_all()
return _state_iterator(self._hass, None)
return _state_generator(self._hass, None)
def __len__(self) -> int:
"""Return number of states."""
self._collect_all()
return len(self._hass.states.async_entity_ids())
return self._hass.states.async_entity_ids_count()
def __call__(self, entity_id):
"""Return the states."""
@ -465,12 +465,12 @@ class DomainStates:
def __iter__(self):
"""Return the iteration over all the states."""
self._collect_domain()
return _state_iterator(self._hass, self._domain)
return _state_generator(self._hass, self._domain)
def __len__(self) -> int:
"""Return number of states."""
self._collect_domain()
return len(self._hass.states.async_entity_ids(self._domain))
return self._hass.states.async_entity_ids_count(self._domain)
def __repr__(self) -> str:
"""Representation of Domain States."""
@ -537,12 +537,10 @@ def _collect_state(hass: HomeAssistantType, entity_id: str) -> None:
entity_collect.entities.add(entity_id)
def _state_iterator(hass: HomeAssistantType, domain: Optional[str]) -> Iterable:
"""Create an state iterator for a domain or all states."""
return iter(
TemplateState(hass, state)
for state in sorted(hass.states.async_all(domain), key=attrgetter("entity_id"))
)
def _state_generator(hass: HomeAssistantType, domain: Optional[str]) -> Generator:
"""State generator for a domain or all states."""
for state in sorted(hass.states.async_all(domain), key=attrgetter("entity_id")):
yield TemplateState(hass, state)
def _get_state(hass: HomeAssistantType, entity_id: str) -> Optional[TemplateState]:

View file

@ -2420,3 +2420,14 @@ For loop example getting 3 entity values:
assert "sensor0" in result
assert "sensor1" in result
assert "sun" in result
async def test_slice_states(hass):
"""Test iterating states with a slice."""
hass.states.async_set("sensor.test", "23")
tpl = template.Template(
"{% for states in states | slice(1) -%}{% set state = states | first %}{{ state.entity_id }}{%- endfor %}",
hass,
)
assert tpl.async_render() == "sensor.test"

View file

@ -1477,3 +1477,20 @@ async def test_async_all(hass):
assert {
state.entity_id for state in hass.states.async_all(["light", "switch"])
} == {"light.bowl", "light.frog", "switch.link"}
async def test_async_entity_ids_count(hass):
"""Test async_entity_ids_count."""
hass.states.async_set("switch.link", "on")
hass.states.async_set("light.bowl", "on")
hass.states.async_set("light.frog", "on")
hass.states.async_set("vacuum.floor", "on")
assert hass.states.async_entity_ids_count() == 4
assert hass.states.async_entity_ids_count("light") == 2
hass.states.async_set("light.cow", "on")
assert hass.states.async_entity_ids_count() == 5
assert hass.states.async_entity_ids_count("light") == 3