Replace StateMachine._domain_index with a UserDict (#100270)

* Replace StateMachine._domain_index with a UserDict

* Access the UserDict's backing dict directly

* Optimize
This commit is contained in:
Erik Montnemery 2023-09-13 18:05:17 +02:00 committed by GitHub
parent d0feb063ec
commit 6057fe5926
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 94 additions and 35 deletions

View file

@ -6,7 +6,16 @@ of entities and react to changes.
from __future__ import annotations
import asyncio
from collections.abc import Callable, Collection, Coroutine, Iterable, Mapping
from collections import UserDict, defaultdict
from collections.abc import (
Callable,
Collection,
Coroutine,
Iterable,
KeysView,
Mapping,
ValuesView,
)
import concurrent.futures
from contextlib import suppress
import datetime
@ -1413,15 +1422,59 @@ class State:
)
class States(UserDict[str, State]):
"""Container for states, maps entity_id -> State.
Maintains an additional index:
- domain -> dict[str, State]
"""
def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._domain_index: defaultdict[str, dict[str, State]] = defaultdict(dict)
def values(self) -> ValuesView[State]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def __setitem__(self, key: str, entry: State) -> None:
"""Add an item."""
self.data[key] = entry
self._domain_index[entry.domain][entry.entity_id] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
del self._domain_index[entry.domain][entry.entity_id]
super().__delitem__(key)
def domain_entity_ids(self, key: str) -> KeysView[str] | tuple[()]:
"""Get all entity_ids for a domain."""
# Avoid polluting _domain_index with non-existing domains
if key not in self._domain_index:
return ()
return self._domain_index[key].keys()
def domain_states(self, key: str) -> ValuesView[State] | tuple[()]:
"""Get all states for a domain."""
# Avoid polluting _domain_index with non-existing domains
if key not in self._domain_index:
return ()
return self._domain_index[key].values()
class StateMachine:
"""Helper class that tracks the state of different entities."""
__slots__ = ("_states", "_domain_index", "_reservations", "_bus", "_loop")
__slots__ = ("_states", "_states_data", "_reservations", "_bus", "_loop")
def __init__(self, bus: EventBus, loop: asyncio.events.AbstractEventLoop) -> None:
"""Initialize state machine."""
self._states: dict[str, State] = {}
self._domain_index: dict[str, dict[str, State]] = {}
self._states = States()
# _states_data is used to access the States backing dict directly to speed
# up read operations
self._states_data = self._states.data
self._reservations: set[str] = set()
self._bus = bus
self._loop = loop
@ -1442,16 +1495,15 @@ class StateMachine:
This method must be run in the event loop.
"""
if domain_filter is None:
return list(self._states)
return list(self._states_data)
if isinstance(domain_filter, str):
return list(self._domain_index.get(domain_filter.lower(), ()))
return list(self._states.domain_entity_ids(domain_filter.lower()))
states: list[str] = []
entity_ids: list[str] = []
for domain in domain_filter:
if domain_index := self._domain_index.get(domain):
states.extend(domain_index)
return states
entity_ids.extend(self._states.domain_entity_ids(domain))
return entity_ids
@callback
def async_entity_ids_count(
@ -1462,12 +1514,14 @@ class StateMachine:
This method must be run in the event loop.
"""
if domain_filter is None:
return len(self._states)
return len(self._states_data)
if isinstance(domain_filter, str):
return len(self._domain_index.get(domain_filter.lower(), ()))
return len(self._states.domain_entity_ids(domain_filter.lower()))
return sum(len(self._domain_index.get(domain, ())) for domain in domain_filter)
return sum(
len(self._states.domain_entity_ids(domain)) for domain in domain_filter
)
def all(self, domain_filter: str | Iterable[str] | None = None) -> list[State]:
"""Create a list of all states."""
@ -1484,15 +1538,14 @@ class StateMachine:
This method must be run in the event loop.
"""
if domain_filter is None:
return list(self._states.values())
return list(self._states_data.values())
if isinstance(domain_filter, str):
return list(self._domain_index.get(domain_filter.lower(), {}).values())
return list(self._states.domain_states(domain_filter.lower()))
states: list[State] = []
for domain in domain_filter:
if domain_index := self._domain_index.get(domain):
states.extend(domain_index.values())
states.extend(self._states.domain_states(domain))
return states
def get(self, entity_id: str) -> State | None:
@ -1500,7 +1553,7 @@ class StateMachine:
Async friendly.
"""
return self._states.get(entity_id.lower())
return self._states_data.get(entity_id.lower())
def is_state(self, entity_id: str, state: str) -> bool:
"""Test if entity exists and is in specified state.
@ -1534,7 +1587,6 @@ class StateMachine:
if old_state is None:
return False
self._domain_index[old_state.domain].pop(entity_id)
old_state.expire()
self._bus.async_fire(
EVENT_STATE_CHANGED,
@ -1579,7 +1631,7 @@ class StateMachine:
entity_id are added.
"""
entity_id = entity_id.lower()
if entity_id in self._states or entity_id in self._reservations:
if entity_id in self._states_data or entity_id in self._reservations:
raise HomeAssistantError(
"async_reserve must not be called once the state is in the state"
" machine."
@ -1591,7 +1643,9 @@ class StateMachine:
def async_available(self, entity_id: str) -> bool:
"""Check to see if an entity_id is available to be used."""
entity_id = entity_id.lower()
return entity_id not in self._states and entity_id not in self._reservations
return (
entity_id not in self._states_data and entity_id not in self._reservations
)
@callback
def async_set(
@ -1614,7 +1668,7 @@ class StateMachine:
entity_id = entity_id.lower()
new_state = str(new_state)
attributes = attributes or {}
if (old_state := self._states.get(entity_id)) is None:
if (old_state := self._states_data.get(entity_id)) is None:
same_state = False
same_attr = False
last_changed = None
@ -1656,10 +1710,6 @@ class StateMachine:
if old_state is not None:
old_state.expire()
self._states[entity_id] = state
if not (domain_index := self._domain_index.get(state.domain)):
domain_index = {}
self._domain_index[state.domain] = domain_index
domain_index[entity_id] = state
self._bus.async_fire(
EVENT_STATE_CHANGED,
{"entity_id": entity_id, "old_state": old_state, "new_state": state},

View file

@ -15,6 +15,7 @@ from typing import Any
from unittest.mock import MagicMock, Mock, PropertyMock, patch
import pytest
from pytest_unordered import unordered
import voluptuous as vol
from homeassistant.const import (
@ -1031,17 +1032,18 @@ async def test_statemachine_is_state(hass: HomeAssistant) -> None:
async def test_statemachine_entity_ids(hass: HomeAssistant) -> None:
"""Test get_entity_ids method."""
"""Test async_entity_ids method."""
assert hass.states.async_entity_ids() == []
assert hass.states.async_entity_ids("light") == []
assert hass.states.async_entity_ids(("light", "switch", "other")) == []
hass.states.async_set("light.bowl", "on", {})
hass.states.async_set("SWITCH.AC", "off", {})
ent_ids = hass.states.async_entity_ids()
assert len(ent_ids) == 2
assert "light.bowl" in ent_ids
assert "switch.ac" in ent_ids
ent_ids = hass.states.async_entity_ids("light")
assert len(ent_ids) == 1
assert "light.bowl" in ent_ids
assert hass.states.async_entity_ids() == unordered(["light.bowl", "switch.ac"])
assert hass.states.async_entity_ids("light") == ["light.bowl"]
assert hass.states.async_entity_ids(("light", "switch", "other")) == unordered(
["light.bowl", "switch.ac"]
)
states = sorted(state.entity_id for state in hass.states.async_all())
assert states == ["light.bowl", "switch.ac"]
@ -1902,6 +1904,9 @@ async def test_chained_logging_misses_log_timeout(
async def test_async_all(hass: HomeAssistant) -> None:
"""Test async_all."""
assert hass.states.async_all() == []
assert hass.states.async_all("light") == []
assert hass.states.async_all(["light", "switch"]) == []
hass.states.async_set("switch.link", "on")
hass.states.async_set("light.bowl", "on")
@ -1926,6 +1931,10 @@ async def test_async_all(hass: HomeAssistant) -> None:
async def test_async_entity_ids_count(hass: HomeAssistant) -> None:
"""Test async_entity_ids_count."""
assert hass.states.async_entity_ids_count() == 0
assert hass.states.async_entity_ids_count("light") == 0
assert hass.states.async_entity_ids_count({"light", "vacuum"}) == 0
hass.states.async_set("switch.link", "on")
hass.states.async_set("light.bowl", "on")
hass.states.async_set("light.frog", "on")