Add domain filter support to async_all to match async_entity_ids (#39725)

This avoids copying all the states before applying
the filter
This commit is contained in:
J. Nick Koston 2020-09-06 16:20:32 -05:00 committed by GitHub
parent 19818d96b7
commit 251d8919ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 45 additions and 16 deletions

View file

@ -41,8 +41,7 @@ class HumidityHandler(intent.IntentHandler):
hass = intent_obj.hass hass = intent_obj.hass
slots = self.async_validate_slots(intent_obj.slots) slots = self.async_validate_slots(intent_obj.slots)
state = hass.helpers.intent.async_match_state( state = hass.helpers.intent.async_match_state(
slots["name"]["value"], slots["name"]["value"], hass.states.async_all(DOMAIN)
[state for state in hass.states.async_all() if state.domain == DOMAIN],
) )
service_data = {ATTR_ENTITY_ID: state.entity_id} service_data = {ATTR_ENTITY_ID: state.entity_id}
@ -87,7 +86,7 @@ class SetModeHandler(intent.IntentHandler):
slots = self.async_validate_slots(intent_obj.slots) slots = self.async_validate_slots(intent_obj.slots)
state = hass.helpers.intent.async_match_state( state = hass.helpers.intent.async_match_state(
slots["name"]["value"], slots["name"]["value"],
[state for state in hass.states.async_all() if state.domain == DOMAIN], hass.states.async_all(DOMAIN),
) )
service_data = {ATTR_ENTITY_ID: state.entity_id} service_data = {ATTR_ENTITY_ID: state.entity_id}

View file

@ -39,8 +39,7 @@ class SetIntentHandler(intent.IntentHandler):
hass = intent_obj.hass hass = intent_obj.hass
slots = self.async_validate_slots(intent_obj.slots) slots = self.async_validate_slots(intent_obj.slots)
state = hass.helpers.intent.async_match_state( state = hass.helpers.intent.async_match_state(
slots["name"]["value"], slots["name"]["value"], hass.states.async_all(DOMAIN)
[state for state in hass.states.async_all() if state.domain == DOMAIN],
) )
service_data = {ATTR_ENTITY_ID: state.entity_id} service_data = {ATTR_ENTITY_ID: state.entity_id}

View file

@ -183,10 +183,7 @@ async def handle_webhook(hass, webhook_id, request):
response = [] response = []
for person in hass.states.async_all(): for person in hass.states.async_all("person"):
if person.domain != "person":
continue
if "latitude" in person.attributes and "longitude" in person.attributes: if "latitude" in person.attributes and "longitude" in person.attributes:
response.append( response.append(
{ {

View file

@ -918,17 +918,29 @@ class StateMachine:
if state.domain in domain_filter if state.domain in domain_filter
] ]
def all(self) -> List[State]: def all(self, domain_filter: Optional[Union[str, Iterable]] = None) -> List[State]:
"""Create a list of all states.""" """Create a list of all states."""
return run_callback_threadsafe(self._loop, self.async_all).result() return run_callback_threadsafe(
self._loop, self.async_all, domain_filter
).result()
@callback @callback
def async_all(self) -> List[State]: def async_all(
"""Create a list of all states. self, domain_filter: Optional[Union[str, Iterable]] = None
) -> List[State]:
"""Create a list of all states matching the filter.
This method must be run in the event loop. This method must be run in the event loop.
""" """
return list(self._states.values()) if domain_filter is None:
return list(self._states.values())
if isinstance(domain_filter, str):
domain_filter = (domain_filter.lower(),)
return [
state for state in self._states.values() if state.domain in domain_filter
]
def get(self, entity_id: str) -> Optional[State]: def get(self, entity_id: str) -> Optional[State]:
"""Retrieve state of entity_id or None if not found. """Retrieve state of entity_id or None if not found.

View file

@ -459,8 +459,7 @@ class DomainStates:
sorted( sorted(
( (
_wrap_state(self._hass, state) _wrap_state(self._hass, state)
for state in self._hass.states.async_all() for state in self._hass.states.async_all(self._domain)
if state.domain == self._domain
), ),
key=lambda state: state.entity_id, key=lambda state: state.entity_id,
) )

View file

@ -1454,3 +1454,26 @@ async def test_chained_logging_misses_log_timeout(hass, caplog):
await hass.async_block_till_done() await hass.async_block_till_done()
assert "_task_chain_" not in caplog.text assert "_task_chain_" not in caplog.text
async def test_async_all(hass):
"""Test async_all."""
hass.states.async_set("switch.link", "on")
hass.states.async_set("light.bowl", "on")
hass.states.async_set("light.frog", "on")
hass.states.async_set("vacuum.floor", "on")
assert {state.entity_id for state in hass.states.async_all()} == {
"switch.link",
"light.bowl",
"light.frog",
"vacuum.floor",
}
assert {state.entity_id for state in hass.states.async_all("light")} == {
"light.bowl",
"light.frog",
}
assert {
state.entity_id for state in hass.states.async_all(["light", "switch"])
} == {"light.bowl", "light.frog", "switch.link"}