Convert homekit to use entry.runtime_data (#122533)

This commit is contained in:
J. Nick Koston 2024-07-25 04:06:55 -05:00 committed by GitHub
parent 8687b438f1
commit 7348a1fd0c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 40 additions and 25 deletions

View file

@ -127,7 +127,7 @@ from .const import (
SIGNAL_RELOAD_ENTITIES, SIGNAL_RELOAD_ENTITIES,
) )
from .iidmanager import AccessoryIIDStorage from .iidmanager import AccessoryIIDStorage
from .models import HomeKitEntryData from .models import HomeKitConfigEntry, HomeKitEntryData
from .type_triggers import DeviceTriggerAccessory from .type_triggers import DeviceTriggerAccessory
from .util import ( from .util import (
accessory_friendly_name, accessory_friendly_name,
@ -223,8 +223,12 @@ UNPAIR_SERVICE_SCHEMA = vol.All(
def _async_all_homekit_instances(hass: HomeAssistant) -> list[HomeKit]: def _async_all_homekit_instances(hass: HomeAssistant) -> list[HomeKit]:
"""All active HomeKit instances.""" """All active HomeKit instances."""
domain_data: dict[str, HomeKitEntryData] = hass.data[DOMAIN] hk_data: HomeKitEntryData | None
return [data.homekit for data in domain_data.values()] return [
hk_data.homekit
for entry in hass.config_entries.async_entries(DOMAIN)
if (hk_data := getattr(entry, "runtime_data", None))
]
def _async_get_imported_entries_indices( def _async_get_imported_entries_indices(
@ -246,7 +250,6 @@ def _async_get_imported_entries_indices(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the HomeKit from yaml.""" """Set up the HomeKit from yaml."""
hass.data[DOMAIN] = {}
hass.data[PERSIST_LOCK_DATA] = asyncio.Lock() hass.data[PERSIST_LOCK_DATA] = asyncio.Lock()
# Initialize the loader before loading entries to ensure # Initialize the loader before loading entries to ensure
@ -316,7 +319,7 @@ def _async_update_config_entry_from_yaml(
return True return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: HomeKitConfigEntry) -> bool:
"""Set up HomeKit from a config entry.""" """Set up HomeKit from a config entry."""
_async_import_options_from_data_if_missing(hass, entry) _async_import_options_from_data_if_missing(hass, entry)
@ -372,7 +375,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
entry_data = HomeKitEntryData( entry_data = HomeKitEntryData(
homekit=homekit, pairing_qr=None, pairing_qr_secret=None homekit=homekit, pairing_qr=None, pairing_qr_secret=None
) )
hass.data[DOMAIN][entry.entry_id] = entry_data entry.runtime_data = entry_data
async def _async_start_homekit(hass: HomeAssistant) -> None: async def _async_start_homekit(hass: HomeAssistant) -> None:
await homekit.async_start() await homekit.async_start()
@ -382,17 +385,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return True return True
async def _async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None: async def _async_update_listener(
hass: HomeAssistant, entry: HomeKitConfigEntry
) -> None:
"""Handle options update.""" """Handle options update."""
if entry.source == SOURCE_IMPORT: if entry.source == SOURCE_IMPORT:
return return
await hass.config_entries.async_reload(entry.entry_id) await hass.config_entries.async_reload(entry.entry_id)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: HomeKitConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
async_dismiss_setup_message(hass, entry.entry_id) async_dismiss_setup_message(hass, entry.entry_id)
entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id] entry_data = entry.runtime_data
homekit = entry_data.homekit homekit = entry_data.homekit
if homekit.status == STATUS_RUNNING: if homekit.status == STATUS_RUNNING:
@ -409,12 +414,10 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
await asyncio.sleep(PORT_CLEANUP_CHECK_INTERVAL_SECS) await asyncio.sleep(PORT_CLEANUP_CHECK_INTERVAL_SECS)
hass.data[DOMAIN].pop(entry.entry_id)
return True return True
async def async_remove_entry(hass: HomeAssistant, entry: ConfigEntry) -> None: async def async_remove_entry(hass: HomeAssistant, entry: HomeKitConfigEntry) -> None:
"""Remove a config entry.""" """Remove a config entry."""
await hass.async_add_executor_job( await hass.async_add_executor_job(
remove_state_files_for_entry_id, hass, entry.entry_id remove_state_files_for_entry_id, hass, entry.entry_id
@ -423,7 +426,7 @@ async def async_remove_entry(hass: HomeAssistant, entry: ConfigEntry) -> None:
@callback @callback
def _async_import_options_from_data_if_missing( def _async_import_options_from_data_if_missing(
hass: HomeAssistant, entry: ConfigEntry hass: HomeAssistant, entry: HomeKitConfigEntry
) -> None: ) -> None:
options = deepcopy(dict(entry.options)) options = deepcopy(dict(entry.options))
data = deepcopy(dict(entry.data)) data = deepcopy(dict(entry.data))
@ -1198,9 +1201,10 @@ class HomeKitPairingQRView(HomeAssistantView):
raise Unauthorized raise Unauthorized
entry_id, secret = request.query_string.split("-") entry_id, secret = request.query_string.split("-")
hass = request.app[KEY_HASS] hass = request.app[KEY_HASS]
domain_data: dict[str, HomeKitEntryData] = hass.data[DOMAIN] entry_data: HomeKitEntryData | None
if ( if (
not (entry_data := domain_data.get(entry_id)) not (entry := hass.config_entries.async_get_entry(entry_id))
or not (entry_data := getattr(entry, "runtime_data", None))
or not secret or not secret
or not entry_data.pairing_qr_secret or not entry_data.pairing_qr_secret
or secret != entry_data.pairing_qr_secret or secret != entry_data.pairing_qr_secret

View file

@ -8,22 +8,19 @@ from pyhap.accessory_driver import AccessoryDriver
from pyhap.state import State from pyhap.state import State
from homeassistant.components.diagnostics import async_redact_data from homeassistant.components.diagnostics import async_redact_data
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .accessories import HomeAccessory, HomeBridge from .accessories import HomeAccessory, HomeBridge
from .const import DOMAIN from .models import HomeKitConfigEntry
from .models import HomeKitEntryData
TO_REDACT = {"access_token", "entity_picture"} TO_REDACT = {"access_token", "entity_picture"}
async def async_get_config_entry_diagnostics( async def async_get_config_entry_diagnostics(
hass: HomeAssistant, entry: ConfigEntry hass: HomeAssistant, entry: HomeKitConfigEntry
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return diagnostics for a config entry.""" """Return diagnostics for a config entry."""
entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id] homekit = entry.runtime_data.homekit
homekit = entry_data.homekit
data: dict[str, Any] = { data: dict[str, Any] = {
"status": homekit.status, "status": homekit.status,
"config-entry": { "config-entry": {

View file

@ -5,9 +5,13 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from homeassistant.config_entries import ConfigEntry
if TYPE_CHECKING: if TYPE_CHECKING:
from . import HomeKit from . import HomeKit
type HomeKitConfigEntry = ConfigEntry[HomeKitEntryData]
@dataclass @dataclass
class HomeKitEntryData: class HomeKitEntryData:

View file

@ -106,7 +106,7 @@ from .const import (
VIDEO_CODEC_H264_V4L2M2M, VIDEO_CODEC_H264_V4L2M2M,
VIDEO_CODEC_LIBX264, VIDEO_CODEC_LIBX264,
) )
from .models import HomeKitEntryData from .models import HomeKitConfigEntry
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -366,7 +366,8 @@ def async_show_setup_message(
url.svg(buffer, scale=5, module_color="#000", background="#FFF") url.svg(buffer, scale=5, module_color="#000", background="#FFF")
pairing_secret = secrets.token_hex(32) pairing_secret = secrets.token_hex(32)
entry_data: HomeKitEntryData = hass.data[DOMAIN][entry_id] entry = cast(HomeKitConfigEntry, hass.config_entries.async_get_entry(entry_id))
entry_data = entry.runtime_data
entry_data.pairing_qr = buffer.getvalue() entry_data.pairing_qr = buffer.getvalue()
entry_data.pairing_qr_secret = pairing_secret entry_data.pairing_qr_secret = pairing_secret

View file

@ -1843,7 +1843,11 @@ async def test_homekit_uses_system_zeroconf(hass: HomeAssistant, hk_driver) -> N
entry.add_to_hass(hass) entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(entry.entry_id) assert await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id] # New tests should not access runtime data.
# Do not use this pattern for new tests.
entry_data: HomeKitEntryData = hass.config_entries.async_get_entry(
entry.entry_id
).runtime_data
assert entry_data.homekit.driver.advertiser == system_async_zc assert entry_data.homekit.driver.advertiser == system_async_zc
assert await hass.config_entries.async_unload(entry.entry_id) assert await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()

View file

@ -257,7 +257,12 @@ async def test_async_show_setup_msg(hass: HomeAssistant, hk_driver) -> None:
hass, entry.entry_id, "bridge_name", pincode, "X-HM://0" hass, entry.entry_id, "bridge_name", pincode, "X-HM://0"
) )
await hass.async_block_till_done() await hass.async_block_till_done()
entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id]
# New tests should not access runtime data.
# Do not use this pattern for new tests.
entry_data: HomeKitEntryData = hass.config_entries.async_get_entry(
entry.entry_id
).runtime_data
assert entry_data.pairing_qr_secret assert entry_data.pairing_qr_secret
assert entry_data.pairing_qr assert entry_data.pairing_qr