Use HassKey for helpers (1) (#117012)

This commit is contained in:
Marc Mueller 2024-05-07 18:25:16 +02:00 committed by GitHub
parent 8f614fb06d
commit 2db64c7e6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 68 additions and 48 deletions

View file

@ -20,6 +20,7 @@ from homeassistant.const import APPLICATION_NAME, EVENT_HOMEASSISTANT_CLOSE, __v
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.loader import bind_hass
from homeassistant.util import ssl as ssl_util
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import json_loads
from .backports.aiohttp_resolver import AsyncResolver
@ -30,8 +31,12 @@ if TYPE_CHECKING:
from aiohttp.typedefs import JSONDecoder
DATA_CONNECTOR = "aiohttp_connector"
DATA_CLIENTSESSION = "aiohttp_clientsession"
DATA_CONNECTOR: HassKey[dict[tuple[bool, int], aiohttp.BaseConnector]] = HassKey(
"aiohttp_connector"
)
DATA_CLIENTSESSION: HassKey[dict[tuple[bool, int], aiohttp.ClientSession]] = HassKey(
"aiohttp_clientsession"
)
SERVER_SOFTWARE = (
f"{APPLICATION_NAME}/{__version__} "
@ -84,11 +89,7 @@ def async_get_clientsession(
This method must be run in the event loop.
"""
session_key = _make_key(verify_ssl, family)
if DATA_CLIENTSESSION not in hass.data:
sessions: dict[tuple[bool, int], aiohttp.ClientSession] = {}
hass.data[DATA_CLIENTSESSION] = sessions
else:
sessions = hass.data[DATA_CLIENTSESSION]
sessions = hass.data.setdefault(DATA_CLIENTSESSION, {})
if session_key not in sessions:
session = _async_create_clientsession(
@ -288,11 +289,7 @@ def _async_get_connector(
This method must be run in the event loop.
"""
connector_key = _make_key(verify_ssl, family)
if DATA_CONNECTOR not in hass.data:
connectors: dict[tuple[bool, int], aiohttp.BaseConnector] = {}
hass.data[DATA_CONNECTOR] = connectors
else:
connectors = hass.data[DATA_CONNECTOR]
connectors = hass.data.setdefault(DATA_CONNECTOR, {})
if connector_key in connectors:
return connectors[connector_key]

View file

@ -27,6 +27,7 @@ from homeassistant import config_entries
from homeassistant.components import http
from homeassistant.core import HomeAssistant, callback
from homeassistant.loader import async_get_application_credentials
from homeassistant.util.hass_dict import HassKey
from .aiohttp_client import async_get_clientsession
from .network import NoURLAvailableError
@ -34,8 +35,15 @@ from .network import NoURLAvailableError
_LOGGER = logging.getLogger(__name__)
DATA_JWT_SECRET = "oauth2_jwt_secret"
DATA_IMPLEMENTATIONS = "oauth2_impl"
DATA_PROVIDERS = "oauth2_providers"
DATA_IMPLEMENTATIONS: HassKey[dict[str, dict[str, AbstractOAuth2Implementation]]] = (
HassKey("oauth2_impl")
)
DATA_PROVIDERS: HassKey[
dict[
str,
Callable[[HomeAssistant, str], Awaitable[list[AbstractOAuth2Implementation]]],
]
] = HassKey("oauth2_providers")
AUTH_CALLBACK_PATH = "/auth/external/callback"
HEADER_FRONTEND_BASE = "HA-Frontend-Base"
MY_AUTH_CALLBACK_PATH = "https://my.home-assistant.io/redirect/oauth"
@ -398,10 +406,7 @@ async def async_get_implementations(
hass: HomeAssistant, domain: str
) -> dict[str, AbstractOAuth2Implementation]:
"""Return OAuth2 implementations for specified domain."""
registered = cast(
dict[str, AbstractOAuth2Implementation],
hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}),
)
registered = hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {})
if DATA_PROVIDERS not in hass.data:
return registered

View file

@ -10,9 +10,12 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import CoreState, Event, HomeAssistant, callback
from homeassistant.loader import bind_hass
from homeassistant.util.async_ import gather_with_limited_concurrency
from homeassistant.util.hass_dict import HassKey
FLOW_INIT_LIMIT = 20
DISCOVERY_FLOW_DISPATCHER = "discovery_flow_dispatcher"
DISCOVERY_FLOW_DISPATCHER: HassKey[FlowDispatcher] = HassKey(
"discovery_flow_dispatcher"
)
@bind_hass

View file

@ -34,6 +34,7 @@ from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.generated import languages
from homeassistant.setup import SetupPhases, async_start_setup
from homeassistant.util.async_ import create_eager_task
from homeassistant.util.hass_dict import HassKey
from . import (
config_validation as cv,
@ -57,9 +58,13 @@ SLOW_ADD_ENTITY_MAX_WAIT = 15 # Per Entity
SLOW_ADD_MIN_TIMEOUT = 500
PLATFORM_NOT_READY_RETRIES = 10
DATA_ENTITY_PLATFORM = "entity_platform"
DATA_DOMAIN_ENTITIES = "domain_entities"
DATA_DOMAIN_PLATFORM_ENTITIES = "domain_platform_entities"
DATA_ENTITY_PLATFORM: HassKey[dict[str, list[EntityPlatform]]] = HassKey(
"entity_platform"
)
DATA_DOMAIN_ENTITIES: HassKey[dict[str, dict[str, Entity]]] = HassKey("domain_entities")
DATA_DOMAIN_PLATFORM_ENTITIES: HassKey[dict[tuple[str, str], dict[str, Entity]]] = (
HassKey("domain_platform_entities")
)
PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds
_LOGGER = getLogger(__name__)
@ -155,20 +160,18 @@ class EntityPlatform:
# with the child dict indexed by entity_id
#
# This is usually media_player, light, switch, etc.
domain_entities: dict[str, dict[str, Entity]] = hass.data.setdefault(
self.domain_entities = hass.data.setdefault(
DATA_DOMAIN_ENTITIES, {}
)
self.domain_entities = domain_entities.setdefault(domain, {})
).setdefault(domain, {})
# Storage for entities indexed by domain and platform
# with the child dict indexed by entity_id
#
# This is usually media_player.yamaha, light.hue, switch.tplink, etc.
domain_platform_entities: dict[tuple[str, str], dict[str, Entity]] = (
hass.data.setdefault(DATA_DOMAIN_PLATFORM_ENTITIES, {})
)
key = (domain, platform_name)
self.domain_platform_entities = domain_platform_entities.setdefault(key, {})
self.domain_platform_entities = hass.data.setdefault(
DATA_DOMAIN_PLATFORM_ENTITIES, {}
).setdefault(key, {})
def __repr__(self) -> str:
"""Represent an EntityPlatform."""
@ -1063,6 +1066,4 @@ def async_get_platforms(
):
return []
platforms: list[EntityPlatform] = hass.data[DATA_ENTITY_PLATFORM][integration_name]
return platforms
return hass.data[DATA_ENTITY_PLATFORM][integration_name]

View file

@ -38,6 +38,7 @@ from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.event_type import EventType
from homeassistant.util.hass_dict import HassKey
from . import frame
from .device_registry import (
@ -54,19 +55,29 @@ from .template import RenderInfo, Template, result_as_boolean
from .typing import TemplateVarsType
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener"
TRACK_STATE_CHANGE_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_state_change_listener"
)
TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks"
TRACK_STATE_ADDED_DOMAIN_LISTENER = "track_state_added_domain_listener"
TRACK_STATE_ADDED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_state_added_domain_listener"
)
TRACK_STATE_REMOVED_DOMAIN_CALLBACKS = "track_state_removed_domain_callbacks"
TRACK_STATE_REMOVED_DOMAIN_LISTENER = "track_state_removed_domain_listener"
TRACK_STATE_REMOVED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_state_removed_domain_listener"
)
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener"
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_entity_registry_updated_listener"
)
TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS = "track_device_registry_updated_callbacks"
TRACK_DEVICE_REGISTRY_UPDATED_LISTENER = "track_device_registry_updated_listener"
TRACK_DEVICE_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_device_registry_updated_listener"
)
_ALL_LISTENER = "all"
_DOMAINS_LISTENER = "domains"
@ -89,7 +100,7 @@ _P = ParamSpec("_P")
class _KeyedEventTracker(Generic[_TypedDictT]):
"""Class to track events by key."""
listeners_key: str
listeners_key: HassKey[Callable[[], None]]
callbacks_key: str
event_type: EventType[_TypedDictT] | str
dispatcher_callable: Callable[
@ -373,7 +384,7 @@ def _remove_empty_listener() -> None:
@callback # type: ignore[arg-type] # mypy bug?
def _remove_listener(
hass: HomeAssistant,
listeners_key: str,
listeners_key: HassKey[Callable[[], None]],
keys: Iterable[str],
job: HassJob[[Event[_TypedDictT]], Any],
callbacks: dict[str, list[HassJob[[Event[_TypedDictT]], Any]]],

View file

@ -11,6 +11,7 @@ import httpx
from homeassistant.const import APPLICATION_NAME, EVENT_HOMEASSISTANT_CLOSE, __version__
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.loader import bind_hass
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.ssl import (
SSLCipherList,
client_context,
@ -23,8 +24,10 @@ from .frame import warn_use
# and we want to keep the connection open for a while so we
# don't have to reconnect every time so we use 15s to match aiohttp.
KEEP_ALIVE_TIMEOUT = 15
DATA_ASYNC_CLIENT = "httpx_async_client"
DATA_ASYNC_CLIENT_NOVERIFY = "httpx_async_client_noverify"
DATA_ASYNC_CLIENT: HassKey[httpx.AsyncClient] = HassKey("httpx_async_client")
DATA_ASYNC_CLIENT_NOVERIFY: HassKey[httpx.AsyncClient] = HassKey(
"httpx_async_client_noverify"
)
DEFAULT_LIMITS = limits = httpx.Limits(keepalive_expiry=KEEP_ALIVE_TIMEOUT)
SERVER_SOFTWARE = (
f"{APPLICATION_NAME}/{__version__} "
@ -42,9 +45,7 @@ def get_async_client(hass: HomeAssistant, verify_ssl: bool = True) -> httpx.Asyn
"""
key = DATA_ASYNC_CLIENT if verify_ssl else DATA_ASYNC_CLIENT_NOVERIFY
client: httpx.AsyncClient | None = hass.data.get(key)
if client is None:
if (client := hass.data.get(key)) is None:
client = hass.data[key] = create_async_httpx_client(hass, verify_ssl)
return client

View file

@ -11,11 +11,12 @@ from typing import Any
from homeassistant.core import HomeAssistant, callback
from homeassistant.loader import Integration, async_get_integrations
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import load_json_object
from .translation import build_resources
ICON_CACHE = "icon_cache"
ICON_CACHE: HassKey[_IconsCache] = HassKey("icon_cache")
_LOGGER = logging.getLogger(__name__)
@ -142,7 +143,7 @@ async def async_get_icons(
components = hass.config.top_level_components
if ICON_CACHE in hass.data:
cache: _IconsCache = hass.data[ICON_CACHE]
cache = hass.data[ICON_CACHE]
else:
cache = hass.data[ICON_CACHE] = _IconsCache(hass)

View file

@ -23,6 +23,7 @@ from homeassistant.const import (
from homeassistant.core import Context, HomeAssistant, State, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import bind_hass
from homeassistant.util.hass_dict import HassKey
from . import (
area_registry,
@ -44,7 +45,7 @@ INTENT_SET_POSITION = "HassSetPosition"
SLOT_SCHEMA = vol.Schema({}, extra=vol.ALLOW_EXTRA)
DATA_KEY = "intent"
DATA_KEY: HassKey[dict[str, IntentHandler]] = HassKey("intent")
SPEECH_TYPE_PLAIN = "plain"
SPEECH_TYPE_SSML = "ssml"
@ -89,7 +90,7 @@ async def async_handle(
assistant: str | None = None,
) -> IntentResponse:
"""Handle an intent."""
handler: IntentHandler = hass.data.get(DATA_KEY, {}).get(intent_type)
handler = hass.data.get(DATA_KEY, {}).get(intent_type)
if handler is None:
raise UnknownIntent(f"Unknown intent {intent_type}")