From 1826795d3771f08c86d66e4effebe463e07179e3 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Mon, 14 Nov 2022 14:32:10 +0100 Subject: [PATCH] Add TagProtocol for type checking (#81086) * Add TagProtocol for type checking * Adjust type hints --- homeassistant/components/esphome/__init__.py | 7 +++++-- homeassistant/components/tag/__init__.py | 15 ++++++++++++++- homeassistant/components/tag/trigger.py | 8 ++++++-- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/homeassistant/components/esphome/__init__.py b/homeassistant/components/esphome/__init__.py index 23b6a6550e4..ea909725b9f 100644 --- a/homeassistant/components/esphome/__init__.py +++ b/homeassistant/components/esphome/__init__.py @@ -5,7 +5,7 @@ from collections.abc import Callable import functools import logging import math -from typing import Any, Generic, NamedTuple, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeVar, cast, overload from aioesphomeapi import ( APIClient, @@ -55,6 +55,9 @@ from .domain_data import DOMAIN, DomainData # Import config flow so that it's added to the registry from .entry_data import RuntimeEntryData +if TYPE_CHECKING: + from homeassistant.components.tag import TagProtocol + CONF_NOISE_PSK = "noise_psk" _LOGGER = logging.getLogger(__name__) _R = TypeVar("_R") @@ -133,7 +136,7 @@ async def async_setup_entry( # noqa: C901 if service_name == "tag_scanned" and device_id is not None: # Importing tag via hass.components in case it is overridden # in a custom_components (custom_components.tag) - tag = hass.components.tag + tag: TagProtocol = hass.components.tag tag_id = service_data["tag_id"] hass.async_create_task(tag.async_scan_tag(tag_id, device_id)) return diff --git a/homeassistant/components/tag/__init__.py b/homeassistant/components/tag/__init__.py index c05b4416343..1b3ee9b646b 100644 --- a/homeassistant/components/tag/__init__.py +++ b/homeassistant/components/tag/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging +from typing import Protocol import uuid import voluptuous as vol @@ -106,9 +107,21 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True +class TagProtocol(Protocol): + """Protocol for type checking.""" + + async def async_scan_tag( + self, tag_id: str, device_id: str | None, context: Context | None = None + ) -> None: + """Handle when a tag is scanned.""" + + @bind_hass async def async_scan_tag( - hass: HomeAssistant, tag_id: str, device_id: str, context: Context | None = None + hass: HomeAssistant, + tag_id: str, + device_id: str | None, + context: Context | None = None, ) -> None: """Handle when a tag is scanned.""" if DOMAIN not in hass.config.components: diff --git a/homeassistant/components/tag/trigger.py b/homeassistant/components/tag/trigger.py index 146521dfba9..b6d77737eab 100644 --- a/homeassistant/components/tag/trigger.py +++ b/homeassistant/components/tag/trigger.py @@ -1,4 +1,6 @@ """Support for tag triggers.""" +from __future__ import annotations + import voluptuous as vol from homeassistant.const import CONF_PLATFORM @@ -26,8 +28,10 @@ async def async_attach_trigger( ) -> CALLBACK_TYPE: """Listen for tag_scanned events based on configuration.""" trigger_data = trigger_info["trigger_data"] - tag_ids = set(config[TAG_ID]) - device_ids = set(config[DEVICE_ID]) if DEVICE_ID in config else None + tag_ids: set[str] = set(config[TAG_ID]) + device_ids: set[str] | None = ( + set(config[DEVICE_ID]) if DEVICE_ID in config else None + ) job = HassJob(action)