Replace StateMachine._domain_index with a UserDict (#100270)
* Replace StateMachine._domain_index with a UserDict * Access the UserDict's backing dict directly * Optimize
This commit is contained in:
parent
d0feb063ec
commit
6057fe5926
2 changed files with 94 additions and 35 deletions
|
@ -6,7 +6,16 @@ of entities and react to changes.
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Collection, Coroutine, Iterable, Mapping
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import (
|
||||
Callable,
|
||||
Collection,
|
||||
Coroutine,
|
||||
Iterable,
|
||||
KeysView,
|
||||
Mapping,
|
||||
ValuesView,
|
||||
)
|
||||
import concurrent.futures
|
||||
from contextlib import suppress
|
||||
import datetime
|
||||
|
@ -1413,15 +1422,59 @@ class State:
|
|||
)
|
||||
|
||||
|
||||
class States(UserDict[str, State]):
|
||||
"""Container for states, maps entity_id -> State.
|
||||
|
||||
Maintains an additional index:
|
||||
- domain -> dict[str, State]
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the container."""
|
||||
super().__init__()
|
||||
self._domain_index: defaultdict[str, dict[str, State]] = defaultdict(dict)
|
||||
|
||||
def values(self) -> ValuesView[State]:
|
||||
"""Return the underlying values to avoid __iter__ overhead."""
|
||||
return self.data.values()
|
||||
|
||||
def __setitem__(self, key: str, entry: State) -> None:
|
||||
"""Add an item."""
|
||||
self.data[key] = entry
|
||||
self._domain_index[entry.domain][entry.entity_id] = entry
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""Remove an item."""
|
||||
entry = self[key]
|
||||
del self._domain_index[entry.domain][entry.entity_id]
|
||||
super().__delitem__(key)
|
||||
|
||||
def domain_entity_ids(self, key: str) -> KeysView[str] | tuple[()]:
|
||||
"""Get all entity_ids for a domain."""
|
||||
# Avoid polluting _domain_index with non-existing domains
|
||||
if key not in self._domain_index:
|
||||
return ()
|
||||
return self._domain_index[key].keys()
|
||||
|
||||
def domain_states(self, key: str) -> ValuesView[State] | tuple[()]:
|
||||
"""Get all states for a domain."""
|
||||
# Avoid polluting _domain_index with non-existing domains
|
||||
if key not in self._domain_index:
|
||||
return ()
|
||||
return self._domain_index[key].values()
|
||||
|
||||
|
||||
class StateMachine:
|
||||
"""Helper class that tracks the state of different entities."""
|
||||
|
||||
__slots__ = ("_states", "_domain_index", "_reservations", "_bus", "_loop")
|
||||
__slots__ = ("_states", "_states_data", "_reservations", "_bus", "_loop")
|
||||
|
||||
def __init__(self, bus: EventBus, loop: asyncio.events.AbstractEventLoop) -> None:
|
||||
"""Initialize state machine."""
|
||||
self._states: dict[str, State] = {}
|
||||
self._domain_index: dict[str, dict[str, State]] = {}
|
||||
self._states = States()
|
||||
# _states_data is used to access the States backing dict directly to speed
|
||||
# up read operations
|
||||
self._states_data = self._states.data
|
||||
self._reservations: set[str] = set()
|
||||
self._bus = bus
|
||||
self._loop = loop
|
||||
|
@ -1442,16 +1495,15 @@ class StateMachine:
|
|||
This method must be run in the event loop.
|
||||
"""
|
||||
if domain_filter is None:
|
||||
return list(self._states)
|
||||
return list(self._states_data)
|
||||
|
||||
if isinstance(domain_filter, str):
|
||||
return list(self._domain_index.get(domain_filter.lower(), ()))
|
||||
return list(self._states.domain_entity_ids(domain_filter.lower()))
|
||||
|
||||
states: list[str] = []
|
||||
entity_ids: list[str] = []
|
||||
for domain in domain_filter:
|
||||
if domain_index := self._domain_index.get(domain):
|
||||
states.extend(domain_index)
|
||||
return states
|
||||
entity_ids.extend(self._states.domain_entity_ids(domain))
|
||||
return entity_ids
|
||||
|
||||
@callback
|
||||
def async_entity_ids_count(
|
||||
|
@ -1462,12 +1514,14 @@ class StateMachine:
|
|||
This method must be run in the event loop.
|
||||
"""
|
||||
if domain_filter is None:
|
||||
return len(self._states)
|
||||
return len(self._states_data)
|
||||
|
||||
if isinstance(domain_filter, str):
|
||||
return len(self._domain_index.get(domain_filter.lower(), ()))
|
||||
return len(self._states.domain_entity_ids(domain_filter.lower()))
|
||||
|
||||
return sum(len(self._domain_index.get(domain, ())) for domain in domain_filter)
|
||||
return sum(
|
||||
len(self._states.domain_entity_ids(domain)) for domain in domain_filter
|
||||
)
|
||||
|
||||
def all(self, domain_filter: str | Iterable[str] | None = None) -> list[State]:
|
||||
"""Create a list of all states."""
|
||||
|
@ -1484,15 +1538,14 @@ class StateMachine:
|
|||
This method must be run in the event loop.
|
||||
"""
|
||||
if domain_filter is None:
|
||||
return list(self._states.values())
|
||||
return list(self._states_data.values())
|
||||
|
||||
if isinstance(domain_filter, str):
|
||||
return list(self._domain_index.get(domain_filter.lower(), {}).values())
|
||||
return list(self._states.domain_states(domain_filter.lower()))
|
||||
|
||||
states: list[State] = []
|
||||
for domain in domain_filter:
|
||||
if domain_index := self._domain_index.get(domain):
|
||||
states.extend(domain_index.values())
|
||||
states.extend(self._states.domain_states(domain))
|
||||
return states
|
||||
|
||||
def get(self, entity_id: str) -> State | None:
|
||||
|
@ -1500,7 +1553,7 @@ class StateMachine:
|
|||
|
||||
Async friendly.
|
||||
"""
|
||||
return self._states.get(entity_id.lower())
|
||||
return self._states_data.get(entity_id.lower())
|
||||
|
||||
def is_state(self, entity_id: str, state: str) -> bool:
|
||||
"""Test if entity exists and is in specified state.
|
||||
|
@ -1534,7 +1587,6 @@ class StateMachine:
|
|||
if old_state is None:
|
||||
return False
|
||||
|
||||
self._domain_index[old_state.domain].pop(entity_id)
|
||||
old_state.expire()
|
||||
self._bus.async_fire(
|
||||
EVENT_STATE_CHANGED,
|
||||
|
@ -1579,7 +1631,7 @@ class StateMachine:
|
|||
entity_id are added.
|
||||
"""
|
||||
entity_id = entity_id.lower()
|
||||
if entity_id in self._states or entity_id in self._reservations:
|
||||
if entity_id in self._states_data or entity_id in self._reservations:
|
||||
raise HomeAssistantError(
|
||||
"async_reserve must not be called once the state is in the state"
|
||||
" machine."
|
||||
|
@ -1591,7 +1643,9 @@ class StateMachine:
|
|||
def async_available(self, entity_id: str) -> bool:
|
||||
"""Check to see if an entity_id is available to be used."""
|
||||
entity_id = entity_id.lower()
|
||||
return entity_id not in self._states and entity_id not in self._reservations
|
||||
return (
|
||||
entity_id not in self._states_data and entity_id not in self._reservations
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_set(
|
||||
|
@ -1614,7 +1668,7 @@ class StateMachine:
|
|||
entity_id = entity_id.lower()
|
||||
new_state = str(new_state)
|
||||
attributes = attributes or {}
|
||||
if (old_state := self._states.get(entity_id)) is None:
|
||||
if (old_state := self._states_data.get(entity_id)) is None:
|
||||
same_state = False
|
||||
same_attr = False
|
||||
last_changed = None
|
||||
|
@ -1656,10 +1710,6 @@ class StateMachine:
|
|||
if old_state is not None:
|
||||
old_state.expire()
|
||||
self._states[entity_id] = state
|
||||
if not (domain_index := self._domain_index.get(state.domain)):
|
||||
domain_index = {}
|
||||
self._domain_index[state.domain] = domain_index
|
||||
domain_index[entity_id] = state
|
||||
self._bus.async_fire(
|
||||
EVENT_STATE_CHANGED,
|
||||
{"entity_id": entity_id, "old_state": old_state, "new_state": state},
|
||||
|
|
|
@ -15,6 +15,7 @@ from typing import Any
|
|||
from unittest.mock import MagicMock, Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_unordered import unordered
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import (
|
||||
|
@ -1031,17 +1032,18 @@ async def test_statemachine_is_state(hass: HomeAssistant) -> None:
|
|||
|
||||
|
||||
async def test_statemachine_entity_ids(hass: HomeAssistant) -> None:
|
||||
"""Test get_entity_ids method."""
|
||||
"""Test async_entity_ids method."""
|
||||
assert hass.states.async_entity_ids() == []
|
||||
assert hass.states.async_entity_ids("light") == []
|
||||
assert hass.states.async_entity_ids(("light", "switch", "other")) == []
|
||||
|
||||
hass.states.async_set("light.bowl", "on", {})
|
||||
hass.states.async_set("SWITCH.AC", "off", {})
|
||||
ent_ids = hass.states.async_entity_ids()
|
||||
assert len(ent_ids) == 2
|
||||
assert "light.bowl" in ent_ids
|
||||
assert "switch.ac" in ent_ids
|
||||
|
||||
ent_ids = hass.states.async_entity_ids("light")
|
||||
assert len(ent_ids) == 1
|
||||
assert "light.bowl" in ent_ids
|
||||
assert hass.states.async_entity_ids() == unordered(["light.bowl", "switch.ac"])
|
||||
assert hass.states.async_entity_ids("light") == ["light.bowl"]
|
||||
assert hass.states.async_entity_ids(("light", "switch", "other")) == unordered(
|
||||
["light.bowl", "switch.ac"]
|
||||
)
|
||||
|
||||
states = sorted(state.entity_id for state in hass.states.async_all())
|
||||
assert states == ["light.bowl", "switch.ac"]
|
||||
|
@ -1902,6 +1904,9 @@ async def test_chained_logging_misses_log_timeout(
|
|||
|
||||
async def test_async_all(hass: HomeAssistant) -> None:
|
||||
"""Test async_all."""
|
||||
assert hass.states.async_all() == []
|
||||
assert hass.states.async_all("light") == []
|
||||
assert hass.states.async_all(["light", "switch"]) == []
|
||||
|
||||
hass.states.async_set("switch.link", "on")
|
||||
hass.states.async_set("light.bowl", "on")
|
||||
|
@ -1926,6 +1931,10 @@ async def test_async_all(hass: HomeAssistant) -> None:
|
|||
async def test_async_entity_ids_count(hass: HomeAssistant) -> None:
|
||||
"""Test async_entity_ids_count."""
|
||||
|
||||
assert hass.states.async_entity_ids_count() == 0
|
||||
assert hass.states.async_entity_ids_count("light") == 0
|
||||
assert hass.states.async_entity_ids_count({"light", "vacuum"}) == 0
|
||||
|
||||
hass.states.async_set("switch.link", "on")
|
||||
hass.states.async_set("light.bowl", "on")
|
||||
hass.states.async_set("light.frog", "on")
|
||||
|
|
Loading…
Add table
Reference in a new issue