Improve typing of entity.entity_sources (#99407)

* Improve typing of entity.entity_sources

* Calculate entity info source when generating WS response

* Adjust typing

* Update tests
This commit is contained in:
Erik Montnemery 2023-09-12 20:41:26 +02:00 committed by GitHub
parent cc252f705f
commit 51576b7214
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 39 additions and 26 deletions

View file

@ -707,7 +707,8 @@ class MediaPlayerCapabilities(AlexaEntity):
# AlexaEqualizerController is disabled for denonavr
# since it blocks alexa from discovering any devices.
domain = entity_sources(self.hass).get(self.entity_id, {}).get("domain")
entity_info = entity_sources(self.hass).get(self.entity_id)
domain = entity_info["domain"] if entity_info else None
if (
supported & media_player.MediaPlayerEntityFeature.SELECT_SOUND_MODE
and domain != "denonavr"

View file

@ -40,6 +40,7 @@ from homeassistant.const import (
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers.entity import EntityInfo
from homeassistant.helpers.json import JSON_DUMP, json_bytes, json_bytes_strip_null
import homeassistant.util.dt as dt_util
from homeassistant.util.json import (
@ -558,7 +559,7 @@ class StateAttributes(Base):
@staticmethod
def shared_attrs_bytes_from_event(
event: Event,
entity_sources: dict[str, dict[str, str]],
entity_sources: dict[str, EntityInfo],
exclude_attrs_by_domain: dict[str, set[str]],
dialect: SupportedDialect | None,
) -> bytes:

View file

@ -15,7 +15,10 @@ from homeassistant.helpers import (
device_registry as dr,
entity_registry as er,
)
from homeassistant.helpers.entity import entity_sources as get_entity_sources
from homeassistant.helpers.entity import (
EntityInfo,
entity_sources as get_entity_sources,
)
from homeassistant.helpers.typing import ConfigType
DOMAIN = "search"
@ -97,7 +100,7 @@ class Searcher:
hass: HomeAssistant,
device_reg: dr.DeviceRegistry,
entity_reg: er.EntityRegistry,
entity_sources: dict[str, dict[str, str]],
entity_sources: dict[str, EntityInfo],
) -> None:
"""Search results."""
self.hass = hass

View file

@ -262,8 +262,9 @@ def _normalize_states(
def _suggest_report_issue(hass: HomeAssistant, entity_id: str) -> str:
"""Suggest to report an issue."""
domain = entity_sources(hass).get(entity_id, {}).get("domain")
custom_component = entity_sources(hass).get(entity_id, {}).get("custom_component")
entity_info = entity_sources(hass).get(entity_id)
domain = entity_info["domain"] if entity_info else None
custom_component = entity_info["custom_component"] if entity_info else None
report_issue = ""
if custom_component:
report_issue = "report it to the custom integration author."
@ -296,7 +297,8 @@ def warn_dip(
hass.data[WARN_DIP] = set()
if entity_id not in hass.data[WARN_DIP]:
hass.data[WARN_DIP].add(entity_id)
domain = entity_sources(hass).get(entity_id, {}).get("domain")
entity_info = entity_sources(hass).get(entity_id)
domain = entity_info["domain"] if entity_info else None
if domain in ["energy", "growatt_server", "solaredge"]:
return
_LOGGER.warning(
@ -320,7 +322,8 @@ def warn_negative(hass: HomeAssistant, entity_id: str, state: State) -> None:
hass.data[WARN_NEGATIVE] = set()
if entity_id not in hass.data[WARN_NEGATIVE]:
hass.data[WARN_NEGATIVE].add(entity_id)
domain = entity_sources(hass).get(entity_id, {}).get("domain")
entity_info = entity_sources(hass).get(entity_id)
domain = entity_info["domain"] if entity_info else None
_LOGGER.warning(
(
"Entity %s %shas state class total_increasing, but its state is "

View file

@ -596,7 +596,7 @@ async def handle_render_template(
def _serialize_entity_sources(
entity_infos: dict[str, dict[str, str]]
entity_infos: dict[str, entity.EntityInfo]
) -> dict[str, Any]:
"""Prepare a websocket response from a dict of entity sources."""
result = {}

View file

@ -12,7 +12,16 @@ import logging
import math
import sys
from timeit import default_timer as timer
from typing import TYPE_CHECKING, Any, Final, Literal, TypeVar, final
from typing import (
TYPE_CHECKING,
Any,
Final,
Literal,
NotRequired,
TypedDict,
TypeVar,
final,
)
import voluptuous as vol
@ -60,8 +69,6 @@ _T = TypeVar("_T")
_LOGGER = logging.getLogger(__name__)
SLOW_UPDATE_WARNING = 10
DATA_ENTITY_SOURCE = "entity_info"
SOURCE_CONFIG_ENTRY = "config_entry"
SOURCE_PLATFORM_CONFIG = "platform_config"
# Used when converting float states to string: limit precision according to machine
# epsilon to make the string representation readable
@ -76,9 +83,9 @@ def async_setup(hass: HomeAssistant) -> None:
@callback
@bind_hass
def entity_sources(hass: HomeAssistant) -> dict[str, dict[str, str]]:
def entity_sources(hass: HomeAssistant) -> dict[str, EntityInfo]:
"""Get the entity sources."""
_entity_sources: dict[str, dict[str, str]] = hass.data[DATA_ENTITY_SOURCE]
_entity_sources: dict[str, EntityInfo] = hass.data[DATA_ENTITY_SOURCE]
return _entity_sources
@ -181,6 +188,14 @@ def get_unit_of_measurement(hass: HomeAssistant, entity_id: str) -> str | None:
ENTITY_CATEGORIES_SCHEMA: Final = vol.Coerce(EntityCategory)
class EntityInfo(TypedDict):
"""Entity info."""
domain: str
custom_component: bool
config_entry: NotRequired[str]
class EntityPlatformState(Enum):
"""The platform state of an entity."""
@ -1061,18 +1076,15 @@ class Entity(ABC):
Not to be extended by integrations.
"""
info = {
info: EntityInfo = {
"domain": self.platform.platform_name,
"custom_component": "custom_components" in type(self).__module__,
}
if self.platform.config_entry:
info["source"] = SOURCE_CONFIG_ENTRY
info["config_entry"] = self.platform.config_entry.entry_id
else:
info["source"] = SOURCE_PLATFORM_CONFIG
self.hass.data[DATA_ENTITY_SOURCE][self.entity_id] = info
entity_sources(self.hass)[self.entity_id] = info
if self.registry_entry is not None:
# This is an assert as it should never happen, but helps in tests

View file

@ -6,7 +6,6 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity,
entity_registry as er,
)
from homeassistant.setup import async_setup_component
@ -22,11 +21,9 @@ def stub_blueprint_populate_autouse(stub_blueprint_populate: None) -> None:
MOCK_ENTITY_SOURCES = {
"light.platform_config_source": {
"source": entity.SOURCE_PLATFORM_CONFIG,
"domain": "wled",
},
"light.config_entry_source": {
"source": entity.SOURCE_CONFIG_ENTRY,
"config_entry": "config_entry_id",
"domain": "wled",
},
@ -73,11 +70,9 @@ async def test_search(
entity_sources = {
"light.wled_platform_config_source": {
"source": entity.SOURCE_PLATFORM_CONFIG,
"domain": "wled",
},
"light.wled_config_entry_source": {
"source": entity.SOURCE_CONFIG_ENTRY,
"config_entry": wled_config_entry.entry_id,
"domain": "wled",
},

View file

@ -795,13 +795,11 @@ async def test_setup_source(hass: HomeAssistant) -> None:
"test_domain.platform_config_source": {
"custom_component": False,
"domain": "test_platform",
"source": entity.SOURCE_PLATFORM_CONFIG,
},
"test_domain.config_entry_source": {
"config_entry": platform.config_entry.entry_id,
"custom_component": False,
"domain": "test_platform",
"source": entity.SOURCE_CONFIG_ENTRY,
},
}