From 51576b7214e25693309252ffe77b20d1c682679a Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 12 Sep 2023 20:41:26 +0200 Subject: [PATCH] Improve typing of entity.entity_sources (#99407) * Improve typing of entity.entity_sources * Calculate entity info source when generating WS response * Adjust typing * Update tests --- homeassistant/components/alexa/entities.py | 3 +- .../components/recorder/db_schema.py | 3 +- homeassistant/components/search/__init__.py | 7 ++-- homeassistant/components/sensor/recorder.py | 11 ++++--- .../components/websocket_api/commands.py | 2 +- homeassistant/helpers/entity.py | 32 +++++++++++++------ tests/components/search/test_init.py | 5 --- tests/helpers/test_entity.py | 2 -- 8 files changed, 39 insertions(+), 26 deletions(-) diff --git a/homeassistant/components/alexa/entities.py b/homeassistant/components/alexa/entities.py index 7f6331515c6..da0bd8b36aa 100644 --- a/homeassistant/components/alexa/entities.py +++ b/homeassistant/components/alexa/entities.py @@ -707,7 +707,8 @@ class MediaPlayerCapabilities(AlexaEntity): # AlexaEqualizerController is disabled for denonavr # since it blocks alexa from discovering any devices. - domain = entity_sources(self.hass).get(self.entity_id, {}).get("domain") + entity_info = entity_sources(self.hass).get(self.entity_id) + domain = entity_info["domain"] if entity_info else None if ( supported & media_player.MediaPlayerEntityFeature.SELECT_SOUND_MODE and domain != "denonavr" diff --git a/homeassistant/components/recorder/db_schema.py b/homeassistant/components/recorder/db_schema.py index 508874c54e5..e25c6d6dd5f 100644 --- a/homeassistant/components/recorder/db_schema.py +++ b/homeassistant/components/recorder/db_schema.py @@ -40,6 +40,7 @@ from homeassistant.const import ( MAX_LENGTH_STATE_STATE, ) from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id +from homeassistant.helpers.entity import EntityInfo from homeassistant.helpers.json import JSON_DUMP, json_bytes, json_bytes_strip_null import homeassistant.util.dt as dt_util from homeassistant.util.json import ( @@ -558,7 +559,7 @@ class StateAttributes(Base): @staticmethod def shared_attrs_bytes_from_event( event: Event, - entity_sources: dict[str, dict[str, str]], + entity_sources: dict[str, EntityInfo], exclude_attrs_by_domain: dict[str, set[str]], dialect: SupportedDialect | None, ) -> bytes: diff --git a/homeassistant/components/search/__init__.py b/homeassistant/components/search/__init__.py index 69796800e61..ac9a13850d6 100644 --- a/homeassistant/components/search/__init__.py +++ b/homeassistant/components/search/__init__.py @@ -15,7 +15,10 @@ from homeassistant.helpers import ( device_registry as dr, entity_registry as er, ) -from homeassistant.helpers.entity import entity_sources as get_entity_sources +from homeassistant.helpers.entity import ( + EntityInfo, + entity_sources as get_entity_sources, +) from homeassistant.helpers.typing import ConfigType DOMAIN = "search" @@ -97,7 +100,7 @@ class Searcher: hass: HomeAssistant, device_reg: dr.DeviceRegistry, entity_reg: er.EntityRegistry, - entity_sources: dict[str, dict[str, str]], + entity_sources: dict[str, EntityInfo], ) -> None: """Search results.""" self.hass = hass diff --git a/homeassistant/components/sensor/recorder.py b/homeassistant/components/sensor/recorder.py index e5a35187c99..63096b16cd8 100644 --- a/homeassistant/components/sensor/recorder.py +++ b/homeassistant/components/sensor/recorder.py @@ -262,8 +262,9 @@ def _normalize_states( def _suggest_report_issue(hass: HomeAssistant, entity_id: str) -> str: """Suggest to report an issue.""" - domain = entity_sources(hass).get(entity_id, {}).get("domain") - custom_component = entity_sources(hass).get(entity_id, {}).get("custom_component") + entity_info = entity_sources(hass).get(entity_id) + domain = entity_info["domain"] if entity_info else None + custom_component = entity_info["custom_component"] if entity_info else None report_issue = "" if custom_component: report_issue = "report it to the custom integration author." @@ -296,7 +297,8 @@ def warn_dip( hass.data[WARN_DIP] = set() if entity_id not in hass.data[WARN_DIP]: hass.data[WARN_DIP].add(entity_id) - domain = entity_sources(hass).get(entity_id, {}).get("domain") + entity_info = entity_sources(hass).get(entity_id) + domain = entity_info["domain"] if entity_info else None if domain in ["energy", "growatt_server", "solaredge"]: return _LOGGER.warning( @@ -320,7 +322,8 @@ def warn_negative(hass: HomeAssistant, entity_id: str, state: State) -> None: hass.data[WARN_NEGATIVE] = set() if entity_id not in hass.data[WARN_NEGATIVE]: hass.data[WARN_NEGATIVE].add(entity_id) - domain = entity_sources(hass).get(entity_id, {}).get("domain") + entity_info = entity_sources(hass).get(entity_id) + domain = entity_info["domain"] if entity_info else None _LOGGER.warning( ( "Entity %s %shas state class total_increasing, but its state is " diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 66866197081..e140fef861e 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -596,7 +596,7 @@ async def handle_render_template( def _serialize_entity_sources( - entity_infos: dict[str, dict[str, str]] + entity_infos: dict[str, entity.EntityInfo] ) -> dict[str, Any]: """Prepare a websocket response from a dict of entity sources.""" result = {} diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 7bd510b6fa1..99c71e2cc86 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -12,7 +12,16 @@ import logging import math import sys from timeit import default_timer as timer -from typing import TYPE_CHECKING, Any, Final, Literal, TypeVar, final +from typing import ( + TYPE_CHECKING, + Any, + Final, + Literal, + NotRequired, + TypedDict, + TypeVar, + final, +) import voluptuous as vol @@ -60,8 +69,6 @@ _T = TypeVar("_T") _LOGGER = logging.getLogger(__name__) SLOW_UPDATE_WARNING = 10 DATA_ENTITY_SOURCE = "entity_info" -SOURCE_CONFIG_ENTRY = "config_entry" -SOURCE_PLATFORM_CONFIG = "platform_config" # Used when converting float states to string: limit precision according to machine # epsilon to make the string representation readable @@ -76,9 +83,9 @@ def async_setup(hass: HomeAssistant) -> None: @callback @bind_hass -def entity_sources(hass: HomeAssistant) -> dict[str, dict[str, str]]: +def entity_sources(hass: HomeAssistant) -> dict[str, EntityInfo]: """Get the entity sources.""" - _entity_sources: dict[str, dict[str, str]] = hass.data[DATA_ENTITY_SOURCE] + _entity_sources: dict[str, EntityInfo] = hass.data[DATA_ENTITY_SOURCE] return _entity_sources @@ -181,6 +188,14 @@ def get_unit_of_measurement(hass: HomeAssistant, entity_id: str) -> str | None: ENTITY_CATEGORIES_SCHEMA: Final = vol.Coerce(EntityCategory) +class EntityInfo(TypedDict): + """Entity info.""" + + domain: str + custom_component: bool + config_entry: NotRequired[str] + + class EntityPlatformState(Enum): """The platform state of an entity.""" @@ -1061,18 +1076,15 @@ class Entity(ABC): Not to be extended by integrations. """ - info = { + info: EntityInfo = { "domain": self.platform.platform_name, "custom_component": "custom_components" in type(self).__module__, } if self.platform.config_entry: - info["source"] = SOURCE_CONFIG_ENTRY info["config_entry"] = self.platform.config_entry.entry_id - else: - info["source"] = SOURCE_PLATFORM_CONFIG - self.hass.data[DATA_ENTITY_SOURCE][self.entity_id] = info + entity_sources(self.hass)[self.entity_id] = info if self.registry_entry is not None: # This is an assert as it should never happen, but helps in tests diff --git a/tests/components/search/test_init.py b/tests/components/search/test_init.py index 40ec9c22afe..ebf70a6239c 100644 --- a/tests/components/search/test_init.py +++ b/tests/components/search/test_init.py @@ -6,7 +6,6 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers import ( area_registry as ar, device_registry as dr, - entity, entity_registry as er, ) from homeassistant.setup import async_setup_component @@ -22,11 +21,9 @@ def stub_blueprint_populate_autouse(stub_blueprint_populate: None) -> None: MOCK_ENTITY_SOURCES = { "light.platform_config_source": { - "source": entity.SOURCE_PLATFORM_CONFIG, "domain": "wled", }, "light.config_entry_source": { - "source": entity.SOURCE_CONFIG_ENTRY, "config_entry": "config_entry_id", "domain": "wled", }, @@ -73,11 +70,9 @@ async def test_search( entity_sources = { "light.wled_platform_config_source": { - "source": entity.SOURCE_PLATFORM_CONFIG, "domain": "wled", }, "light.wled_config_entry_source": { - "source": entity.SOURCE_CONFIG_ENTRY, "config_entry": wled_config_entry.entry_id, "domain": "wled", }, diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index 20bea6a98eb..68eed5b6e32 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -795,13 +795,11 @@ async def test_setup_source(hass: HomeAssistant) -> None: "test_domain.platform_config_source": { "custom_component": False, "domain": "test_platform", - "source": entity.SOURCE_PLATFORM_CONFIG, }, "test_domain.config_entry_source": { "config_entry": platform.config_entry.entry_id, "custom_component": False, "domain": "test_platform", - "source": entity.SOURCE_CONFIG_ENTRY, }, }