From dace9b32de3bc33f0305790eb39b92a28f3c4aa1 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 30 Apr 2024 11:29:43 +0200 Subject: [PATCH] Store runtime data inside ConfigEntry (#115669) --- homeassistant/components/adguard/__init__.py | 20 +++--- homeassistant/components/adguard/entity.py | 6 +- homeassistant/components/adguard/sensor.py | 9 ++- homeassistant/components/adguard/switch.py | 9 ++- homeassistant/config_entries.py | 7 +- pylint/plugins/hass_enforce_type_hints.py | 14 ++++ tests/pylint/test_enforce_type_hints.py | 76 ++++++++++++++++++++ 7 files changed, 118 insertions(+), 23 deletions(-) diff --git a/homeassistant/components/adguard/__init__.py b/homeassistant/components/adguard/__init__.py index 874a4cae963..d6274659f1d 100644 --- a/homeassistant/components/adguard/__init__.py +++ b/homeassistant/components/adguard/__init__.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from adguardhome import AdGuardHome, AdGuardHomeConnectionError import voluptuous as vol -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.const import ( CONF_HOST, CONF_NAME, @@ -43,6 +43,7 @@ SERVICE_REFRESH_SCHEMA = vol.Schema( ) PLATFORMS = [Platform.SENSOR, Platform.SWITCH] +AdGuardConfigEntry = ConfigEntry["AdGuardData"] @dataclass @@ -53,7 +54,7 @@ class AdGuardData: version: str -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_setup_entry(hass: HomeAssistant, entry: AdGuardConfigEntry) -> bool: """Set up AdGuard Home from a config entry.""" session = async_get_clientsession(hass, entry.data[CONF_VERIFY_SSL]) adguard = AdGuardHome( @@ -71,7 +72,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: except AdGuardHomeConnectionError as exception: raise ConfigEntryNotReady from exception - hass.data.setdefault(DOMAIN, {})[entry.entry_id] = AdGuardData(adguard, version) + entry.runtime_data = AdGuardData(adguard, version) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) @@ -116,17 +117,20 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return True -async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_unload_entry(hass: HomeAssistant, entry: AdGuardConfigEntry) -> bool: """Unload AdGuard Home config entry.""" unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) - if unload_ok: - hass.data[DOMAIN].pop(entry.entry_id) - if not hass.data[DOMAIN]: + loaded_entries = [ + entry + for entry in hass.config_entries.async_entries(DOMAIN) + if entry.state == ConfigEntryState.LOADED + ] + if len(loaded_entries) == 1: + # This is the last loaded instance of AdGuard, deregister any services hass.services.async_remove(DOMAIN, SERVICE_ADD_URL) hass.services.async_remove(DOMAIN, SERVICE_REMOVE_URL) hass.services.async_remove(DOMAIN, SERVICE_ENABLE_URL) hass.services.async_remove(DOMAIN, SERVICE_DISABLE_URL) hass.services.async_remove(DOMAIN, SERVICE_REFRESH) - del hass.data[DOMAIN] return unload_ok diff --git a/homeassistant/components/adguard/entity.py b/homeassistant/components/adguard/entity.py index a4e16f1b995..65d20a4e88c 100644 --- a/homeassistant/components/adguard/entity.py +++ b/homeassistant/components/adguard/entity.py @@ -4,11 +4,11 @@ from __future__ import annotations from adguardhome import AdGuardHomeError -from homeassistant.config_entries import SOURCE_HASSIO, ConfigEntry +from homeassistant.config_entries import SOURCE_HASSIO from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.entity import Entity -from . import AdGuardData +from . import AdGuardConfigEntry, AdGuardData from .const import DOMAIN, LOGGER @@ -21,7 +21,7 @@ class AdGuardHomeEntity(Entity): def __init__( self, data: AdGuardData, - entry: ConfigEntry, + entry: AdGuardConfigEntry, ) -> None: """Initialize the AdGuard Home entity.""" self._entry = entry diff --git a/homeassistant/components/adguard/sensor.py b/homeassistant/components/adguard/sensor.py index ce112f49531..b2404a88278 100644 --- a/homeassistant/components/adguard/sensor.py +++ b/homeassistant/components/adguard/sensor.py @@ -10,12 +10,11 @@ from typing import Any from adguardhome import AdGuardHome from homeassistant.components.sensor import SensorEntity, SensorEntityDescription -from homeassistant.config_entries import ConfigEntry from homeassistant.const import PERCENTAGE, UnitOfTime from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback -from . import AdGuardData +from . import AdGuardConfigEntry, AdGuardData from .const import DOMAIN from .entity import AdGuardHomeEntity @@ -85,11 +84,11 @@ SENSORS: tuple[AdGuardHomeEntityDescription, ...] = ( async def async_setup_entry( hass: HomeAssistant, - entry: ConfigEntry, + entry: AdGuardConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up AdGuard Home sensor based on a config entry.""" - data: AdGuardData = hass.data[DOMAIN][entry.entry_id] + data = entry.runtime_data async_add_entities( [AdGuardHomeSensor(data, entry, description) for description in SENSORS], @@ -105,7 +104,7 @@ class AdGuardHomeSensor(AdGuardHomeEntity, SensorEntity): def __init__( self, data: AdGuardData, - entry: ConfigEntry, + entry: AdGuardConfigEntry, description: AdGuardHomeEntityDescription, ) -> None: """Initialize AdGuard Home sensor.""" diff --git a/homeassistant/components/adguard/switch.py b/homeassistant/components/adguard/switch.py index e084ed2f349..3ea4f9d1d93 100644 --- a/homeassistant/components/adguard/switch.py +++ b/homeassistant/components/adguard/switch.py @@ -10,11 +10,10 @@ from typing import Any from adguardhome import AdGuardHome, AdGuardHomeError from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription -from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback -from . import AdGuardData +from . import AdGuardConfigEntry, AdGuardData from .const import DOMAIN, LOGGER from .entity import AdGuardHomeEntity @@ -79,11 +78,11 @@ SWITCHES: tuple[AdGuardHomeSwitchEntityDescription, ...] = ( async def async_setup_entry( hass: HomeAssistant, - entry: ConfigEntry, + entry: AdGuardConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up AdGuard Home switch based on a config entry.""" - data: AdGuardData = hass.data[DOMAIN][entry.entry_id] + data = entry.runtime_data async_add_entities( [AdGuardHomeSwitch(data, entry, description) for description in SWITCHES], @@ -99,7 +98,7 @@ class AdGuardHomeSwitch(AdGuardHomeEntity, SwitchEntity): def __init__( self, data: AdGuardData, - entry: ConfigEntry, + entry: AdGuardConfigEntry, description: AdGuardHomeSwitchEntityDescription, ) -> None: """Initialize AdGuard Home switch.""" diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 619b2a4b48a..123424108fc 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -21,9 +21,10 @@ from functools import cached_property import logging from random import randint from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Self, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Self, cast from async_interrupt import interrupt +from typing_extensions import TypeVar from . import data_entry_flow, loader from .components import persistent_notification @@ -124,6 +125,7 @@ SAVE_DELAY = 1 DISCOVERY_COOLDOWN = 1 +_DataT = TypeVar("_DataT", default=Any) _R = TypeVar("_R") @@ -266,13 +268,14 @@ class ConfigFlowResult(FlowResult, total=False): version: int -class ConfigEntry: +class ConfigEntry(Generic[_DataT]): """Hold a configuration entry.""" entry_id: str domain: str title: str data: MappingProxyType[str, Any] + runtime_data: _DataT options: MappingProxyType[str, Any] unique_id: str | None state: ConfigEntryState diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index 7d48fa4b2e3..2f107fb1bf2 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -23,6 +23,10 @@ _COMMON_ARGUMENTS: dict[str, list[str]] = { "hass": ["HomeAssistant", "HomeAssistant | None"] } _PLATFORMS: set[str] = {platform.value for platform in Platform} +_KNOWN_GENERIC_TYPES: set[str] = { + "ConfigEntry", +} +_KNOWN_GENERIC_TYPES_TUPLE = tuple(_KNOWN_GENERIC_TYPES) class _Special(Enum): @@ -2977,6 +2981,16 @@ def _is_valid_type( ): return True + # Allow subscripts or type aliases for generic types + if ( + isinstance(node, nodes.Subscript) + and isinstance(node.value, nodes.Name) + and node.value.name in _KNOWN_GENERIC_TYPES + or isinstance(node, nodes.Name) + and node.name.endswith(_KNOWN_GENERIC_TYPES_TUPLE) + ): + return True + # Name occurs when a namespace is not used, eg. "HomeAssistant" if isinstance(node, nodes.Name) and node.name == expected_type: return True diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index 78eb682200a..ad3b7d62be9 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -1196,3 +1196,79 @@ def test_pytest_invalid_function( ), ): type_hint_checker.visit_asyncfunctiondef(func_node) + + +@pytest.mark.parametrize( + "entry_annotation", + [ + "ConfigEntry", + "ConfigEntry[AdGuardData]", + "AdGuardConfigEntry", # prefix allowed for type aliases + ], +) +def test_valid_generic( + linter: UnittestLinter, type_hint_checker: BaseChecker, entry_annotation: str +) -> None: + """Ensure valid hints are accepted for generic types.""" + func_node = astroid.extract_node( + f""" + async def async_setup_entry( #@ + hass: HomeAssistant, + entry: {entry_annotation}, + async_add_entities: AddEntitiesCallback, + ) -> None: + pass + """, + "homeassistant.components.pylint_test.notify", + ) + type_hint_checker.visit_module(func_node.parent) + + with assert_no_messages(linter): + type_hint_checker.visit_asyncfunctiondef(func_node) + + +@pytest.mark.parametrize( + ("entry_annotation", "end_col_offset"), + [ + ("Config", 17), # not generic + ("ConfigEntryXX[Data]", 30), # generic type needs to match exactly + ("ConfigEntryData", 26), # ConfigEntry should be the suffix + ], +) +def test_invalid_generic( + linter: UnittestLinter, + type_hint_checker: BaseChecker, + entry_annotation: str, + end_col_offset: int, +) -> None: + """Ensure invalid hints are rejected for generic types.""" + func_node, entry_node = astroid.extract_node( + f""" + async def async_setup_entry( #@ + hass: HomeAssistant, + entry: {entry_annotation}, #@ + async_add_entities: AddEntitiesCallback, + ) -> None: + pass + """, + "homeassistant.components.pylint_test.notify", + ) + type_hint_checker.visit_module(func_node.parent) + + with assert_adds_messages( + linter, + pylint.testutils.MessageTest( + msg_id="hass-argument-type", + node=entry_node, + args=( + 2, + "ConfigEntry", + "async_setup_entry", + ), + line=4, + col_offset=4, + end_line=4, + end_col_offset=end_col_offset, + ), + ): + type_hint_checker.visit_asyncfunctiondef(func_node)