diff --git a/homeassistant/components/broadlink/device.py b/homeassistant/components/broadlink/device.py index 8f5cf43ad7e..2518cd65bd3 100644 --- a/homeassistant/components/broadlink/device.py +++ b/homeassistant/components/broadlink/device.py @@ -3,6 +3,7 @@ from contextlib import suppress from functools import partial import logging +from typing import Generic import broadlink as blk from broadlink.exceptions import ( @@ -12,6 +13,7 @@ from broadlink.exceptions import ( ConnectionClosedError, NetworkTimeoutError, ) +from typing_extensions import TypeVar from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntry from homeassistant.const import ( @@ -27,7 +29,9 @@ from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers import device_registry as dr from .const import DEFAULT_PORT, DOMAIN, DOMAINS_AND_TYPES -from .updater import get_update_manager +from .updater import BroadlinkUpdateManager, get_update_manager + +_ApiT = TypeVar("_ApiT", bound=blk.Device, default=blk.Device) _LOGGER = logging.getLogger(__name__) @@ -37,16 +41,16 @@ def get_domains(device_type: str) -> set[Platform]: return {d for d, t in DOMAINS_AND_TYPES.items() if device_type in t} -class BroadlinkDevice: +class BroadlinkDevice(Generic[_ApiT]): """Manages a Broadlink device.""" - api: blk.Device + api: _ApiT def __init__(self, hass: HomeAssistant, config: ConfigEntry) -> None: """Initialize the device.""" self.hass = hass self.config = config - self.update_manager = None + self.update_manager: BroadlinkUpdateManager[_ApiT] | None = None self.fw_version: int | None = None self.authorized: bool | None = None self.reset_jobs: list[CALLBACK_TYPE] = [] diff --git a/homeassistant/components/broadlink/updater.py b/homeassistant/components/broadlink/updater.py index f678af0105f..4faa84dbbee 100644 --- a/homeassistant/components/broadlink/updater.py +++ b/homeassistant/components/broadlink/updater.py @@ -1,20 +1,30 @@ """Support for fetching data from Broadlink devices.""" -from abc import ABC, abstractmethod -from datetime import timedelta -import logging +from __future__ import annotations +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +import logging +from typing import TYPE_CHECKING, Any, Generic + +import broadlink as blk from broadlink.exceptions import AuthorizationError, BroadlinkException +from typing_extensions import TypeVar from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.util import dt as dt_util +if TYPE_CHECKING: + from .device import BroadlinkDevice + +_ApiT = TypeVar("_ApiT", bound=blk.Device) + _LOGGER = logging.getLogger(__name__) -def get_update_manager(device): +def get_update_manager(device: BroadlinkDevice[_ApiT]) -> BroadlinkUpdateManager[_ApiT]: """Return an update manager for a given Broadlink device.""" - update_managers = { + update_managers: dict[str, type[BroadlinkUpdateManager]] = { "A1": BroadlinkA1UpdateManager, "BG1": BroadlinkBG1UpdateManager, "HYS": BroadlinkThermostatUpdateManager, @@ -38,7 +48,7 @@ def get_update_manager(device): return update_managers[device.api.type](device) -class BroadlinkUpdateManager(ABC): +class BroadlinkUpdateManager(ABC, Generic[_ApiT]): """Representation of a Broadlink update manager. Implement this class to manage fetching data from the device and to @@ -47,7 +57,7 @@ class BroadlinkUpdateManager(ABC): SCAN_INTERVAL = timedelta(minutes=1) - def __init__(self, device): + def __init__(self, device: BroadlinkDevice[_ApiT]) -> None: """Initialize the update manager.""" self.device = device self.coordinator = DataUpdateCoordinator( @@ -57,18 +67,22 @@ class BroadlinkUpdateManager(ABC): update_method=self.async_update, update_interval=self.SCAN_INTERVAL, ) - self.available = None - self.last_update = None + self.available: bool | None = None + self.last_update: datetime | None = None - async def async_update(self): + async def async_update(self) -> dict[str, Any] | None: """Fetch data from the device and update availability.""" try: data = await self.async_fetch_data() except (BroadlinkException, OSError) as err: - if self.available and ( - dt_util.utcnow() - self.last_update > self.SCAN_INTERVAL * 3 - or isinstance(err, (AuthorizationError, OSError)) + if ( + self.available + and self.last_update + and ( + dt_util.utcnow() - self.last_update > self.SCAN_INTERVAL * 3 + or isinstance(err, (AuthorizationError, OSError)) + ) ): self.available = False _LOGGER.warning( @@ -91,42 +105,42 @@ class BroadlinkUpdateManager(ABC): return data @abstractmethod - async def async_fetch_data(self): + async def async_fetch_data(self) -> dict[str, Any] | None: """Fetch data from the device.""" -class BroadlinkA1UpdateManager(BroadlinkUpdateManager): +class BroadlinkA1UpdateManager(BroadlinkUpdateManager[blk.a1]): """Manages updates for Broadlink A1 devices.""" SCAN_INTERVAL = timedelta(seconds=10) - async def async_fetch_data(self): + async def async_fetch_data(self) -> dict[str, Any]: """Fetch data from the device.""" return await self.device.async_request(self.device.api.check_sensors_raw) -class BroadlinkMP1UpdateManager(BroadlinkUpdateManager): +class BroadlinkMP1UpdateManager(BroadlinkUpdateManager[blk.mp1]): """Manages updates for Broadlink MP1 devices.""" - async def async_fetch_data(self): + async def async_fetch_data(self) -> dict[str, Any]: """Fetch data from the device.""" return await self.device.async_request(self.device.api.check_power) -class BroadlinkMP1SUpdateManager(BroadlinkUpdateManager): +class BroadlinkMP1SUpdateManager(BroadlinkUpdateManager[blk.mp1s]): """Manages updates for Broadlink MP1 devices.""" - async def async_fetch_data(self): + async def async_fetch_data(self) -> dict[str, Any]: """Fetch data from the device.""" power = await self.device.async_request(self.device.api.check_power) sensors = await self.device.async_request(self.device.api.get_state) return {**power, **sensors} -class BroadlinkRMUpdateManager(BroadlinkUpdateManager): +class BroadlinkRMUpdateManager(BroadlinkUpdateManager[blk.rm]): """Manages updates for Broadlink remotes.""" - async def async_fetch_data(self): + async def async_fetch_data(self) -> dict[str, Any]: """Fetch data from the device.""" device = self.device @@ -138,7 +152,9 @@ class BroadlinkRMUpdateManager(BroadlinkUpdateManager): return {} @staticmethod - def normalize(data, previous_data): + def normalize( + data: dict[str, Any], previous_data: dict[str, Any] | None + ) -> dict[str, Any]: """Fix firmware issue. See https://github.com/home-assistant/core/issues/42100. @@ -151,18 +167,18 @@ class BroadlinkRMUpdateManager(BroadlinkUpdateManager): return data -class BroadlinkSP1UpdateManager(BroadlinkUpdateManager): +class BroadlinkSP1UpdateManager(BroadlinkUpdateManager[blk.sp1]): """Manages updates for Broadlink SP1 devices.""" - async def async_fetch_data(self): + async def async_fetch_data(self) -> None: """Fetch data from the device.""" return None -class BroadlinkSP2UpdateManager(BroadlinkUpdateManager): +class BroadlinkSP2UpdateManager(BroadlinkUpdateManager[blk.sp2]): """Manages updates for Broadlink SP2 devices.""" - async def async_fetch_data(self): + async def async_fetch_data(self) -> dict[str, Any]: """Fetch data from the device.""" device = self.device @@ -175,33 +191,33 @@ class BroadlinkSP2UpdateManager(BroadlinkUpdateManager): return data -class BroadlinkBG1UpdateManager(BroadlinkUpdateManager): +class BroadlinkBG1UpdateManager(BroadlinkUpdateManager[blk.bg1]): """Manages updates for Broadlink BG1 devices.""" - async def async_fetch_data(self): + async def async_fetch_data(self) -> dict[str, Any]: """Fetch data from the device.""" return await self.device.async_request(self.device.api.get_state) -class BroadlinkSP4UpdateManager(BroadlinkUpdateManager): +class BroadlinkSP4UpdateManager(BroadlinkUpdateManager[blk.sp4]): """Manages updates for Broadlink SP4 devices.""" - async def async_fetch_data(self): + async def async_fetch_data(self) -> dict[str, Any]: """Fetch data from the device.""" return await self.device.async_request(self.device.api.get_state) -class BroadlinkLB1UpdateManager(BroadlinkUpdateManager): +class BroadlinkLB1UpdateManager(BroadlinkUpdateManager[blk.lb1]): """Manages updates for Broadlink LB1 devices.""" - async def async_fetch_data(self): + async def async_fetch_data(self) -> dict[str, Any]: """Fetch data from the device.""" return await self.device.async_request(self.device.api.get_state) -class BroadlinkThermostatUpdateManager(BroadlinkUpdateManager): +class BroadlinkThermostatUpdateManager(BroadlinkUpdateManager[blk.hysen]): """Manages updates for thermostats with Broadlink DNA.""" - async def async_fetch_data(self): + async def async_fetch_data(self) -> dict[str, Any]: """Fetch data from the device.""" return await self.device.async_request(self.device.api.get_full_status)