Use HassKey for helpers (1) (#117012)

This commit is contained in:
Marc Mueller 2024-05-07 18:25:16 +02:00 committed by GitHub
parent 8f614fb06d
commit 2db64c7e6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 68 additions and 48 deletions

View file

@ -20,6 +20,7 @@ from homeassistant.const import APPLICATION_NAME, EVENT_HOMEASSISTANT_CLOSE, __v
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import ssl as ssl_util from homeassistant.util import ssl as ssl_util
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import json_loads from homeassistant.util.json import json_loads
from .backports.aiohttp_resolver import AsyncResolver from .backports.aiohttp_resolver import AsyncResolver
@ -30,8 +31,12 @@ if TYPE_CHECKING:
from aiohttp.typedefs import JSONDecoder from aiohttp.typedefs import JSONDecoder
DATA_CONNECTOR = "aiohttp_connector" DATA_CONNECTOR: HassKey[dict[tuple[bool, int], aiohttp.BaseConnector]] = HassKey(
DATA_CLIENTSESSION = "aiohttp_clientsession" "aiohttp_connector"
)
DATA_CLIENTSESSION: HassKey[dict[tuple[bool, int], aiohttp.ClientSession]] = HassKey(
"aiohttp_clientsession"
)
SERVER_SOFTWARE = ( SERVER_SOFTWARE = (
f"{APPLICATION_NAME}/{__version__} " f"{APPLICATION_NAME}/{__version__} "
@ -84,11 +89,7 @@ def async_get_clientsession(
This method must be run in the event loop. This method must be run in the event loop.
""" """
session_key = _make_key(verify_ssl, family) session_key = _make_key(verify_ssl, family)
if DATA_CLIENTSESSION not in hass.data: sessions = hass.data.setdefault(DATA_CLIENTSESSION, {})
sessions: dict[tuple[bool, int], aiohttp.ClientSession] = {}
hass.data[DATA_CLIENTSESSION] = sessions
else:
sessions = hass.data[DATA_CLIENTSESSION]
if session_key not in sessions: if session_key not in sessions:
session = _async_create_clientsession( session = _async_create_clientsession(
@ -288,11 +289,7 @@ def _async_get_connector(
This method must be run in the event loop. This method must be run in the event loop.
""" """
connector_key = _make_key(verify_ssl, family) connector_key = _make_key(verify_ssl, family)
if DATA_CONNECTOR not in hass.data: connectors = hass.data.setdefault(DATA_CONNECTOR, {})
connectors: dict[tuple[bool, int], aiohttp.BaseConnector] = {}
hass.data[DATA_CONNECTOR] = connectors
else:
connectors = hass.data[DATA_CONNECTOR]
if connector_key in connectors: if connector_key in connectors:
return connectors[connector_key] return connectors[connector_key]

View file

@ -27,6 +27,7 @@ from homeassistant import config_entries
from homeassistant.components import http from homeassistant.components import http
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.loader import async_get_application_credentials from homeassistant.loader import async_get_application_credentials
from homeassistant.util.hass_dict import HassKey
from .aiohttp_client import async_get_clientsession from .aiohttp_client import async_get_clientsession
from .network import NoURLAvailableError from .network import NoURLAvailableError
@ -34,8 +35,15 @@ from .network import NoURLAvailableError
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DATA_JWT_SECRET = "oauth2_jwt_secret" DATA_JWT_SECRET = "oauth2_jwt_secret"
DATA_IMPLEMENTATIONS = "oauth2_impl" DATA_IMPLEMENTATIONS: HassKey[dict[str, dict[str, AbstractOAuth2Implementation]]] = (
DATA_PROVIDERS = "oauth2_providers" HassKey("oauth2_impl")
)
DATA_PROVIDERS: HassKey[
dict[
str,
Callable[[HomeAssistant, str], Awaitable[list[AbstractOAuth2Implementation]]],
]
] = HassKey("oauth2_providers")
AUTH_CALLBACK_PATH = "/auth/external/callback" AUTH_CALLBACK_PATH = "/auth/external/callback"
HEADER_FRONTEND_BASE = "HA-Frontend-Base" HEADER_FRONTEND_BASE = "HA-Frontend-Base"
MY_AUTH_CALLBACK_PATH = "https://my.home-assistant.io/redirect/oauth" MY_AUTH_CALLBACK_PATH = "https://my.home-assistant.io/redirect/oauth"
@ -398,10 +406,7 @@ async def async_get_implementations(
hass: HomeAssistant, domain: str hass: HomeAssistant, domain: str
) -> dict[str, AbstractOAuth2Implementation]: ) -> dict[str, AbstractOAuth2Implementation]:
"""Return OAuth2 implementations for specified domain.""" """Return OAuth2 implementations for specified domain."""
registered = cast( registered = hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {})
dict[str, AbstractOAuth2Implementation],
hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}),
)
if DATA_PROVIDERS not in hass.data: if DATA_PROVIDERS not in hass.data:
return registered return registered

View file

@ -10,9 +10,12 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import CoreState, Event, HomeAssistant, callback from homeassistant.core import CoreState, Event, HomeAssistant, callback
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util.async_ import gather_with_limited_concurrency from homeassistant.util.async_ import gather_with_limited_concurrency
from homeassistant.util.hass_dict import HassKey
FLOW_INIT_LIMIT = 20 FLOW_INIT_LIMIT = 20
DISCOVERY_FLOW_DISPATCHER = "discovery_flow_dispatcher" DISCOVERY_FLOW_DISPATCHER: HassKey[FlowDispatcher] = HassKey(
"discovery_flow_dispatcher"
)
@bind_hass @bind_hass

View file

@ -34,6 +34,7 @@ from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.generated import languages from homeassistant.generated import languages
from homeassistant.setup import SetupPhases, async_start_setup from homeassistant.setup import SetupPhases, async_start_setup
from homeassistant.util.async_ import create_eager_task from homeassistant.util.async_ import create_eager_task
from homeassistant.util.hass_dict import HassKey
from . import ( from . import (
config_validation as cv, config_validation as cv,
@ -57,9 +58,13 @@ SLOW_ADD_ENTITY_MAX_WAIT = 15 # Per Entity
SLOW_ADD_MIN_TIMEOUT = 500 SLOW_ADD_MIN_TIMEOUT = 500
PLATFORM_NOT_READY_RETRIES = 10 PLATFORM_NOT_READY_RETRIES = 10
DATA_ENTITY_PLATFORM = "entity_platform" DATA_ENTITY_PLATFORM: HassKey[dict[str, list[EntityPlatform]]] = HassKey(
DATA_DOMAIN_ENTITIES = "domain_entities" "entity_platform"
DATA_DOMAIN_PLATFORM_ENTITIES = "domain_platform_entities" )
DATA_DOMAIN_ENTITIES: HassKey[dict[str, dict[str, Entity]]] = HassKey("domain_entities")
DATA_DOMAIN_PLATFORM_ENTITIES: HassKey[dict[tuple[str, str], dict[str, Entity]]] = (
HassKey("domain_platform_entities")
)
PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds
_LOGGER = getLogger(__name__) _LOGGER = getLogger(__name__)
@ -155,20 +160,18 @@ class EntityPlatform:
# with the child dict indexed by entity_id # with the child dict indexed by entity_id
# #
# This is usually media_player, light, switch, etc. # This is usually media_player, light, switch, etc.
domain_entities: dict[str, dict[str, Entity]] = hass.data.setdefault( self.domain_entities = hass.data.setdefault(
DATA_DOMAIN_ENTITIES, {} DATA_DOMAIN_ENTITIES, {}
) ).setdefault(domain, {})
self.domain_entities = domain_entities.setdefault(domain, {})
# Storage for entities indexed by domain and platform # Storage for entities indexed by domain and platform
# with the child dict indexed by entity_id # with the child dict indexed by entity_id
# #
# This is usually media_player.yamaha, light.hue, switch.tplink, etc. # This is usually media_player.yamaha, light.hue, switch.tplink, etc.
domain_platform_entities: dict[tuple[str, str], dict[str, Entity]] = (
hass.data.setdefault(DATA_DOMAIN_PLATFORM_ENTITIES, {})
)
key = (domain, platform_name) key = (domain, platform_name)
self.domain_platform_entities = domain_platform_entities.setdefault(key, {}) self.domain_platform_entities = hass.data.setdefault(
DATA_DOMAIN_PLATFORM_ENTITIES, {}
).setdefault(key, {})
def __repr__(self) -> str: def __repr__(self) -> str:
"""Represent an EntityPlatform.""" """Represent an EntityPlatform."""
@ -1063,6 +1066,4 @@ def async_get_platforms(
): ):
return [] return []
platforms: list[EntityPlatform] = hass.data[DATA_ENTITY_PLATFORM][integration_name] return hass.data[DATA_ENTITY_PLATFORM][integration_name]
return platforms

View file

@ -38,6 +38,7 @@ from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.event_type import EventType from homeassistant.util.event_type import EventType
from homeassistant.util.hass_dict import HassKey
from . import frame from . import frame
from .device_registry import ( from .device_registry import (
@ -54,19 +55,29 @@ from .template import RenderInfo, Template, result_as_boolean
from .typing import TemplateVarsType from .typing import TemplateVarsType
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks" TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener" TRACK_STATE_CHANGE_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_state_change_listener"
)
TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks" TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks"
TRACK_STATE_ADDED_DOMAIN_LISTENER = "track_state_added_domain_listener" TRACK_STATE_ADDED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_state_added_domain_listener"
)
TRACK_STATE_REMOVED_DOMAIN_CALLBACKS = "track_state_removed_domain_callbacks" TRACK_STATE_REMOVED_DOMAIN_CALLBACKS = "track_state_removed_domain_callbacks"
TRACK_STATE_REMOVED_DOMAIN_LISTENER = "track_state_removed_domain_listener" TRACK_STATE_REMOVED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_state_removed_domain_listener"
)
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks" TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener" TRACK_ENTITY_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_entity_registry_updated_listener"
)
TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS = "track_device_registry_updated_callbacks" TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS = "track_device_registry_updated_callbacks"
TRACK_DEVICE_REGISTRY_UPDATED_LISTENER = "track_device_registry_updated_listener" TRACK_DEVICE_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_device_registry_updated_listener"
)
_ALL_LISTENER = "all" _ALL_LISTENER = "all"
_DOMAINS_LISTENER = "domains" _DOMAINS_LISTENER = "domains"
@ -89,7 +100,7 @@ _P = ParamSpec("_P")
class _KeyedEventTracker(Generic[_TypedDictT]): class _KeyedEventTracker(Generic[_TypedDictT]):
"""Class to track events by key.""" """Class to track events by key."""
listeners_key: str listeners_key: HassKey[Callable[[], None]]
callbacks_key: str callbacks_key: str
event_type: EventType[_TypedDictT] | str event_type: EventType[_TypedDictT] | str
dispatcher_callable: Callable[ dispatcher_callable: Callable[
@ -373,7 +384,7 @@ def _remove_empty_listener() -> None:
@callback # type: ignore[arg-type] # mypy bug? @callback # type: ignore[arg-type] # mypy bug?
def _remove_listener( def _remove_listener(
hass: HomeAssistant, hass: HomeAssistant,
listeners_key: str, listeners_key: HassKey[Callable[[], None]],
keys: Iterable[str], keys: Iterable[str],
job: HassJob[[Event[_TypedDictT]], Any], job: HassJob[[Event[_TypedDictT]], Any],
callbacks: dict[str, list[HassJob[[Event[_TypedDictT]], Any]]], callbacks: dict[str, list[HassJob[[Event[_TypedDictT]], Any]]],

View file

@ -11,6 +11,7 @@ import httpx
from homeassistant.const import APPLICATION_NAME, EVENT_HOMEASSISTANT_CLOSE, __version__ from homeassistant.const import APPLICATION_NAME, EVENT_HOMEASSISTANT_CLOSE, __version__
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.ssl import ( from homeassistant.util.ssl import (
SSLCipherList, SSLCipherList,
client_context, client_context,
@ -23,8 +24,10 @@ from .frame import warn_use
# and we want to keep the connection open for a while so we # and we want to keep the connection open for a while so we
# don't have to reconnect every time so we use 15s to match aiohttp. # don't have to reconnect every time so we use 15s to match aiohttp.
KEEP_ALIVE_TIMEOUT = 15 KEEP_ALIVE_TIMEOUT = 15
DATA_ASYNC_CLIENT = "httpx_async_client" DATA_ASYNC_CLIENT: HassKey[httpx.AsyncClient] = HassKey("httpx_async_client")
DATA_ASYNC_CLIENT_NOVERIFY = "httpx_async_client_noverify" DATA_ASYNC_CLIENT_NOVERIFY: HassKey[httpx.AsyncClient] = HassKey(
"httpx_async_client_noverify"
)
DEFAULT_LIMITS = limits = httpx.Limits(keepalive_expiry=KEEP_ALIVE_TIMEOUT) DEFAULT_LIMITS = limits = httpx.Limits(keepalive_expiry=KEEP_ALIVE_TIMEOUT)
SERVER_SOFTWARE = ( SERVER_SOFTWARE = (
f"{APPLICATION_NAME}/{__version__} " f"{APPLICATION_NAME}/{__version__} "
@ -42,9 +45,7 @@ def get_async_client(hass: HomeAssistant, verify_ssl: bool = True) -> httpx.Asyn
""" """
key = DATA_ASYNC_CLIENT if verify_ssl else DATA_ASYNC_CLIENT_NOVERIFY key = DATA_ASYNC_CLIENT if verify_ssl else DATA_ASYNC_CLIENT_NOVERIFY
client: httpx.AsyncClient | None = hass.data.get(key) if (client := hass.data.get(key)) is None:
if client is None:
client = hass.data[key] = create_async_httpx_client(hass, verify_ssl) client = hass.data[key] = create_async_httpx_client(hass, verify_ssl)
return client return client

View file

@ -11,11 +11,12 @@ from typing import Any
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.loader import Integration, async_get_integrations from homeassistant.loader import Integration, async_get_integrations
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import load_json_object from homeassistant.util.json import load_json_object
from .translation import build_resources from .translation import build_resources
ICON_CACHE = "icon_cache" ICON_CACHE: HassKey[_IconsCache] = HassKey("icon_cache")
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -142,7 +143,7 @@ async def async_get_icons(
components = hass.config.top_level_components components = hass.config.top_level_components
if ICON_CACHE in hass.data: if ICON_CACHE in hass.data:
cache: _IconsCache = hass.data[ICON_CACHE] cache = hass.data[ICON_CACHE]
else: else:
cache = hass.data[ICON_CACHE] = _IconsCache(hass) cache = hass.data[ICON_CACHE] = _IconsCache(hass)

View file

@ -23,6 +23,7 @@ from homeassistant.const import (
from homeassistant.core import Context, HomeAssistant, State, callback from homeassistant.core import Context, HomeAssistant, State, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util.hass_dict import HassKey
from . import ( from . import (
area_registry, area_registry,
@ -44,7 +45,7 @@ INTENT_SET_POSITION = "HassSetPosition"
SLOT_SCHEMA = vol.Schema({}, extra=vol.ALLOW_EXTRA) SLOT_SCHEMA = vol.Schema({}, extra=vol.ALLOW_EXTRA)
DATA_KEY = "intent" DATA_KEY: HassKey[dict[str, IntentHandler]] = HassKey("intent")
SPEECH_TYPE_PLAIN = "plain" SPEECH_TYPE_PLAIN = "plain"
SPEECH_TYPE_SSML = "ssml" SPEECH_TYPE_SSML = "ssml"
@ -89,7 +90,7 @@ async def async_handle(
assistant: str | None = None, assistant: str | None = None,
) -> IntentResponse: ) -> IntentResponse:
"""Handle an intent.""" """Handle an intent."""
handler: IntentHandler = hass.data.get(DATA_KEY, {}).get(intent_type) handler = hass.data.get(DATA_KEY, {}).get(intent_type)
if handler is None: if handler is None:
raise UnknownIntent(f"Unknown intent {intent_type}") raise UnknownIntent(f"Unknown intent {intent_type}")