From e065c7096999f43977365db74406378c430d4105 Mon Sep 17 00:00:00 2001 From: Rami Mosleh Date: Wed, 12 Jun 2024 16:38:35 +0300 Subject: [PATCH] Store transmission coordinator in runtime_data (#119502) store transmission coordinator in runtime_data --- .../components/transmission/__init__.py | 27 ++++++++++++------- .../components/transmission/sensor.py | 8 +++--- .../components/transmission/switch.py | 11 +++----- tests/components/transmission/test_init.py | 1 - 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/homeassistant/components/transmission/__init__.py b/homeassistant/components/transmission/__init__.py index 681b4438099..06f27a1e605 100644 --- a/homeassistant/components/transmission/__init__.py +++ b/homeassistant/components/transmission/__init__.py @@ -15,7 +15,7 @@ from transmission_rpc.error import ( ) import voluptuous as vol -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.const import ( CONF_HOST, CONF_ID, @@ -102,8 +102,12 @@ SERVICE_STOP_TORRENT_SCHEMA = vol.All( ) ) +type TransmissionConfigEntry = ConfigEntry[TransmissionDataUpdateCoordinator] -async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: + +async def async_setup_entry( + hass: HomeAssistant, config_entry: TransmissionConfigEntry +) -> bool: """Set up the Transmission Component.""" @callback @@ -135,7 +139,7 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b await hass.async_add_executor_job(coordinator.init_torrent_list) await coordinator.async_config_entry_first_refresh() - hass.data.setdefault(DOMAIN, {})[config_entry.entry_id] = coordinator + config_entry.runtime_data = coordinator await hass.config_entries.async_forward_entry_setups(config_entry, PLATFORMS) @@ -204,13 +208,16 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> if unload_ok := await hass.config_entries.async_unload_platforms( config_entry, PLATFORMS ): - hass.data[DOMAIN].pop(config_entry.entry_id) - - if not hass.data[DOMAIN]: - hass.services.async_remove(DOMAIN, SERVICE_ADD_TORRENT) - hass.services.async_remove(DOMAIN, SERVICE_REMOVE_TORRENT) - hass.services.async_remove(DOMAIN, SERVICE_START_TORRENT) - hass.services.async_remove(DOMAIN, SERVICE_STOP_TORRENT) + loaded_entries = [ + entry + for entry in hass.config_entries.async_entries(DOMAIN) + if entry.state == ConfigEntryState.LOADED + ] + if len(loaded_entries) == 1: + hass.services.async_remove(DOMAIN, SERVICE_ADD_TORRENT) + hass.services.async_remove(DOMAIN, SERVICE_REMOVE_TORRENT) + hass.services.async_remove(DOMAIN, SERVICE_START_TORRENT) + hass.services.async_remove(DOMAIN, SERVICE_STOP_TORRENT) return unload_ok diff --git a/homeassistant/components/transmission/sensor.py b/homeassistant/components/transmission/sensor.py index 9ee42045aab..737520adb5f 100644 --- a/homeassistant/components/transmission/sensor.py +++ b/homeassistant/components/transmission/sensor.py @@ -14,7 +14,6 @@ from homeassistant.components.sensor import ( SensorEntity, SensorEntityDescription, ) -from homeassistant.config_entries import ConfigEntry from homeassistant.const import STATE_IDLE, UnitOfDataRate from homeassistant.core import HomeAssistant from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo @@ -22,6 +21,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import StateType from homeassistant.helpers.update_coordinator import CoordinatorEntity +from . import TransmissionConfigEntry from .const import ( DOMAIN, STATE_ATTR_TORRENT_INFO, @@ -134,14 +134,12 @@ SENSOR_TYPES: tuple[TransmissionSensorEntityDescription, ...] = ( async def async_setup_entry( hass: HomeAssistant, - config_entry: ConfigEntry, + config_entry: TransmissionConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Transmission sensors.""" - coordinator: TransmissionDataUpdateCoordinator = hass.data[DOMAIN][ - config_entry.entry_id - ] + coordinator = config_entry.runtime_data async_add_entities( TransmissionSensor(coordinator, description) for description in SENSOR_TYPES diff --git a/homeassistant/components/transmission/switch.py b/homeassistant/components/transmission/switch.py index 8e79d8246e0..d88f794cb10 100644 --- a/homeassistant/components/transmission/switch.py +++ b/homeassistant/components/transmission/switch.py @@ -2,21 +2,18 @@ from collections.abc import Callable from dataclasses import dataclass -import logging from typing import Any from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription -from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import CoordinatorEntity +from . import TransmissionConfigEntry from .const import DOMAIN from .coordinator import TransmissionDataUpdateCoordinator -_LOGGING = logging.getLogger(__name__) - @dataclass(frozen=True, kw_only=True) class TransmissionSwitchEntityDescription(SwitchEntityDescription): @@ -47,14 +44,12 @@ SWITCH_TYPES: tuple[TransmissionSwitchEntityDescription, ...] = ( async def async_setup_entry( hass: HomeAssistant, - config_entry: ConfigEntry, + config_entry: TransmissionConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Transmission switch.""" - coordinator: TransmissionDataUpdateCoordinator = hass.data[DOMAIN][ - config_entry.entry_id - ] + coordinator = config_entry.runtime_data async_add_entities( TransmissionSwitch(coordinator, description) for description in SWITCH_TYPES diff --git a/tests/components/transmission/test_init.py b/tests/components/transmission/test_init.py index 307576ffdea..38d941c3779 100644 --- a/tests/components/transmission/test_init.py +++ b/tests/components/transmission/test_init.py @@ -119,7 +119,6 @@ async def test_unload_entry(hass: HomeAssistant) -> None: await hass.async_block_till_done() assert entry.state is ConfigEntryState.NOT_LOADED - assert not hass.data[DOMAIN] @pytest.mark.parametrize(