Add domain filter support to async_all to match async_entity_ids (#39725)
This avoids copying all the states before applying the filter
This commit is contained in:
parent
19818d96b7
commit
251d8919ea
6 changed files with 45 additions and 16 deletions
|
@ -41,8 +41,7 @@ class HumidityHandler(intent.IntentHandler):
|
||||||
hass = intent_obj.hass
|
hass = intent_obj.hass
|
||||||
slots = self.async_validate_slots(intent_obj.slots)
|
slots = self.async_validate_slots(intent_obj.slots)
|
||||||
state = hass.helpers.intent.async_match_state(
|
state = hass.helpers.intent.async_match_state(
|
||||||
slots["name"]["value"],
|
slots["name"]["value"], hass.states.async_all(DOMAIN)
|
||||||
[state for state in hass.states.async_all() if state.domain == DOMAIN],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
service_data = {ATTR_ENTITY_ID: state.entity_id}
|
service_data = {ATTR_ENTITY_ID: state.entity_id}
|
||||||
|
@ -87,7 +86,7 @@ class SetModeHandler(intent.IntentHandler):
|
||||||
slots = self.async_validate_slots(intent_obj.slots)
|
slots = self.async_validate_slots(intent_obj.slots)
|
||||||
state = hass.helpers.intent.async_match_state(
|
state = hass.helpers.intent.async_match_state(
|
||||||
slots["name"]["value"],
|
slots["name"]["value"],
|
||||||
[state for state in hass.states.async_all() if state.domain == DOMAIN],
|
hass.states.async_all(DOMAIN),
|
||||||
)
|
)
|
||||||
|
|
||||||
service_data = {ATTR_ENTITY_ID: state.entity_id}
|
service_data = {ATTR_ENTITY_ID: state.entity_id}
|
||||||
|
|
|
@ -39,8 +39,7 @@ class SetIntentHandler(intent.IntentHandler):
|
||||||
hass = intent_obj.hass
|
hass = intent_obj.hass
|
||||||
slots = self.async_validate_slots(intent_obj.slots)
|
slots = self.async_validate_slots(intent_obj.slots)
|
||||||
state = hass.helpers.intent.async_match_state(
|
state = hass.helpers.intent.async_match_state(
|
||||||
slots["name"]["value"],
|
slots["name"]["value"], hass.states.async_all(DOMAIN)
|
||||||
[state for state in hass.states.async_all() if state.domain == DOMAIN],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
service_data = {ATTR_ENTITY_ID: state.entity_id}
|
service_data = {ATTR_ENTITY_ID: state.entity_id}
|
||||||
|
|
|
@ -183,10 +183,7 @@ async def handle_webhook(hass, webhook_id, request):
|
||||||
|
|
||||||
response = []
|
response = []
|
||||||
|
|
||||||
for person in hass.states.async_all():
|
for person in hass.states.async_all("person"):
|
||||||
if person.domain != "person":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "latitude" in person.attributes and "longitude" in person.attributes:
|
if "latitude" in person.attributes and "longitude" in person.attributes:
|
||||||
response.append(
|
response.append(
|
||||||
{
|
{
|
||||||
|
|
|
@ -918,17 +918,29 @@ class StateMachine:
|
||||||
if state.domain in domain_filter
|
if state.domain in domain_filter
|
||||||
]
|
]
|
||||||
|
|
||||||
def all(self) -> List[State]:
|
def all(self, domain_filter: Optional[Union[str, Iterable]] = None) -> List[State]:
|
||||||
"""Create a list of all states."""
|
"""Create a list of all states."""
|
||||||
return run_callback_threadsafe(self._loop, self.async_all).result()
|
return run_callback_threadsafe(
|
||||||
|
self._loop, self.async_all, domain_filter
|
||||||
|
).result()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_all(self) -> List[State]:
|
def async_all(
|
||||||
"""Create a list of all states.
|
self, domain_filter: Optional[Union[str, Iterable]] = None
|
||||||
|
) -> List[State]:
|
||||||
|
"""Create a list of all states matching the filter.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
return list(self._states.values())
|
if domain_filter is None:
|
||||||
|
return list(self._states.values())
|
||||||
|
|
||||||
|
if isinstance(domain_filter, str):
|
||||||
|
domain_filter = (domain_filter.lower(),)
|
||||||
|
|
||||||
|
return [
|
||||||
|
state for state in self._states.values() if state.domain in domain_filter
|
||||||
|
]
|
||||||
|
|
||||||
def get(self, entity_id: str) -> Optional[State]:
|
def get(self, entity_id: str) -> Optional[State]:
|
||||||
"""Retrieve state of entity_id or None if not found.
|
"""Retrieve state of entity_id or None if not found.
|
||||||
|
|
|
@ -459,8 +459,7 @@ class DomainStates:
|
||||||
sorted(
|
sorted(
|
||||||
(
|
(
|
||||||
_wrap_state(self._hass, state)
|
_wrap_state(self._hass, state)
|
||||||
for state in self._hass.states.async_all()
|
for state in self._hass.states.async_all(self._domain)
|
||||||
if state.domain == self._domain
|
|
||||||
),
|
),
|
||||||
key=lambda state: state.entity_id,
|
key=lambda state: state.entity_id,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1454,3 +1454,26 @@ async def test_chained_logging_misses_log_timeout(hass, caplog):
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert "_task_chain_" not in caplog.text
|
assert "_task_chain_" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_all(hass):
|
||||||
|
"""Test async_all."""
|
||||||
|
|
||||||
|
hass.states.async_set("switch.link", "on")
|
||||||
|
hass.states.async_set("light.bowl", "on")
|
||||||
|
hass.states.async_set("light.frog", "on")
|
||||||
|
hass.states.async_set("vacuum.floor", "on")
|
||||||
|
|
||||||
|
assert {state.entity_id for state in hass.states.async_all()} == {
|
||||||
|
"switch.link",
|
||||||
|
"light.bowl",
|
||||||
|
"light.frog",
|
||||||
|
"vacuum.floor",
|
||||||
|
}
|
||||||
|
assert {state.entity_id for state in hass.states.async_all("light")} == {
|
||||||
|
"light.bowl",
|
||||||
|
"light.frog",
|
||||||
|
}
|
||||||
|
assert {
|
||||||
|
state.entity_id for state in hass.states.async_all(["light", "switch"])
|
||||||
|
} == {"light.bowl", "light.frog", "switch.link"}
|
||||||
|
|
Loading…
Add table
Reference in a new issue