Use HassKey for helpers (1) (#117012)
This commit is contained in:
parent
8f614fb06d
commit
2db64c7e6d
8 changed files with 68 additions and 48 deletions
|
@ -20,6 +20,7 @@ from homeassistant.const import APPLICATION_NAME, EVENT_HOMEASSISTANT_CLOSE, __v
|
|||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import ssl as ssl_util
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
from .backports.aiohttp_resolver import AsyncResolver
|
||||
|
@ -30,8 +31,12 @@ if TYPE_CHECKING:
|
|||
from aiohttp.typedefs import JSONDecoder
|
||||
|
||||
|
||||
DATA_CONNECTOR = "aiohttp_connector"
|
||||
DATA_CLIENTSESSION = "aiohttp_clientsession"
|
||||
DATA_CONNECTOR: HassKey[dict[tuple[bool, int], aiohttp.BaseConnector]] = HassKey(
|
||||
"aiohttp_connector"
|
||||
)
|
||||
DATA_CLIENTSESSION: HassKey[dict[tuple[bool, int], aiohttp.ClientSession]] = HassKey(
|
||||
"aiohttp_clientsession"
|
||||
)
|
||||
|
||||
SERVER_SOFTWARE = (
|
||||
f"{APPLICATION_NAME}/{__version__} "
|
||||
|
@ -84,11 +89,7 @@ def async_get_clientsession(
|
|||
This method must be run in the event loop.
|
||||
"""
|
||||
session_key = _make_key(verify_ssl, family)
|
||||
if DATA_CLIENTSESSION not in hass.data:
|
||||
sessions: dict[tuple[bool, int], aiohttp.ClientSession] = {}
|
||||
hass.data[DATA_CLIENTSESSION] = sessions
|
||||
else:
|
||||
sessions = hass.data[DATA_CLIENTSESSION]
|
||||
sessions = hass.data.setdefault(DATA_CLIENTSESSION, {})
|
||||
|
||||
if session_key not in sessions:
|
||||
session = _async_create_clientsession(
|
||||
|
@ -288,11 +289,7 @@ def _async_get_connector(
|
|||
This method must be run in the event loop.
|
||||
"""
|
||||
connector_key = _make_key(verify_ssl, family)
|
||||
if DATA_CONNECTOR not in hass.data:
|
||||
connectors: dict[tuple[bool, int], aiohttp.BaseConnector] = {}
|
||||
hass.data[DATA_CONNECTOR] = connectors
|
||||
else:
|
||||
connectors = hass.data[DATA_CONNECTOR]
|
||||
connectors = hass.data.setdefault(DATA_CONNECTOR, {})
|
||||
|
||||
if connector_key in connectors:
|
||||
return connectors[connector_key]
|
||||
|
|
|
@ -27,6 +27,7 @@ from homeassistant import config_entries
|
|||
from homeassistant.components import http
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.loader import async_get_application_credentials
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
from .aiohttp_client import async_get_clientsession
|
||||
from .network import NoURLAvailableError
|
||||
|
@ -34,8 +35,15 @@ from .network import NoURLAvailableError
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DATA_JWT_SECRET = "oauth2_jwt_secret"
|
||||
DATA_IMPLEMENTATIONS = "oauth2_impl"
|
||||
DATA_PROVIDERS = "oauth2_providers"
|
||||
DATA_IMPLEMENTATIONS: HassKey[dict[str, dict[str, AbstractOAuth2Implementation]]] = (
|
||||
HassKey("oauth2_impl")
|
||||
)
|
||||
DATA_PROVIDERS: HassKey[
|
||||
dict[
|
||||
str,
|
||||
Callable[[HomeAssistant, str], Awaitable[list[AbstractOAuth2Implementation]]],
|
||||
]
|
||||
] = HassKey("oauth2_providers")
|
||||
AUTH_CALLBACK_PATH = "/auth/external/callback"
|
||||
HEADER_FRONTEND_BASE = "HA-Frontend-Base"
|
||||
MY_AUTH_CALLBACK_PATH = "https://my.home-assistant.io/redirect/oauth"
|
||||
|
@ -398,10 +406,7 @@ async def async_get_implementations(
|
|||
hass: HomeAssistant, domain: str
|
||||
) -> dict[str, AbstractOAuth2Implementation]:
|
||||
"""Return OAuth2 implementations for specified domain."""
|
||||
registered = cast(
|
||||
dict[str, AbstractOAuth2Implementation],
|
||||
hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}),
|
||||
)
|
||||
registered = hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {})
|
||||
|
||||
if DATA_PROVIDERS not in hass.data:
|
||||
return registered
|
||||
|
|
|
@ -10,9 +10,12 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
|
|||
from homeassistant.core import CoreState, Event, HomeAssistant, callback
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util.async_ import gather_with_limited_concurrency
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
FLOW_INIT_LIMIT = 20
|
||||
DISCOVERY_FLOW_DISPATCHER = "discovery_flow_dispatcher"
|
||||
DISCOVERY_FLOW_DISPATCHER: HassKey[FlowDispatcher] = HassKey(
|
||||
"discovery_flow_dispatcher"
|
||||
)
|
||||
|
||||
|
||||
@bind_hass
|
||||
|
|
|
@ -34,6 +34,7 @@ from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
|||
from homeassistant.generated import languages
|
||||
from homeassistant.setup import SetupPhases, async_start_setup
|
||||
from homeassistant.util.async_ import create_eager_task
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
from . import (
|
||||
config_validation as cv,
|
||||
|
@ -57,9 +58,13 @@ SLOW_ADD_ENTITY_MAX_WAIT = 15 # Per Entity
|
|||
SLOW_ADD_MIN_TIMEOUT = 500
|
||||
|
||||
PLATFORM_NOT_READY_RETRIES = 10
|
||||
DATA_ENTITY_PLATFORM = "entity_platform"
|
||||
DATA_DOMAIN_ENTITIES = "domain_entities"
|
||||
DATA_DOMAIN_PLATFORM_ENTITIES = "domain_platform_entities"
|
||||
DATA_ENTITY_PLATFORM: HassKey[dict[str, list[EntityPlatform]]] = HassKey(
|
||||
"entity_platform"
|
||||
)
|
||||
DATA_DOMAIN_ENTITIES: HassKey[dict[str, dict[str, Entity]]] = HassKey("domain_entities")
|
||||
DATA_DOMAIN_PLATFORM_ENTITIES: HassKey[dict[tuple[str, str], dict[str, Entity]]] = (
|
||||
HassKey("domain_platform_entities")
|
||||
)
|
||||
PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
@ -155,20 +160,18 @@ class EntityPlatform:
|
|||
# with the child dict indexed by entity_id
|
||||
#
|
||||
# This is usually media_player, light, switch, etc.
|
||||
domain_entities: dict[str, dict[str, Entity]] = hass.data.setdefault(
|
||||
self.domain_entities = hass.data.setdefault(
|
||||
DATA_DOMAIN_ENTITIES, {}
|
||||
)
|
||||
self.domain_entities = domain_entities.setdefault(domain, {})
|
||||
).setdefault(domain, {})
|
||||
|
||||
# Storage for entities indexed by domain and platform
|
||||
# with the child dict indexed by entity_id
|
||||
#
|
||||
# This is usually media_player.yamaha, light.hue, switch.tplink, etc.
|
||||
domain_platform_entities: dict[tuple[str, str], dict[str, Entity]] = (
|
||||
hass.data.setdefault(DATA_DOMAIN_PLATFORM_ENTITIES, {})
|
||||
)
|
||||
key = (domain, platform_name)
|
||||
self.domain_platform_entities = domain_platform_entities.setdefault(key, {})
|
||||
self.domain_platform_entities = hass.data.setdefault(
|
||||
DATA_DOMAIN_PLATFORM_ENTITIES, {}
|
||||
).setdefault(key, {})
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Represent an EntityPlatform."""
|
||||
|
@ -1063,6 +1066,4 @@ def async_get_platforms(
|
|||
):
|
||||
return []
|
||||
|
||||
platforms: list[EntityPlatform] = hass.data[DATA_ENTITY_PLATFORM][integration_name]
|
||||
|
||||
return platforms
|
||||
return hass.data[DATA_ENTITY_PLATFORM][integration_name]
|
||||
|
|
|
@ -38,6 +38,7 @@ from homeassistant.loader import bind_hass
|
|||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
from homeassistant.util.event_type import EventType
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
from . import frame
|
||||
from .device_registry import (
|
||||
|
@ -54,19 +55,29 @@ from .template import RenderInfo, Template, result_as_boolean
|
|||
from .typing import TemplateVarsType
|
||||
|
||||
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
|
||||
TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener"
|
||||
TRACK_STATE_CHANGE_LISTENER: HassKey[Callable[[], None]] = HassKey(
|
||||
"track_state_change_listener"
|
||||
)
|
||||
|
||||
TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks"
|
||||
TRACK_STATE_ADDED_DOMAIN_LISTENER = "track_state_added_domain_listener"
|
||||
TRACK_STATE_ADDED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey(
|
||||
"track_state_added_domain_listener"
|
||||
)
|
||||
|
||||
TRACK_STATE_REMOVED_DOMAIN_CALLBACKS = "track_state_removed_domain_callbacks"
|
||||
TRACK_STATE_REMOVED_DOMAIN_LISTENER = "track_state_removed_domain_listener"
|
||||
TRACK_STATE_REMOVED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey(
|
||||
"track_state_removed_domain_listener"
|
||||
)
|
||||
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener"
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey(
|
||||
"track_entity_registry_updated_listener"
|
||||
)
|
||||
|
||||
TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS = "track_device_registry_updated_callbacks"
|
||||
TRACK_DEVICE_REGISTRY_UPDATED_LISTENER = "track_device_registry_updated_listener"
|
||||
TRACK_DEVICE_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey(
|
||||
"track_device_registry_updated_listener"
|
||||
)
|
||||
|
||||
_ALL_LISTENER = "all"
|
||||
_DOMAINS_LISTENER = "domains"
|
||||
|
@ -89,7 +100,7 @@ _P = ParamSpec("_P")
|
|||
class _KeyedEventTracker(Generic[_TypedDictT]):
|
||||
"""Class to track events by key."""
|
||||
|
||||
listeners_key: str
|
||||
listeners_key: HassKey[Callable[[], None]]
|
||||
callbacks_key: str
|
||||
event_type: EventType[_TypedDictT] | str
|
||||
dispatcher_callable: Callable[
|
||||
|
@ -373,7 +384,7 @@ def _remove_empty_listener() -> None:
|
|||
@callback # type: ignore[arg-type] # mypy bug?
|
||||
def _remove_listener(
|
||||
hass: HomeAssistant,
|
||||
listeners_key: str,
|
||||
listeners_key: HassKey[Callable[[], None]],
|
||||
keys: Iterable[str],
|
||||
job: HassJob[[Event[_TypedDictT]], Any],
|
||||
callbacks: dict[str, list[HassJob[[Event[_TypedDictT]], Any]]],
|
||||
|
|
|
@ -11,6 +11,7 @@ import httpx
|
|||
from homeassistant.const import APPLICATION_NAME, EVENT_HOMEASSISTANT_CLOSE, __version__
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.ssl import (
|
||||
SSLCipherList,
|
||||
client_context,
|
||||
|
@ -23,8 +24,10 @@ from .frame import warn_use
|
|||
# and we want to keep the connection open for a while so we
|
||||
# don't have to reconnect every time so we use 15s to match aiohttp.
|
||||
KEEP_ALIVE_TIMEOUT = 15
|
||||
DATA_ASYNC_CLIENT = "httpx_async_client"
|
||||
DATA_ASYNC_CLIENT_NOVERIFY = "httpx_async_client_noverify"
|
||||
DATA_ASYNC_CLIENT: HassKey[httpx.AsyncClient] = HassKey("httpx_async_client")
|
||||
DATA_ASYNC_CLIENT_NOVERIFY: HassKey[httpx.AsyncClient] = HassKey(
|
||||
"httpx_async_client_noverify"
|
||||
)
|
||||
DEFAULT_LIMITS = limits = httpx.Limits(keepalive_expiry=KEEP_ALIVE_TIMEOUT)
|
||||
SERVER_SOFTWARE = (
|
||||
f"{APPLICATION_NAME}/{__version__} "
|
||||
|
@ -42,9 +45,7 @@ def get_async_client(hass: HomeAssistant, verify_ssl: bool = True) -> httpx.Asyn
|
|||
"""
|
||||
key = DATA_ASYNC_CLIENT if verify_ssl else DATA_ASYNC_CLIENT_NOVERIFY
|
||||
|
||||
client: httpx.AsyncClient | None = hass.data.get(key)
|
||||
|
||||
if client is None:
|
||||
if (client := hass.data.get(key)) is None:
|
||||
client = hass.data[key] = create_async_httpx_client(hass, verify_ssl)
|
||||
|
||||
return client
|
||||
|
|
|
@ -11,11 +11,12 @@ from typing import Any
|
|||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.loader import Integration, async_get_integrations
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import load_json_object
|
||||
|
||||
from .translation import build_resources
|
||||
|
||||
ICON_CACHE = "icon_cache"
|
||||
ICON_CACHE: HassKey[_IconsCache] = HassKey("icon_cache")
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -142,7 +143,7 @@ async def async_get_icons(
|
|||
components = hass.config.top_level_components
|
||||
|
||||
if ICON_CACHE in hass.data:
|
||||
cache: _IconsCache = hass.data[ICON_CACHE]
|
||||
cache = hass.data[ICON_CACHE]
|
||||
else:
|
||||
cache = hass.data[ICON_CACHE] = _IconsCache(hass)
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ from homeassistant.const import (
|
|||
from homeassistant.core import Context, HomeAssistant, State, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
from . import (
|
||||
area_registry,
|
||||
|
@ -44,7 +45,7 @@ INTENT_SET_POSITION = "HassSetPosition"
|
|||
|
||||
SLOT_SCHEMA = vol.Schema({}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
DATA_KEY = "intent"
|
||||
DATA_KEY: HassKey[dict[str, IntentHandler]] = HassKey("intent")
|
||||
|
||||
SPEECH_TYPE_PLAIN = "plain"
|
||||
SPEECH_TYPE_SSML = "ssml"
|
||||
|
@ -89,7 +90,7 @@ async def async_handle(
|
|||
assistant: str | None = None,
|
||||
) -> IntentResponse:
|
||||
"""Handle an intent."""
|
||||
handler: IntentHandler = hass.data.get(DATA_KEY, {}).get(intent_type)
|
||||
handler = hass.data.get(DATA_KEY, {}).get(intent_type)
|
||||
|
||||
if handler is None:
|
||||
raise UnknownIntent(f"Unknown intent {intent_type}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue