Improve type hints in broadlink (#121285)

This commit is contained in:
epenet 2024-07-05 11:12:01 +02:00 committed by GitHub
parent 0088765268
commit d0c10c961d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 59 additions and 39 deletions

View file

@ -3,6 +3,7 @@
from contextlib import suppress from contextlib import suppress
from functools import partial from functools import partial
import logging import logging
from typing import Generic
import broadlink as blk import broadlink as blk
from broadlink.exceptions import ( from broadlink.exceptions import (
@ -12,6 +13,7 @@ from broadlink.exceptions import (
ConnectionClosedError, ConnectionClosedError,
NetworkTimeoutError, NetworkTimeoutError,
) )
from typing_extensions import TypeVar
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntry from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
@ -27,7 +29,9 @@ from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from .const import DEFAULT_PORT, DOMAIN, DOMAINS_AND_TYPES from .const import DEFAULT_PORT, DOMAIN, DOMAINS_AND_TYPES
from .updater import get_update_manager from .updater import BroadlinkUpdateManager, get_update_manager
_ApiT = TypeVar("_ApiT", bound=blk.Device, default=blk.Device)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -37,16 +41,16 @@ def get_domains(device_type: str) -> set[Platform]:
return {d for d, t in DOMAINS_AND_TYPES.items() if device_type in t} return {d for d, t in DOMAINS_AND_TYPES.items() if device_type in t}
class BroadlinkDevice: class BroadlinkDevice(Generic[_ApiT]):
"""Manages a Broadlink device.""" """Manages a Broadlink device."""
api: blk.Device api: _ApiT
def __init__(self, hass: HomeAssistant, config: ConfigEntry) -> None: def __init__(self, hass: HomeAssistant, config: ConfigEntry) -> None:
"""Initialize the device.""" """Initialize the device."""
self.hass = hass self.hass = hass
self.config = config self.config = config
self.update_manager = None self.update_manager: BroadlinkUpdateManager[_ApiT] | None = None
self.fw_version: int | None = None self.fw_version: int | None = None
self.authorized: bool | None = None self.authorized: bool | None = None
self.reset_jobs: list[CALLBACK_TYPE] = [] self.reset_jobs: list[CALLBACK_TYPE] = []

View file

@ -1,20 +1,30 @@
"""Support for fetching data from Broadlink devices.""" """Support for fetching data from Broadlink devices."""
from abc import ABC, abstractmethod from __future__ import annotations
from datetime import timedelta
import logging
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
import logging
from typing import TYPE_CHECKING, Any, Generic
import broadlink as blk
from broadlink.exceptions import AuthorizationError, BroadlinkException from broadlink.exceptions import AuthorizationError, BroadlinkException
from typing_extensions import TypeVar
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
if TYPE_CHECKING:
from .device import BroadlinkDevice
_ApiT = TypeVar("_ApiT", bound=blk.Device)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def get_update_manager(device): def get_update_manager(device: BroadlinkDevice[_ApiT]) -> BroadlinkUpdateManager[_ApiT]:
"""Return an update manager for a given Broadlink device.""" """Return an update manager for a given Broadlink device."""
update_managers = { update_managers: dict[str, type[BroadlinkUpdateManager]] = {
"A1": BroadlinkA1UpdateManager, "A1": BroadlinkA1UpdateManager,
"BG1": BroadlinkBG1UpdateManager, "BG1": BroadlinkBG1UpdateManager,
"HYS": BroadlinkThermostatUpdateManager, "HYS": BroadlinkThermostatUpdateManager,
@ -38,7 +48,7 @@ def get_update_manager(device):
return update_managers[device.api.type](device) return update_managers[device.api.type](device)
class BroadlinkUpdateManager(ABC): class BroadlinkUpdateManager(ABC, Generic[_ApiT]):
"""Representation of a Broadlink update manager. """Representation of a Broadlink update manager.
Implement this class to manage fetching data from the device and to Implement this class to manage fetching data from the device and to
@ -47,7 +57,7 @@ class BroadlinkUpdateManager(ABC):
SCAN_INTERVAL = timedelta(minutes=1) SCAN_INTERVAL = timedelta(minutes=1)
def __init__(self, device): def __init__(self, device: BroadlinkDevice[_ApiT]) -> None:
"""Initialize the update manager.""" """Initialize the update manager."""
self.device = device self.device = device
self.coordinator = DataUpdateCoordinator( self.coordinator = DataUpdateCoordinator(
@ -57,18 +67,22 @@ class BroadlinkUpdateManager(ABC):
update_method=self.async_update, update_method=self.async_update,
update_interval=self.SCAN_INTERVAL, update_interval=self.SCAN_INTERVAL,
) )
self.available = None self.available: bool | None = None
self.last_update = None self.last_update: datetime | None = None
async def async_update(self): async def async_update(self) -> dict[str, Any] | None:
"""Fetch data from the device and update availability.""" """Fetch data from the device and update availability."""
try: try:
data = await self.async_fetch_data() data = await self.async_fetch_data()
except (BroadlinkException, OSError) as err: except (BroadlinkException, OSError) as err:
if self.available and ( if (
dt_util.utcnow() - self.last_update > self.SCAN_INTERVAL * 3 self.available
or isinstance(err, (AuthorizationError, OSError)) and self.last_update
and (
dt_util.utcnow() - self.last_update > self.SCAN_INTERVAL * 3
or isinstance(err, (AuthorizationError, OSError))
)
): ):
self.available = False self.available = False
_LOGGER.warning( _LOGGER.warning(
@ -91,42 +105,42 @@ class BroadlinkUpdateManager(ABC):
return data return data
@abstractmethod @abstractmethod
async def async_fetch_data(self): async def async_fetch_data(self) -> dict[str, Any] | None:
"""Fetch data from the device.""" """Fetch data from the device."""
class BroadlinkA1UpdateManager(BroadlinkUpdateManager): class BroadlinkA1UpdateManager(BroadlinkUpdateManager[blk.a1]):
"""Manages updates for Broadlink A1 devices.""" """Manages updates for Broadlink A1 devices."""
SCAN_INTERVAL = timedelta(seconds=10) SCAN_INTERVAL = timedelta(seconds=10)
async def async_fetch_data(self): async def async_fetch_data(self) -> dict[str, Any]:
"""Fetch data from the device.""" """Fetch data from the device."""
return await self.device.async_request(self.device.api.check_sensors_raw) return await self.device.async_request(self.device.api.check_sensors_raw)
class BroadlinkMP1UpdateManager(BroadlinkUpdateManager): class BroadlinkMP1UpdateManager(BroadlinkUpdateManager[blk.mp1]):
"""Manages updates for Broadlink MP1 devices.""" """Manages updates for Broadlink MP1 devices."""
async def async_fetch_data(self): async def async_fetch_data(self) -> dict[str, Any]:
"""Fetch data from the device.""" """Fetch data from the device."""
return await self.device.async_request(self.device.api.check_power) return await self.device.async_request(self.device.api.check_power)
class BroadlinkMP1SUpdateManager(BroadlinkUpdateManager): class BroadlinkMP1SUpdateManager(BroadlinkUpdateManager[blk.mp1s]):
"""Manages updates for Broadlink MP1 devices.""" """Manages updates for Broadlink MP1 devices."""
async def async_fetch_data(self): async def async_fetch_data(self) -> dict[str, Any]:
"""Fetch data from the device.""" """Fetch data from the device."""
power = await self.device.async_request(self.device.api.check_power) power = await self.device.async_request(self.device.api.check_power)
sensors = await self.device.async_request(self.device.api.get_state) sensors = await self.device.async_request(self.device.api.get_state)
return {**power, **sensors} return {**power, **sensors}
class BroadlinkRMUpdateManager(BroadlinkUpdateManager): class BroadlinkRMUpdateManager(BroadlinkUpdateManager[blk.rm]):
"""Manages updates for Broadlink remotes.""" """Manages updates for Broadlink remotes."""
async def async_fetch_data(self): async def async_fetch_data(self) -> dict[str, Any]:
"""Fetch data from the device.""" """Fetch data from the device."""
device = self.device device = self.device
@ -138,7 +152,9 @@ class BroadlinkRMUpdateManager(BroadlinkUpdateManager):
return {} return {}
@staticmethod @staticmethod
def normalize(data, previous_data): def normalize(
data: dict[str, Any], previous_data: dict[str, Any] | None
) -> dict[str, Any]:
"""Fix firmware issue. """Fix firmware issue.
See https://github.com/home-assistant/core/issues/42100. See https://github.com/home-assistant/core/issues/42100.
@ -151,18 +167,18 @@ class BroadlinkRMUpdateManager(BroadlinkUpdateManager):
return data return data
class BroadlinkSP1UpdateManager(BroadlinkUpdateManager): class BroadlinkSP1UpdateManager(BroadlinkUpdateManager[blk.sp1]):
"""Manages updates for Broadlink SP1 devices.""" """Manages updates for Broadlink SP1 devices."""
async def async_fetch_data(self): async def async_fetch_data(self) -> None:
"""Fetch data from the device.""" """Fetch data from the device."""
return None return None
class BroadlinkSP2UpdateManager(BroadlinkUpdateManager): class BroadlinkSP2UpdateManager(BroadlinkUpdateManager[blk.sp2]):
"""Manages updates for Broadlink SP2 devices.""" """Manages updates for Broadlink SP2 devices."""
async def async_fetch_data(self): async def async_fetch_data(self) -> dict[str, Any]:
"""Fetch data from the device.""" """Fetch data from the device."""
device = self.device device = self.device
@ -175,33 +191,33 @@ class BroadlinkSP2UpdateManager(BroadlinkUpdateManager):
return data return data
class BroadlinkBG1UpdateManager(BroadlinkUpdateManager): class BroadlinkBG1UpdateManager(BroadlinkUpdateManager[blk.bg1]):
"""Manages updates for Broadlink BG1 devices.""" """Manages updates for Broadlink BG1 devices."""
async def async_fetch_data(self): async def async_fetch_data(self) -> dict[str, Any]:
"""Fetch data from the device.""" """Fetch data from the device."""
return await self.device.async_request(self.device.api.get_state) return await self.device.async_request(self.device.api.get_state)
class BroadlinkSP4UpdateManager(BroadlinkUpdateManager): class BroadlinkSP4UpdateManager(BroadlinkUpdateManager[blk.sp4]):
"""Manages updates for Broadlink SP4 devices.""" """Manages updates for Broadlink SP4 devices."""
async def async_fetch_data(self): async def async_fetch_data(self) -> dict[str, Any]:
"""Fetch data from the device.""" """Fetch data from the device."""
return await self.device.async_request(self.device.api.get_state) return await self.device.async_request(self.device.api.get_state)
class BroadlinkLB1UpdateManager(BroadlinkUpdateManager): class BroadlinkLB1UpdateManager(BroadlinkUpdateManager[blk.lb1]):
"""Manages updates for Broadlink LB1 devices.""" """Manages updates for Broadlink LB1 devices."""
async def async_fetch_data(self): async def async_fetch_data(self) -> dict[str, Any]:
"""Fetch data from the device.""" """Fetch data from the device."""
return await self.device.async_request(self.device.api.get_state) return await self.device.async_request(self.device.api.get_state)
class BroadlinkThermostatUpdateManager(BroadlinkUpdateManager): class BroadlinkThermostatUpdateManager(BroadlinkUpdateManager[blk.hysen]):
"""Manages updates for thermostats with Broadlink DNA.""" """Manages updates for thermostats with Broadlink DNA."""
async def async_fetch_data(self): async def async_fetch_data(self) -> dict[str, Any]:
"""Fetch data from the device.""" """Fetch data from the device."""
return await self.device.async_request(self.device.api.get_full_status) return await self.device.async_request(self.device.api.get_full_status)