Update typing 02 (#48014)

This commit is contained in:
Marc Mueller 2021-03-17 18:34:19 +01:00 committed by GitHub
parent 86d3baa34e
commit 6fb2e63e49
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
47 changed files with 717 additions and 706 deletions

View file

@ -1,6 +1,8 @@
"""Helper methods for components within Home Assistant.""" """Helper methods for components within Home Assistant."""
from __future__ import annotations
import re import re
from typing import TYPE_CHECKING, Any, Iterable, Sequence, Tuple from typing import TYPE_CHECKING, Any, Iterable, Sequence
from homeassistant.const import CONF_PLATFORM from homeassistant.const import CONF_PLATFORM
@ -8,7 +10,7 @@ if TYPE_CHECKING:
from .typing import ConfigType from .typing import ConfigType
def config_per_platform(config: "ConfigType", domain: str) -> Iterable[Tuple[Any, Any]]: def config_per_platform(config: "ConfigType", domain: str) -> Iterable[tuple[Any, Any]]:
"""Break a component config into different platforms. """Break a component config into different platforms.
For example, will find 'switch', 'switch 2', 'switch 3', .. etc For example, will find 'switch', 'switch 2', 'switch 3', .. etc

View file

@ -1,8 +1,10 @@
"""Helper for aiohttp webclient stuff.""" """Helper for aiohttp webclient stuff."""
from __future__ import annotations
import asyncio import asyncio
from ssl import SSLContext from ssl import SSLContext
import sys import sys
from typing import Any, Awaitable, Optional, Union, cast from typing import Any, Awaitable, cast
import aiohttp import aiohttp
from aiohttp import web from aiohttp import web
@ -87,7 +89,7 @@ async def async_aiohttp_proxy_web(
web_coro: Awaitable[aiohttp.ClientResponse], web_coro: Awaitable[aiohttp.ClientResponse],
buffer_size: int = 102400, buffer_size: int = 102400,
timeout: int = 10, timeout: int = 10,
) -> Optional[web.StreamResponse]: ) -> web.StreamResponse | None:
"""Stream websession request to aiohttp web response.""" """Stream websession request to aiohttp web response."""
try: try:
with async_timeout.timeout(timeout): with async_timeout.timeout(timeout):
@ -118,7 +120,7 @@ async def async_aiohttp_proxy_stream(
hass: HomeAssistantType, hass: HomeAssistantType,
request: web.BaseRequest, request: web.BaseRequest,
stream: aiohttp.StreamReader, stream: aiohttp.StreamReader,
content_type: Optional[str], content_type: str | None,
buffer_size: int = 102400, buffer_size: int = 102400,
timeout: int = 10, timeout: int = 10,
) -> web.StreamResponse: ) -> web.StreamResponse:
@ -175,7 +177,7 @@ def _async_get_connector(
return cast(aiohttp.BaseConnector, hass.data[key]) return cast(aiohttp.BaseConnector, hass.data[key])
if verify_ssl: if verify_ssl:
ssl_context: Union[bool, SSLContext] = ssl_util.client_context() ssl_context: bool | SSLContext = ssl_util.client_context()
else: else:
ssl_context = False ssl_context = False

View file

@ -1,6 +1,8 @@
"""Provide a way to connect devices to one physical location.""" """Provide a way to connect devices to one physical location."""
from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
from typing import Container, Dict, Iterable, List, MutableMapping, Optional, cast from typing import Container, Iterable, MutableMapping, cast
import attr import attr
@ -26,7 +28,7 @@ class AreaEntry:
name: str = attr.ib() name: str = attr.ib()
normalized_name: str = attr.ib() normalized_name: str = attr.ib()
id: Optional[str] = attr.ib(default=None) id: str | None = attr.ib(default=None)
def generate_id(self, existing_ids: Container[str]) -> None: def generate_id(self, existing_ids: Container[str]) -> None:
"""Initialize ID.""" """Initialize ID."""
@ -46,15 +48,15 @@ class AreaRegistry:
self.hass = hass self.hass = hass
self.areas: MutableMapping[str, AreaEntry] = {} self.areas: MutableMapping[str, AreaEntry] = {}
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self._normalized_name_area_idx: Dict[str, str] = {} self._normalized_name_area_idx: dict[str, str] = {}
@callback @callback
def async_get_area(self, area_id: str) -> Optional[AreaEntry]: def async_get_area(self, area_id: str) -> AreaEntry | None:
"""Get area by id.""" """Get area by id."""
return self.areas.get(area_id) return self.areas.get(area_id)
@callback @callback
def async_get_area_by_name(self, name: str) -> Optional[AreaEntry]: def async_get_area_by_name(self, name: str) -> AreaEntry | None:
"""Get area by name.""" """Get area by name."""
normalized_name = normalize_area_name(name) normalized_name = normalize_area_name(name)
if normalized_name not in self._normalized_name_area_idx: if normalized_name not in self._normalized_name_area_idx:
@ -171,7 +173,7 @@ class AreaRegistry:
self._store.async_delay_save(self._data_to_save, SAVE_DELAY) self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback @callback
def _data_to_save(self) -> Dict[str, List[Dict[str, Optional[str]]]]: def _data_to_save(self) -> dict[str, list[dict[str, str | None]]]:
"""Return data of area registry to store in a file.""" """Return data of area registry to store in a file."""
data = {} data = {}

View file

@ -5,7 +5,7 @@ from collections import OrderedDict
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import List, NamedTuple, Optional from typing import NamedTuple
import voluptuous as vol import voluptuous as vol
@ -35,8 +35,8 @@ class CheckConfigError(NamedTuple):
"""Configuration check error.""" """Configuration check error."""
message: str message: str
domain: Optional[str] domain: str | None
config: Optional[ConfigType] config: ConfigType | None
class HomeAssistantConfig(OrderedDict): class HomeAssistantConfig(OrderedDict):
@ -45,13 +45,13 @@ class HomeAssistantConfig(OrderedDict):
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize HA config.""" """Initialize HA config."""
super().__init__() super().__init__()
self.errors: List[CheckConfigError] = [] self.errors: list[CheckConfigError] = []
def add_error( def add_error(
self, self,
message: str, message: str,
domain: Optional[str] = None, domain: str | None = None,
config: Optional[ConfigType] = None, config: ConfigType | None = None,
) -> HomeAssistantConfig: ) -> HomeAssistantConfig:
"""Add a single error.""" """Add a single error."""
self.errors.append(CheckConfigError(str(message), domain, config)) self.errors.append(CheckConfigError(str(message), domain, config))

View file

@ -1,9 +1,11 @@
"""Helper to deal with YAML + storage.""" """Helper to deal with YAML + storage."""
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, cast from typing import Any, Awaitable, Callable, Iterable, Optional, cast
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
@ -72,9 +74,9 @@ class IDManager:
def __init__(self) -> None: def __init__(self) -> None:
"""Initiate the ID manager.""" """Initiate the ID manager."""
self.collections: List[Dict[str, Any]] = [] self.collections: list[dict[str, Any]] = []
def add_collection(self, collection: Dict[str, Any]) -> None: def add_collection(self, collection: dict[str, Any]) -> None:
"""Add a collection to check for ID usage.""" """Add a collection to check for ID usage."""
self.collections.append(collection) self.collections.append(collection)
@ -98,17 +100,17 @@ class IDManager:
class ObservableCollection(ABC): class ObservableCollection(ABC):
"""Base collection type that can be observed.""" """Base collection type that can be observed."""
def __init__(self, logger: logging.Logger, id_manager: Optional[IDManager] = None): def __init__(self, logger: logging.Logger, id_manager: IDManager | None = None):
"""Initialize the base collection.""" """Initialize the base collection."""
self.logger = logger self.logger = logger
self.id_manager = id_manager or IDManager() self.id_manager = id_manager or IDManager()
self.data: Dict[str, dict] = {} self.data: dict[str, dict] = {}
self.listeners: List[ChangeListener] = [] self.listeners: list[ChangeListener] = []
self.id_manager.add_collection(self.data) self.id_manager.add_collection(self.data)
@callback @callback
def async_items(self) -> List[dict]: def async_items(self) -> list[dict]:
"""Return list of items in collection.""" """Return list of items in collection."""
return list(self.data.values()) return list(self.data.values())
@ -134,7 +136,7 @@ class ObservableCollection(ABC):
class YamlCollection(ObservableCollection): class YamlCollection(ObservableCollection):
"""Offer a collection based on static data.""" """Offer a collection based on static data."""
async def async_load(self, data: List[dict]) -> None: async def async_load(self, data: list[dict]) -> None:
"""Load the YAML collection. Overrides existing data.""" """Load the YAML collection. Overrides existing data."""
old_ids = set(self.data) old_ids = set(self.data)
@ -171,7 +173,7 @@ class StorageCollection(ObservableCollection):
self, self,
store: Store, store: Store,
logger: logging.Logger, logger: logging.Logger,
id_manager: Optional[IDManager] = None, id_manager: IDManager | None = None,
): ):
"""Initialize the storage collection.""" """Initialize the storage collection."""
super().__init__(logger, id_manager) super().__init__(logger, id_manager)
@ -182,7 +184,7 @@ class StorageCollection(ObservableCollection):
"""Home Assistant object.""" """Home Assistant object."""
return self.store.hass return self.store.hass
async def _async_load_data(self) -> Optional[dict]: async def _async_load_data(self) -> dict | None:
"""Load the data.""" """Load the data."""
return cast(Optional[dict], await self.store.async_load()) return cast(Optional[dict], await self.store.async_load())
@ -274,7 +276,7 @@ class IDLessCollection(ObservableCollection):
counter = 0 counter = 0
async def async_load(self, data: List[dict]) -> None: async def async_load(self, data: list[dict]) -> None:
"""Load the collection. Overrides existing data.""" """Load the collection. Overrides existing data."""
await self.notify_changes( await self.notify_changes(
[ [

View file

@ -1,4 +1,6 @@
"""Offer reusable conditions.""" """Offer reusable conditions."""
from __future__ import annotations
import asyncio import asyncio
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
@ -7,7 +9,7 @@ import functools as ft
import logging import logging
import re import re
import sys import sys
from typing import Any, Callable, Container, Generator, List, Optional, Set, Union, cast from typing import Any, Callable, Container, Generator, cast
from homeassistant.components import zone as zone_cmp from homeassistant.components import zone as zone_cmp
from homeassistant.components.device_automation import ( from homeassistant.components.device_automation import (
@ -124,7 +126,7 @@ def trace_condition_function(condition: ConditionCheckerType) -> ConditionChecke
async def async_from_config( async def async_from_config(
hass: HomeAssistant, hass: HomeAssistant,
config: Union[ConfigType, Template], config: ConfigType | Template,
config_validation: bool = True, config_validation: bool = True,
) -> ConditionCheckerType: ) -> ConditionCheckerType:
"""Turn a condition configuration into a method. """Turn a condition configuration into a method.
@ -267,10 +269,10 @@ async def async_not_from_config(
def numeric_state( def numeric_state(
hass: HomeAssistant, hass: HomeAssistant,
entity: Union[None, str, State], entity: None | str | State,
below: Optional[Union[float, str]] = None, below: float | str | None = None,
above: Optional[Union[float, str]] = None, above: float | str | None = None,
value_template: Optional[Template] = None, value_template: Template | None = None,
variables: TemplateVarsType = None, variables: TemplateVarsType = None,
) -> bool: ) -> bool:
"""Test a numeric state condition.""" """Test a numeric state condition."""
@ -288,12 +290,12 @@ def numeric_state(
def async_numeric_state( def async_numeric_state(
hass: HomeAssistant, hass: HomeAssistant,
entity: Union[None, str, State], entity: None | str | State,
below: Optional[Union[float, str]] = None, below: float | str | None = None,
above: Optional[Union[float, str]] = None, above: float | str | None = None,
value_template: Optional[Template] = None, value_template: Template | None = None,
variables: TemplateVarsType = None, variables: TemplateVarsType = None,
attribute: Optional[str] = None, attribute: str | None = None,
) -> bool: ) -> bool:
"""Test a numeric state condition.""" """Test a numeric state condition."""
if entity is None: if entity is None:
@ -456,10 +458,10 @@ def async_numeric_state_from_config(
def state( def state(
hass: HomeAssistant, hass: HomeAssistant,
entity: Union[None, str, State], entity: None | str | State,
req_state: Any, req_state: Any,
for_period: Optional[timedelta] = None, for_period: timedelta | None = None,
attribute: Optional[str] = None, attribute: str | None = None,
) -> bool: ) -> bool:
"""Test if state matches requirements. """Test if state matches requirements.
@ -526,7 +528,7 @@ def state_from_config(
if config_validation: if config_validation:
config = cv.STATE_CONDITION_SCHEMA(config) config = cv.STATE_CONDITION_SCHEMA(config)
entity_ids = config.get(CONF_ENTITY_ID, []) entity_ids = config.get(CONF_ENTITY_ID, [])
req_states: Union[str, List[str]] = config.get(CONF_STATE, []) req_states: str | list[str] = config.get(CONF_STATE, [])
for_period = config.get("for") for_period = config.get("for")
attribute = config.get(CONF_ATTRIBUTE) attribute = config.get(CONF_ATTRIBUTE)
@ -560,10 +562,10 @@ def state_from_config(
def sun( def sun(
hass: HomeAssistant, hass: HomeAssistant,
before: Optional[str] = None, before: str | None = None,
after: Optional[str] = None, after: str | None = None,
before_offset: Optional[timedelta] = None, before_offset: timedelta | None = None,
after_offset: Optional[timedelta] = None, after_offset: timedelta | None = None,
) -> bool: ) -> bool:
"""Test if current time matches sun requirements.""" """Test if current time matches sun requirements."""
utcnow = dt_util.utcnow() utcnow = dt_util.utcnow()
@ -673,9 +675,9 @@ def async_template_from_config(
def time( def time(
hass: HomeAssistant, hass: HomeAssistant,
before: Optional[Union[dt_util.dt.time, str]] = None, before: dt_util.dt.time | str | None = None,
after: Optional[Union[dt_util.dt.time, str]] = None, after: dt_util.dt.time | str | None = None,
weekday: Union[None, str, Container[str]] = None, weekday: None | str | Container[str] = None,
) -> bool: ) -> bool:
"""Test if local time condition matches. """Test if local time condition matches.
@ -752,8 +754,8 @@ def time_from_config(
def zone( def zone(
hass: HomeAssistant, hass: HomeAssistant,
zone_ent: Union[None, str, State], zone_ent: None | str | State,
entity: Union[None, str, State], entity: None | str | State,
) -> bool: ) -> bool:
"""Test if zone-condition matches. """Test if zone-condition matches.
@ -858,8 +860,8 @@ async def async_device_from_config(
async def async_validate_condition_config( async def async_validate_condition_config(
hass: HomeAssistant, config: Union[ConfigType, Template] hass: HomeAssistant, config: ConfigType | Template
) -> Union[ConfigType, Template]: ) -> ConfigType | Template:
"""Validate config.""" """Validate config."""
if isinstance(config, Template): if isinstance(config, Template):
return config return config
@ -884,9 +886,9 @@ async def async_validate_condition_config(
@callback @callback
def async_extract_entities(config: Union[ConfigType, Template]) -> Set[str]: def async_extract_entities(config: ConfigType | Template) -> set[str]:
"""Extract entities from a condition.""" """Extract entities from a condition."""
referenced: Set[str] = set() referenced: set[str] = set()
to_process = deque([config]) to_process = deque([config])
while to_process: while to_process:
@ -912,7 +914,7 @@ def async_extract_entities(config: Union[ConfigType, Template]) -> Set[str]:
@callback @callback
def async_extract_devices(config: Union[ConfigType, Template]) -> Set[str]: def async_extract_devices(config: ConfigType | Template) -> set[str]:
"""Extract devices from a condition.""" """Extract devices from a condition."""
referenced = set() referenced = set()
to_process = deque([config]) to_process = deque([config])

View file

@ -1,5 +1,7 @@
"""Helpers for data entry flows for config entries.""" """Helpers for data entry flows for config entries."""
from typing import Any, Awaitable, Callable, Dict, Optional, Union from __future__ import annotations
from typing import Any, Awaitable, Callable, Union
from homeassistant import config_entries from homeassistant import config_entries
@ -27,8 +29,8 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow):
self.CONNECTION_CLASS = connection_class # pylint: disable=invalid-name self.CONNECTION_CLASS = connection_class # pylint: disable=invalid-name
async def async_step_user( async def async_step_user(
self, user_input: Optional[Dict[str, Any]] = None self, user_input: dict[str, Any] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Handle a flow initialized by the user.""" """Handle a flow initialized by the user."""
if self._async_current_entries(): if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")
@ -38,8 +40,8 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow):
return await self.async_step_confirm() return await self.async_step_confirm()
async def async_step_confirm( async def async_step_confirm(
self, user_input: Optional[Dict[str, Any]] = None self, user_input: dict[str, Any] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Confirm setup.""" """Confirm setup."""
if user_input is None: if user_input is None:
self._set_confirm_only() self._set_confirm_only()
@ -68,8 +70,8 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow):
return self.async_create_entry(title=self._title, data={}) return self.async_create_entry(title=self._title, data={})
async def async_step_discovery( async def async_step_discovery(
self, discovery_info: Dict[str, Any] self, discovery_info: dict[str, Any]
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Handle a flow initialized by discovery.""" """Handle a flow initialized by discovery."""
if self._async_in_progress() or self._async_current_entries(): if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")
@ -84,7 +86,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow):
async_step_homekit = async_step_discovery async_step_homekit = async_step_discovery
async_step_dhcp = async_step_discovery async_step_dhcp = async_step_discovery
async def async_step_import(self, _: Optional[Dict[str, Any]]) -> Dict[str, Any]: async def async_step_import(self, _: dict[str, Any] | None) -> dict[str, Any]:
"""Handle a flow initialized by import.""" """Handle a flow initialized by import."""
if self._async_current_entries(): if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")
@ -133,8 +135,8 @@ class WebhookFlowHandler(config_entries.ConfigFlow):
self._allow_multiple = allow_multiple self._allow_multiple = allow_multiple
async def async_step_user( async def async_step_user(
self, user_input: Optional[Dict[str, Any]] = None self, user_input: dict[str, Any] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Handle a user initiated set up flow to create a webhook.""" """Handle a user initiated set up flow to create a webhook."""
if not self._allow_multiple and self._async_current_entries(): if not self._allow_multiple and self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")

View file

@ -5,12 +5,14 @@ This module exists of the following parts:
- OAuth2 implementation that works with local provided client ID/secret - OAuth2 implementation that works with local provided client ID/secret
""" """
from __future__ import annotations
from abc import ABC, ABCMeta, abstractmethod from abc import ABC, ABCMeta, abstractmethod
import asyncio import asyncio
import logging import logging
import secrets import secrets
import time import time
from typing import Any, Awaitable, Callable, Dict, Optional, cast from typing import Any, Awaitable, Callable, Dict, cast
from aiohttp import client, web from aiohttp import client, web
import async_timeout import async_timeout
@ -231,7 +233,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
return {} return {}
async def async_step_pick_implementation( async def async_step_pick_implementation(
self, user_input: Optional[dict] = None self, user_input: dict | None = None
) -> dict: ) -> dict:
"""Handle a flow start.""" """Handle a flow start."""
implementations = await async_get_implementations(self.hass, self.DOMAIN) implementations = await async_get_implementations(self.hass, self.DOMAIN)
@ -260,8 +262,8 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
) )
async def async_step_auth( async def async_step_auth(
self, user_input: Optional[Dict[str, Any]] = None self, user_input: dict[str, Any] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Create an entry for auth.""" """Create an entry for auth."""
# Flow has been triggered by external data # Flow has been triggered by external data
if user_input: if user_input:
@ -286,8 +288,8 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
return self.async_external_step(step_id="auth", url=url) return self.async_external_step(step_id="auth", url=url)
async def async_step_creation( async def async_step_creation(
self, user_input: Optional[Dict[str, Any]] = None self, user_input: dict[str, Any] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Create config entry from external data.""" """Create config entry from external data."""
token = await self.flow_impl.async_resolve_external_data(self.external_data) token = await self.flow_impl.async_resolve_external_data(self.external_data)
# Force int for non-compliant oauth2 providers # Force int for non-compliant oauth2 providers
@ -312,8 +314,8 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
return self.async_create_entry(title=self.flow_impl.name, data=data) return self.async_create_entry(title=self.flow_impl.name, data=data)
async def async_step_discovery( async def async_step_discovery(
self, discovery_info: Dict[str, Any] self, discovery_info: dict[str, Any]
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Handle a flow initialized by discovery.""" """Handle a flow initialized by discovery."""
await self.async_set_unique_id(self.DOMAIN) await self.async_set_unique_id(self.DOMAIN)
@ -354,7 +356,7 @@ def async_register_implementation(
async def async_get_implementations( async def async_get_implementations(
hass: HomeAssistant, domain: str hass: HomeAssistant, domain: str
) -> Dict[str, AbstractOAuth2Implementation]: ) -> dict[str, AbstractOAuth2Implementation]:
"""Return OAuth2 implementations for specified domain.""" """Return OAuth2 implementations for specified domain."""
registered = cast( registered = cast(
Dict[str, AbstractOAuth2Implementation], Dict[str, AbstractOAuth2Implementation],
@ -392,7 +394,7 @@ def async_add_implementation_provider(
hass: HomeAssistant, hass: HomeAssistant,
provider_domain: str, provider_domain: str,
async_provide_implementation: Callable[ async_provide_implementation: Callable[
[HomeAssistant, str], Awaitable[Optional[AbstractOAuth2Implementation]] [HomeAssistant, str], Awaitable[AbstractOAuth2Implementation | None]
], ],
) -> None: ) -> None:
"""Add an implementation provider. """Add an implementation provider.
@ -516,7 +518,7 @@ def _encode_jwt(hass: HomeAssistant, data: dict) -> str:
@callback @callback
def _decode_jwt(hass: HomeAssistant, encoded: str) -> Optional[dict]: def _decode_jwt(hass: HomeAssistant, encoded: str) -> dict | None:
"""JWT encode data.""" """JWT encode data."""
secret = cast(str, hass.data.get(DATA_JWT_SECRET)) secret = cast(str, hass.data.get(DATA_JWT_SECRET))

View file

@ -1,4 +1,6 @@
"""Helpers for config validation using voluptuous.""" """Helpers for config validation using voluptuous."""
from __future__ import annotations
from datetime import ( from datetime import (
date as date_sys, date as date_sys,
datetime as datetime_sys, datetime as datetime_sys,
@ -12,19 +14,7 @@ from numbers import Number
import os import os
import re import re
from socket import _GLOBAL_DEFAULT_TIMEOUT # type: ignore # private, not in typeshed from socket import _GLOBAL_DEFAULT_TIMEOUT # type: ignore # private, not in typeshed
from typing import ( from typing import Any, Callable, Dict, Hashable, Pattern, TypeVar, cast
Any,
Callable,
Dict,
Hashable,
List,
Optional,
Pattern,
Type,
TypeVar,
Union,
cast,
)
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import UUID from uuid import UUID
@ -131,7 +121,7 @@ def path(value: Any) -> str:
def has_at_least_one_key(*keys: str) -> Callable: def has_at_least_one_key(*keys: str) -> Callable:
"""Validate that at least one key exists.""" """Validate that at least one key exists."""
def validate(obj: Dict) -> Dict: def validate(obj: dict) -> dict:
"""Test keys exist in dict.""" """Test keys exist in dict."""
if not isinstance(obj, dict): if not isinstance(obj, dict):
raise vol.Invalid("expected dictionary") raise vol.Invalid("expected dictionary")
@ -144,10 +134,10 @@ def has_at_least_one_key(*keys: str) -> Callable:
return validate return validate
def has_at_most_one_key(*keys: str) -> Callable[[Dict], Dict]: def has_at_most_one_key(*keys: str) -> Callable[[dict], dict]:
"""Validate that zero keys exist or one key exists.""" """Validate that zero keys exist or one key exists."""
def validate(obj: Dict) -> Dict: def validate(obj: dict) -> dict:
"""Test zero keys exist or one key exists in dict.""" """Test zero keys exist or one key exists in dict."""
if not isinstance(obj, dict): if not isinstance(obj, dict):
raise vol.Invalid("expected dictionary") raise vol.Invalid("expected dictionary")
@ -253,7 +243,7 @@ def isdir(value: Any) -> str:
return dir_in return dir_in
def ensure_list(value: Union[T, List[T], None]) -> List[T]: def ensure_list(value: T | list[T] | None) -> list[T]:
"""Wrap value in list if it is not one.""" """Wrap value in list if it is not one."""
if value is None: if value is None:
return [] return []
@ -269,7 +259,7 @@ def entity_id(value: Any) -> str:
raise vol.Invalid(f"Entity ID {value} is an invalid entity ID") raise vol.Invalid(f"Entity ID {value} is an invalid entity ID")
def entity_ids(value: Union[str, List]) -> List[str]: def entity_ids(value: str | list) -> list[str]:
"""Validate Entity IDs.""" """Validate Entity IDs."""
if value is None: if value is None:
raise vol.Invalid("Entity IDs can not be None") raise vol.Invalid("Entity IDs can not be None")
@ -284,7 +274,7 @@ comp_entity_ids = vol.Any(
) )
def entity_domain(domain: Union[str, List[str]]) -> Callable[[Any], str]: def entity_domain(domain: str | list[str]) -> Callable[[Any], str]:
"""Validate that entity belong to domain.""" """Validate that entity belong to domain."""
ent_domain = entities_domain(domain) ent_domain = entities_domain(domain)
@ -298,9 +288,7 @@ def entity_domain(domain: Union[str, List[str]]) -> Callable[[Any], str]:
return validate return validate
def entities_domain( def entities_domain(domain: str | list[str]) -> Callable[[str | list], list[str]]:
domain: Union[str, List[str]]
) -> Callable[[Union[str, List]], List[str]]:
"""Validate that entities belong to domain.""" """Validate that entities belong to domain."""
if isinstance(domain, str): if isinstance(domain, str):
@ -312,7 +300,7 @@ def entities_domain(
def check_invalid(val: str) -> bool: def check_invalid(val: str) -> bool:
return val not in domain return val not in domain
def validate(values: Union[str, List]) -> List[str]: def validate(values: str | list) -> list[str]:
"""Test if entity domain is domain.""" """Test if entity domain is domain."""
values = entity_ids(values) values = entity_ids(values)
for ent_id in values: for ent_id in values:
@ -325,7 +313,7 @@ def entities_domain(
return validate return validate
def enum(enumClass: Type[Enum]) -> vol.All: def enum(enumClass: type[Enum]) -> vol.All:
"""Create validator for specified enum.""" """Create validator for specified enum."""
return vol.All(vol.In(enumClass.__members__), enumClass.__getitem__) return vol.All(vol.In(enumClass.__members__), enumClass.__getitem__)
@ -423,7 +411,7 @@ def time_period_str(value: str) -> timedelta:
return offset return offset
def time_period_seconds(value: Union[float, str]) -> timedelta: def time_period_seconds(value: float | str) -> timedelta:
"""Validate and transform seconds to a time offset.""" """Validate and transform seconds to a time offset."""
try: try:
return timedelta(seconds=float(value)) return timedelta(seconds=float(value))
@ -450,7 +438,7 @@ positive_time_period_dict = vol.All(time_period_dict, positive_timedelta)
positive_time_period = vol.All(time_period, positive_timedelta) positive_time_period = vol.All(time_period, positive_timedelta)
def remove_falsy(value: List[T]) -> List[T]: def remove_falsy(value: list[T]) -> list[T]:
"""Remove falsy values from a list.""" """Remove falsy values from a list."""
return [v for v in value if v] return [v for v in value if v]
@ -477,7 +465,7 @@ def slug(value: Any) -> str:
def schema_with_slug_keys( def schema_with_slug_keys(
value_schema: Union[T, Callable], *, slug_validator: Callable[[Any], str] = slug value_schema: T | Callable, *, slug_validator: Callable[[Any], str] = slug
) -> Callable: ) -> Callable:
"""Ensure dicts have slugs as keys. """Ensure dicts have slugs as keys.
@ -486,7 +474,7 @@ def schema_with_slug_keys(
""" """
schema = vol.Schema({str: value_schema}) schema = vol.Schema({str: value_schema})
def verify(value: Dict) -> Dict: def verify(value: dict) -> dict:
"""Validate all keys are slugs and then the value_schema.""" """Validate all keys are slugs and then the value_schema."""
if not isinstance(value, dict): if not isinstance(value, dict):
raise vol.Invalid("expected dictionary") raise vol.Invalid("expected dictionary")
@ -547,7 +535,7 @@ unit_system = vol.All(
) )
def template(value: Optional[Any]) -> template_helper.Template: def template(value: Any | None) -> template_helper.Template:
"""Validate a jinja2 template.""" """Validate a jinja2 template."""
if value is None: if value is None:
raise vol.Invalid("template value is None") raise vol.Invalid("template value is None")
@ -563,7 +551,7 @@ def template(value: Optional[Any]) -> template_helper.Template:
raise vol.Invalid(f"invalid template ({ex})") from ex raise vol.Invalid(f"invalid template ({ex})") from ex
def dynamic_template(value: Optional[Any]) -> template_helper.Template: def dynamic_template(value: Any | None) -> template_helper.Template:
"""Validate a dynamic (non static) jinja2 template.""" """Validate a dynamic (non static) jinja2 template."""
if value is None: if value is None:
raise vol.Invalid("template value is None") raise vol.Invalid("template value is None")
@ -632,7 +620,7 @@ def time_zone(value: str) -> str:
weekdays = vol.All(ensure_list, [vol.In(WEEKDAYS)]) weekdays = vol.All(ensure_list, [vol.In(WEEKDAYS)])
def socket_timeout(value: Optional[Any]) -> object: def socket_timeout(value: Any | None) -> object:
"""Validate timeout float > 0.0. """Validate timeout float > 0.0.
None coerced to socket._GLOBAL_DEFAULT_TIMEOUT bare object. None coerced to socket._GLOBAL_DEFAULT_TIMEOUT bare object.
@ -681,7 +669,7 @@ def uuid4_hex(value: Any) -> str:
return result.hex return result.hex
def ensure_list_csv(value: Any) -> List: def ensure_list_csv(value: Any) -> list:
"""Ensure that input is a list or make one from comma-separated string.""" """Ensure that input is a list or make one from comma-separated string."""
if isinstance(value, str): if isinstance(value, str):
return [member.strip() for member in value.split(",")] return [member.strip() for member in value.split(",")]
@ -709,9 +697,9 @@ class multi_select:
def deprecated( def deprecated(
key: str, key: str,
replacement_key: Optional[str] = None, replacement_key: str | None = None,
default: Optional[Any] = None, default: Any | None = None,
) -> Callable[[Dict], Dict]: ) -> Callable[[dict], dict]:
""" """
Log key as deprecated and provide a replacement (if exists). Log key as deprecated and provide a replacement (if exists).
@ -743,7 +731,7 @@ def deprecated(
" please remove it from your configuration" " please remove it from your configuration"
) )
def validator(config: Dict) -> Dict: def validator(config: dict) -> dict:
"""Check if key is in config and log warning.""" """Check if key is in config and log warning."""
if key in config: if key in config:
try: try:
@ -781,14 +769,14 @@ def deprecated(
def key_value_schemas( def key_value_schemas(
key: str, value_schemas: Dict[str, vol.Schema] key: str, value_schemas: dict[str, vol.Schema]
) -> Callable[[Any], Dict[str, Any]]: ) -> Callable[[Any], dict[str, Any]]:
"""Create a validator that validates based on a value for specific key. """Create a validator that validates based on a value for specific key.
This gives better error messages. This gives better error messages.
""" """
def key_value_validator(value: Any) -> Dict[str, Any]: def key_value_validator(value: Any) -> dict[str, Any]:
if not isinstance(value, dict): if not isinstance(value, dict):
raise vol.Invalid("Expected a dictionary") raise vol.Invalid("Expected a dictionary")
@ -809,10 +797,10 @@ def key_value_schemas(
def key_dependency( def key_dependency(
key: Hashable, dependency: Hashable key: Hashable, dependency: Hashable
) -> Callable[[Dict[Hashable, Any]], Dict[Hashable, Any]]: ) -> Callable[[dict[Hashable, Any]], dict[Hashable, Any]]:
"""Validate that all dependencies exist for key.""" """Validate that all dependencies exist for key."""
def validator(value: Dict[Hashable, Any]) -> Dict[Hashable, Any]: def validator(value: dict[Hashable, Any]) -> dict[Hashable, Any]:
"""Test dependencies.""" """Test dependencies."""
if not isinstance(value, dict): if not isinstance(value, dict):
raise vol.Invalid("key dependencies require a dict") raise vol.Invalid("key dependencies require a dict")
@ -1247,7 +1235,7 @@ def determine_script_action(action: dict) -> str:
return SCRIPT_ACTION_CALL_SERVICE return SCRIPT_ACTION_CALL_SERVICE
ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = { ACTION_TYPE_SCHEMAS: dict[str, Callable[[Any], dict]] = {
SCRIPT_ACTION_CALL_SERVICE: SERVICE_SCHEMA, SCRIPT_ACTION_CALL_SERVICE: SERVICE_SCHEMA,
SCRIPT_ACTION_DELAY: _SCRIPT_DELAY_SCHEMA, SCRIPT_ACTION_DELAY: _SCRIPT_DELAY_SCHEMA,
SCRIPT_ACTION_WAIT_TEMPLATE: _SCRIPT_WAIT_TEMPLATE_SCHEMA, SCRIPT_ACTION_WAIT_TEMPLATE: _SCRIPT_WAIT_TEMPLATE_SCHEMA,

View file

@ -1,6 +1,7 @@
"""Helpers for the data entry flow.""" """Helpers for the data entry flow."""
from __future__ import annotations
from typing import Any, Dict from typing import Any
from aiohttp import web from aiohttp import web
import voluptuous as vol import voluptuous as vol
@ -20,7 +21,7 @@ class _BaseFlowManagerView(HomeAssistantView):
self._flow_mgr = flow_mgr self._flow_mgr = flow_mgr
# pylint: disable=no-self-use # pylint: disable=no-self-use
def _prepare_result_json(self, result: Dict[str, Any]) -> Dict[str, Any]: def _prepare_result_json(self, result: dict[str, Any]) -> dict[str, Any]:
"""Convert result to JSON.""" """Convert result to JSON."""
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY: if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
data = result.copy() data = result.copy()
@ -58,7 +59,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
) )
) )
async def post(self, request: web.Request, data: Dict[str, Any]) -> web.Response: async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
"""Handle a POST request.""" """Handle a POST request."""
if isinstance(data["handler"], list): if isinstance(data["handler"], list):
handler = tuple(data["handler"]) handler = tuple(data["handler"])
@ -99,7 +100,7 @@ class FlowManagerResourceView(_BaseFlowManagerView):
@RequestDataValidator(vol.Schema(dict), allow_empty=True) @RequestDataValidator(vol.Schema(dict), allow_empty=True)
async def post( async def post(
self, request: web.Request, flow_id: str, data: Dict[str, Any] self, request: web.Request, flow_id: str, data: dict[str, Any]
) -> web.Response: ) -> web.Response:
"""Handle a POST request.""" """Handle a POST request."""
try: try:

View file

@ -1,7 +1,9 @@
"""Debounce helper.""" """Debounce helper."""
from __future__ import annotations
import asyncio import asyncio
from logging import Logger from logging import Logger
from typing import Any, Awaitable, Callable, Optional from typing import Any, Awaitable, Callable
from homeassistant.core import HassJob, HomeAssistant, callback from homeassistant.core import HassJob, HomeAssistant, callback
@ -16,7 +18,7 @@ class Debouncer:
*, *,
cooldown: float, cooldown: float,
immediate: bool, immediate: bool,
function: Optional[Callable[..., Awaitable[Any]]] = None, function: Callable[..., Awaitable[Any]] | None = None,
): ):
"""Initialize debounce. """Initialize debounce.
@ -29,13 +31,13 @@ class Debouncer:
self._function = function self._function = function
self.cooldown = cooldown self.cooldown = cooldown
self.immediate = immediate self.immediate = immediate
self._timer_task: Optional[asyncio.TimerHandle] = None self._timer_task: asyncio.TimerHandle | None = None
self._execute_at_end_of_timer: bool = False self._execute_at_end_of_timer: bool = False
self._execute_lock = asyncio.Lock() self._execute_lock = asyncio.Lock()
self._job: Optional[HassJob] = None if function is None else HassJob(function) self._job: HassJob | None = None if function is None else HassJob(function)
@property @property
def function(self) -> Optional[Callable[..., Awaitable[Any]]]: def function(self) -> Callable[..., Awaitable[Any]] | None:
"""Return the function being wrapped by the Debouncer.""" """Return the function being wrapped by the Debouncer."""
return self._function return self._function

View file

@ -1,8 +1,10 @@
"""Deprecation helpers for Home Assistant.""" """Deprecation helpers for Home Assistant."""
from __future__ import annotations
import functools import functools
import inspect import inspect
import logging import logging
from typing import Any, Callable, Dict, Optional from typing import Any, Callable
from ..helpers.frame import MissingIntegrationFrame, get_integration_frame from ..helpers.frame import MissingIntegrationFrame, get_integration_frame
@ -49,8 +51,8 @@ def deprecated_substitute(substitute_name: str) -> Callable[..., Callable]:
def get_deprecated( def get_deprecated(
config: Dict[str, Any], new_name: str, old_name: str, default: Optional[Any] = None config: dict[str, Any], new_name: str, old_name: str, default: Any | None = None
) -> Optional[Any]: ) -> Any | None:
"""Allow an old config name to be deprecated with a replacement. """Allow an old config name to be deprecated with a replacement.
If the new config isn't found, but the old one is, the old value is used If the new config isn't found, but the old one is, the old value is used
@ -85,7 +87,7 @@ def deprecated_function(replacement: str) -> Callable[..., Callable]:
"""Decorate function as deprecated.""" """Decorate function as deprecated."""
@functools.wraps(func) @functools.wraps(func)
def deprecated_func(*args: tuple, **kwargs: Dict[str, Any]) -> Any: def deprecated_func(*args: tuple, **kwargs: dict[str, Any]) -> Any:
"""Wrap for the original function.""" """Wrap for the original function."""
logger = logging.getLogger(func.__module__) logger = logging.getLogger(func.__module__)
try: try:

View file

@ -1,8 +1,10 @@
"""Provide a way to connect entities belonging to one device.""" """Provide a way to connect entities belonging to one device."""
from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
import logging import logging
import time import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union, cast from typing import TYPE_CHECKING, Any, cast
import attr import attr
@ -50,21 +52,21 @@ ORPHANED_DEVICE_KEEP_SECONDS = 86400 * 30
class DeviceEntry: class DeviceEntry:
"""Device Registry Entry.""" """Device Registry Entry."""
config_entries: Set[str] = attr.ib(converter=set, factory=set) config_entries: set[str] = attr.ib(converter=set, factory=set)
connections: Set[Tuple[str, str]] = attr.ib(converter=set, factory=set) connections: set[tuple[str, str]] = attr.ib(converter=set, factory=set)
identifiers: Set[Tuple[str, str]] = attr.ib(converter=set, factory=set) identifiers: set[tuple[str, str]] = attr.ib(converter=set, factory=set)
manufacturer: Optional[str] = attr.ib(default=None) manufacturer: str | None = attr.ib(default=None)
model: Optional[str] = attr.ib(default=None) model: str | None = attr.ib(default=None)
name: Optional[str] = attr.ib(default=None) name: str | None = attr.ib(default=None)
sw_version: Optional[str] = attr.ib(default=None) sw_version: str | None = attr.ib(default=None)
via_device_id: Optional[str] = attr.ib(default=None) via_device_id: str | None = attr.ib(default=None)
area_id: Optional[str] = attr.ib(default=None) area_id: str | None = attr.ib(default=None)
name_by_user: Optional[str] = attr.ib(default=None) name_by_user: str | None = attr.ib(default=None)
entry_type: Optional[str] = attr.ib(default=None) entry_type: str | None = attr.ib(default=None)
id: str = attr.ib(factory=uuid_util.random_uuid_hex) id: str = attr.ib(factory=uuid_util.random_uuid_hex)
# This value is not stored, just used to keep track of events to fire. # This value is not stored, just used to keep track of events to fire.
is_new: bool = attr.ib(default=False) is_new: bool = attr.ib(default=False)
disabled_by: Optional[str] = attr.ib( disabled_by: str | None = attr.ib(
default=None, default=None,
validator=attr.validators.in_( validator=attr.validators.in_(
( (
@ -75,7 +77,7 @@ class DeviceEntry:
) )
), ),
) )
suggested_area: Optional[str] = attr.ib(default=None) suggested_area: str | None = attr.ib(default=None)
@property @property
def disabled(self) -> bool: def disabled(self) -> bool:
@ -87,17 +89,17 @@ class DeviceEntry:
class DeletedDeviceEntry: class DeletedDeviceEntry:
"""Deleted Device Registry Entry.""" """Deleted Device Registry Entry."""
config_entries: Set[str] = attr.ib() config_entries: set[str] = attr.ib()
connections: Set[Tuple[str, str]] = attr.ib() connections: set[tuple[str, str]] = attr.ib()
identifiers: Set[Tuple[str, str]] = attr.ib() identifiers: set[tuple[str, str]] = attr.ib()
id: str = attr.ib() id: str = attr.ib()
orphaned_timestamp: Optional[float] = attr.ib() orphaned_timestamp: float | None = attr.ib()
def to_device_entry( def to_device_entry(
self, self,
config_entry_id: str, config_entry_id: str,
connections: Set[Tuple[str, str]], connections: set[tuple[str, str]],
identifiers: Set[Tuple[str, str]], identifiers: set[tuple[str, str]],
) -> DeviceEntry: ) -> DeviceEntry:
"""Create DeviceEntry from DeletedDeviceEntry.""" """Create DeviceEntry from DeletedDeviceEntry."""
return DeviceEntry( return DeviceEntry(
@ -133,9 +135,9 @@ def format_mac(mac: str) -> str:
class DeviceRegistry: class DeviceRegistry:
"""Class to hold a registry of devices.""" """Class to hold a registry of devices."""
devices: Dict[str, DeviceEntry] devices: dict[str, DeviceEntry]
deleted_devices: Dict[str, DeletedDeviceEntry] deleted_devices: dict[str, DeletedDeviceEntry]
_devices_index: Dict[str, Dict[str, Dict[Tuple[str, str], str]]] _devices_index: dict[str, dict[str, dict[tuple[str, str], str]]]
def __init__(self, hass: HomeAssistantType) -> None: def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize the device registry.""" """Initialize the device registry."""
@ -144,16 +146,16 @@ class DeviceRegistry:
self._clear_index() self._clear_index()
@callback @callback
def async_get(self, device_id: str) -> Optional[DeviceEntry]: def async_get(self, device_id: str) -> DeviceEntry | None:
"""Get device.""" """Get device."""
return self.devices.get(device_id) return self.devices.get(device_id)
@callback @callback
def async_get_device( def async_get_device(
self, self,
identifiers: Set[Tuple[str, str]], identifiers: set[tuple[str, str]],
connections: Optional[Set[Tuple[str, str]]] = None, connections: set[tuple[str, str]] | None = None,
) -> Optional[DeviceEntry]: ) -> DeviceEntry | None:
"""Check if device is registered.""" """Check if device is registered."""
device_id = self._async_get_device_id_from_index( device_id = self._async_get_device_id_from_index(
REGISTERED_DEVICE, identifiers, connections REGISTERED_DEVICE, identifiers, connections
@ -164,9 +166,9 @@ class DeviceRegistry:
def _async_get_deleted_device( def _async_get_deleted_device(
self, self,
identifiers: Set[Tuple[str, str]], identifiers: set[tuple[str, str]],
connections: Optional[Set[Tuple[str, str]]], connections: set[tuple[str, str]] | None,
) -> Optional[DeletedDeviceEntry]: ) -> DeletedDeviceEntry | None:
"""Check if device is deleted.""" """Check if device is deleted."""
device_id = self._async_get_device_id_from_index( device_id = self._async_get_device_id_from_index(
DELETED_DEVICE, identifiers, connections DELETED_DEVICE, identifiers, connections
@ -178,9 +180,9 @@ class DeviceRegistry:
def _async_get_device_id_from_index( def _async_get_device_id_from_index(
self, self,
index: str, index: str,
identifiers: Set[Tuple[str, str]], identifiers: set[tuple[str, str]],
connections: Optional[Set[Tuple[str, str]]], connections: set[tuple[str, str]] | None,
) -> Optional[str]: ) -> str | None:
"""Check if device has previously been registered.""" """Check if device has previously been registered."""
devices_index = self._devices_index[index] devices_index = self._devices_index[index]
for identifier in identifiers: for identifier in identifiers:
@ -193,7 +195,7 @@ class DeviceRegistry:
return devices_index[IDX_CONNECTIONS][connection] return devices_index[IDX_CONNECTIONS][connection]
return None return None
def _add_device(self, device: Union[DeviceEntry, DeletedDeviceEntry]) -> None: def _add_device(self, device: DeviceEntry | DeletedDeviceEntry) -> None:
"""Add a device and index it.""" """Add a device and index it."""
if isinstance(device, DeletedDeviceEntry): if isinstance(device, DeletedDeviceEntry):
devices_index = self._devices_index[DELETED_DEVICE] devices_index = self._devices_index[DELETED_DEVICE]
@ -204,7 +206,7 @@ class DeviceRegistry:
_add_device_to_index(devices_index, device) _add_device_to_index(devices_index, device)
def _remove_device(self, device: Union[DeviceEntry, DeletedDeviceEntry]) -> None: def _remove_device(self, device: DeviceEntry | DeletedDeviceEntry) -> None:
"""Remove a device and remove it from the index.""" """Remove a device and remove it from the index."""
if isinstance(device, DeletedDeviceEntry): if isinstance(device, DeletedDeviceEntry):
devices_index = self._devices_index[DELETED_DEVICE] devices_index = self._devices_index[DELETED_DEVICE]
@ -243,21 +245,21 @@ class DeviceRegistry:
self, self,
*, *,
config_entry_id: str, config_entry_id: str,
connections: Optional[Set[Tuple[str, str]]] = None, connections: set[tuple[str, str]] | None = None,
identifiers: Optional[Set[Tuple[str, str]]] = None, identifiers: set[tuple[str, str]] | None = None,
manufacturer: Union[str, None, UndefinedType] = UNDEFINED, manufacturer: str | None | UndefinedType = UNDEFINED,
model: Union[str, None, UndefinedType] = UNDEFINED, model: str | None | UndefinedType = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED, name: str | None | UndefinedType = UNDEFINED,
default_manufacturer: Union[str, None, UndefinedType] = UNDEFINED, default_manufacturer: str | None | UndefinedType = UNDEFINED,
default_model: Union[str, None, UndefinedType] = UNDEFINED, default_model: str | None | UndefinedType = UNDEFINED,
default_name: Union[str, None, UndefinedType] = UNDEFINED, default_name: str | None | UndefinedType = UNDEFINED,
sw_version: Union[str, None, UndefinedType] = UNDEFINED, sw_version: str | None | UndefinedType = UNDEFINED,
entry_type: Union[str, None, UndefinedType] = UNDEFINED, entry_type: str | None | UndefinedType = UNDEFINED,
via_device: Optional[Tuple[str, str]] = None, via_device: tuple[str, str] | None = None,
# To disable a device if it gets created # To disable a device if it gets created
disabled_by: Union[str, None, UndefinedType] = UNDEFINED, disabled_by: str | None | UndefinedType = UNDEFINED,
suggested_area: Union[str, None, UndefinedType] = UNDEFINED, suggested_area: str | None | UndefinedType = UNDEFINED,
) -> Optional[DeviceEntry]: ) -> DeviceEntry | None:
"""Get device. Create if it doesn't exist.""" """Get device. Create if it doesn't exist."""
if not identifiers and not connections: if not identifiers and not connections:
return None return None
@ -294,7 +296,7 @@ class DeviceRegistry:
if via_device is not None: if via_device is not None:
via = self.async_get_device({via_device}) via = self.async_get_device({via_device})
via_device_id: Union[str, UndefinedType] = via.id if via else UNDEFINED via_device_id: str | UndefinedType = via.id if via else UNDEFINED
else: else:
via_device_id = UNDEFINED via_device_id = UNDEFINED
@ -318,18 +320,18 @@ class DeviceRegistry:
self, self,
device_id: str, device_id: str,
*, *,
area_id: Union[str, None, UndefinedType] = UNDEFINED, area_id: str | None | UndefinedType = UNDEFINED,
manufacturer: Union[str, None, UndefinedType] = UNDEFINED, manufacturer: str | None | UndefinedType = UNDEFINED,
model: Union[str, None, UndefinedType] = UNDEFINED, model: str | None | UndefinedType = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED, name: str | None | UndefinedType = UNDEFINED,
name_by_user: Union[str, None, UndefinedType] = UNDEFINED, name_by_user: str | None | UndefinedType = UNDEFINED,
new_identifiers: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED, new_identifiers: set[tuple[str, str]] | UndefinedType = UNDEFINED,
sw_version: Union[str, None, UndefinedType] = UNDEFINED, sw_version: str | None | UndefinedType = UNDEFINED,
via_device_id: Union[str, None, UndefinedType] = UNDEFINED, via_device_id: str | None | UndefinedType = UNDEFINED,
remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED, remove_config_entry_id: str | UndefinedType = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED, disabled_by: str | None | UndefinedType = UNDEFINED,
suggested_area: Union[str, None, UndefinedType] = UNDEFINED, suggested_area: str | None | UndefinedType = UNDEFINED,
) -> Optional[DeviceEntry]: ) -> DeviceEntry | None:
"""Update properties of a device.""" """Update properties of a device."""
return self._async_update_device( return self._async_update_device(
device_id, device_id,
@ -351,26 +353,26 @@ class DeviceRegistry:
self, self,
device_id: str, device_id: str,
*, *,
add_config_entry_id: Union[str, UndefinedType] = UNDEFINED, add_config_entry_id: str | UndefinedType = UNDEFINED,
remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED, remove_config_entry_id: str | UndefinedType = UNDEFINED,
merge_connections: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED, merge_connections: set[tuple[str, str]] | UndefinedType = UNDEFINED,
merge_identifiers: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED, merge_identifiers: set[tuple[str, str]] | UndefinedType = UNDEFINED,
new_identifiers: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED, new_identifiers: set[tuple[str, str]] | UndefinedType = UNDEFINED,
manufacturer: Union[str, None, UndefinedType] = UNDEFINED, manufacturer: str | None | UndefinedType = UNDEFINED,
model: Union[str, None, UndefinedType] = UNDEFINED, model: str | None | UndefinedType = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED, name: str | None | UndefinedType = UNDEFINED,
sw_version: Union[str, None, UndefinedType] = UNDEFINED, sw_version: str | None | UndefinedType = UNDEFINED,
entry_type: Union[str, None, UndefinedType] = UNDEFINED, entry_type: str | None | UndefinedType = UNDEFINED,
via_device_id: Union[str, None, UndefinedType] = UNDEFINED, via_device_id: str | None | UndefinedType = UNDEFINED,
area_id: Union[str, None, UndefinedType] = UNDEFINED, area_id: str | None | UndefinedType = UNDEFINED,
name_by_user: Union[str, None, UndefinedType] = UNDEFINED, name_by_user: str | None | UndefinedType = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED, disabled_by: str | None | UndefinedType = UNDEFINED,
suggested_area: Union[str, None, UndefinedType] = UNDEFINED, suggested_area: str | None | UndefinedType = UNDEFINED,
) -> Optional[DeviceEntry]: ) -> DeviceEntry | None:
"""Update device attributes.""" """Update device attributes."""
old = self.devices[device_id] old = self.devices[device_id]
changes: Dict[str, Any] = {} changes: dict[str, Any] = {}
config_entries = old.config_entries config_entries = old.config_entries
@ -529,7 +531,7 @@ class DeviceRegistry:
self._store.async_delay_save(self._data_to_save, SAVE_DELAY) self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback @callback
def _data_to_save(self) -> Dict[str, List[Dict[str, Any]]]: def _data_to_save(self) -> dict[str, list[dict[str, Any]]]:
"""Return data of device registry to store in a file.""" """Return data of device registry to store in a file."""
data = {} data = {}
@ -637,7 +639,7 @@ async def async_get_registry(hass: HomeAssistantType) -> DeviceRegistry:
@callback @callback
def async_entries_for_area(registry: DeviceRegistry, area_id: str) -> List[DeviceEntry]: def async_entries_for_area(registry: DeviceRegistry, area_id: str) -> list[DeviceEntry]:
"""Return entries that match an area.""" """Return entries that match an area."""
return [device for device in registry.devices.values() if device.area_id == area_id] return [device for device in registry.devices.values() if device.area_id == area_id]
@ -645,7 +647,7 @@ def async_entries_for_area(registry: DeviceRegistry, area_id: str) -> List[Devic
@callback @callback
def async_entries_for_config_entry( def async_entries_for_config_entry(
registry: DeviceRegistry, config_entry_id: str registry: DeviceRegistry, config_entry_id: str
) -> List[DeviceEntry]: ) -> list[DeviceEntry]:
"""Return entries that match a config entry.""" """Return entries that match a config entry."""
return [ return [
device device
@ -769,7 +771,7 @@ def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> Non
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, startup_clean) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, startup_clean)
def _normalize_connections(connections: Set[Tuple[str, str]]) -> Set[Tuple[str, str]]: def _normalize_connections(connections: set[tuple[str, str]]) -> set[tuple[str, str]]:
"""Normalize connections to ensure we can match mac addresses.""" """Normalize connections to ensure we can match mac addresses."""
return { return {
(key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value) (key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value)
@ -778,8 +780,8 @@ def _normalize_connections(connections: Set[Tuple[str, str]]) -> Set[Tuple[str,
def _add_device_to_index( def _add_device_to_index(
devices_index: Dict[str, Dict[Tuple[str, str], str]], devices_index: dict[str, dict[tuple[str, str], str]],
device: Union[DeviceEntry, DeletedDeviceEntry], device: DeviceEntry | DeletedDeviceEntry,
) -> None: ) -> None:
"""Add a device to the index.""" """Add a device to the index."""
for identifier in device.identifiers: for identifier in device.identifiers:
@ -789,8 +791,8 @@ def _add_device_to_index(
def _remove_device_from_index( def _remove_device_from_index(
devices_index: Dict[str, Dict[Tuple[str, str], str]], devices_index: dict[str, dict[tuple[str, str], str]],
device: Union[DeviceEntry, DeletedDeviceEntry], device: DeviceEntry | DeletedDeviceEntry,
) -> None: ) -> None:
"""Remove a device from the index.""" """Remove a device from the index."""
for identifier in device.identifiers: for identifier in device.identifiers:

View file

@ -5,7 +5,9 @@ There are two different types of discoveries that can be fired/listened for.
- listen_platform/discover_platform is for platforms. These are used by - listen_platform/discover_platform is for platforms. These are used by
components to allow discovery of their platforms. components to allow discovery of their platforms.
""" """
from typing import Any, Callable, Dict, Optional, TypedDict from __future__ import annotations
from typing import Any, Callable, TypedDict
from homeassistant import core, setup from homeassistant import core, setup
from homeassistant.core import CALLBACK_TYPE from homeassistant.core import CALLBACK_TYPE
@ -26,8 +28,8 @@ class DiscoveryDict(TypedDict):
"""Discovery data.""" """Discovery data."""
service: str service: str
platform: Optional[str] platform: str | None
discovered: Optional[DiscoveryInfoType] discovered: DiscoveryInfoType | None
@core.callback @core.callback
@ -76,8 +78,8 @@ def discover(
async def async_discover( async def async_discover(
hass: core.HomeAssistant, hass: core.HomeAssistant,
service: str, service: str,
discovered: Optional[DiscoveryInfoType], discovered: DiscoveryInfoType | None,
component: Optional[str], component: str | None,
hass_config: ConfigType, hass_config: ConfigType,
) -> None: ) -> None:
"""Fire discovery event. Can ensure a component is loaded.""" """Fire discovery event. Can ensure a component is loaded."""
@ -97,7 +99,7 @@ async def async_discover(
def async_listen_platform( def async_listen_platform(
hass: core.HomeAssistant, hass: core.HomeAssistant,
component: str, component: str,
callback: Callable[[str, Optional[Dict[str, Any]]], Any], callback: Callable[[str, dict[str, Any] | None], Any],
) -> None: ) -> None:
"""Register a platform loader listener. """Register a platform loader listener.

View file

@ -1,11 +1,13 @@
"""An abstract class for entities.""" """An abstract class for entities."""
from __future__ import annotations
from abc import ABC from abc import ABC
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
import functools as ft import functools as ft
import logging import logging
from timeit import default_timer as timer from timeit import default_timer as timer
from typing import Any, Awaitable, Dict, Iterable, List, Optional from typing import Any, Awaitable, Iterable
from homeassistant.config import DATA_CUSTOMIZE from homeassistant.config import DATA_CUSTOMIZE
from homeassistant.const import ( from homeassistant.const import (
@ -42,16 +44,16 @@ SOURCE_PLATFORM_CONFIG = "platform_config"
@callback @callback
@bind_hass @bind_hass
def entity_sources(hass: HomeAssistant) -> Dict[str, Dict[str, str]]: def entity_sources(hass: HomeAssistant) -> dict[str, dict[str, str]]:
"""Get the entity sources.""" """Get the entity sources."""
return hass.data.get(DATA_ENTITY_SOURCE, {}) return hass.data.get(DATA_ENTITY_SOURCE, {})
def generate_entity_id( def generate_entity_id(
entity_id_format: str, entity_id_format: str,
name: Optional[str], name: str | None,
current_ids: Optional[List[str]] = None, current_ids: list[str] | None = None,
hass: Optional[HomeAssistant] = None, hass: HomeAssistant | None = None,
) -> str: ) -> str:
"""Generate a unique entity ID based on given entity IDs or used IDs.""" """Generate a unique entity ID based on given entity IDs or used IDs."""
return async_generate_entity_id(entity_id_format, name, current_ids, hass) return async_generate_entity_id(entity_id_format, name, current_ids, hass)
@ -60,9 +62,9 @@ def generate_entity_id(
@callback @callback
def async_generate_entity_id( def async_generate_entity_id(
entity_id_format: str, entity_id_format: str,
name: Optional[str], name: str | None,
current_ids: Optional[Iterable[str]] = None, current_ids: Iterable[str] | None = None,
hass: Optional[HomeAssistant] = None, hass: HomeAssistant | None = None,
) -> str: ) -> str:
"""Generate a unique entity ID based on given entity IDs or used IDs.""" """Generate a unique entity ID based on given entity IDs or used IDs."""
name = (name or DEVICE_DEFAULT_NAME).lower() name = (name or DEVICE_DEFAULT_NAME).lower()
@ -98,7 +100,7 @@ class Entity(ABC):
hass: HomeAssistant = None # type: ignore hass: HomeAssistant = None # type: ignore
# Owning platform instance. Will be set by EntityPlatform # Owning platform instance. Will be set by EntityPlatform
platform: Optional[EntityPlatform] = None platform: EntityPlatform | None = None
# If we reported if this entity was slow # If we reported if this entity was slow
_slow_reported = False _slow_reported = False
@ -110,17 +112,17 @@ class Entity(ABC):
_update_staged = False _update_staged = False
# Process updates in parallel # Process updates in parallel
parallel_updates: Optional[asyncio.Semaphore] = None parallel_updates: asyncio.Semaphore | None = None
# Entry in the entity registry # Entry in the entity registry
registry_entry: Optional[RegistryEntry] = None registry_entry: RegistryEntry | None = None
# Hold list for functions to call on remove. # Hold list for functions to call on remove.
_on_remove: Optional[List[CALLBACK_TYPE]] = None _on_remove: list[CALLBACK_TYPE] | None = None
# Context # Context
_context: Optional[Context] = None _context: Context | None = None
_context_set: Optional[datetime] = None _context_set: datetime | None = None
# If entity is added to an entity platform # If entity is added to an entity platform
_added = False _added = False
@ -134,12 +136,12 @@ class Entity(ABC):
return True return True
@property @property
def unique_id(self) -> Optional[str]: def unique_id(self) -> str | None:
"""Return a unique ID.""" """Return a unique ID."""
return None return None
@property @property
def name(self) -> Optional[str]: def name(self) -> str | None:
"""Return the name of the entity.""" """Return the name of the entity."""
return None return None
@ -149,7 +151,7 @@ class Entity(ABC):
return STATE_UNKNOWN return STATE_UNKNOWN
@property @property
def capability_attributes(self) -> Optional[Dict[str, Any]]: def capability_attributes(self) -> dict[str, Any] | None:
"""Return the capability attributes. """Return the capability attributes.
Attributes that explain the capabilities of an entity. Attributes that explain the capabilities of an entity.
@ -160,7 +162,7 @@ class Entity(ABC):
return None return None
@property @property
def state_attributes(self) -> Optional[Dict[str, Any]]: def state_attributes(self) -> dict[str, Any] | None:
"""Return the state attributes. """Return the state attributes.
Implemented by component base class, should not be extended by integrations. Implemented by component base class, should not be extended by integrations.
@ -169,7 +171,7 @@ class Entity(ABC):
return None return None
@property @property
def device_state_attributes(self) -> Optional[Dict[str, Any]]: def device_state_attributes(self) -> dict[str, Any] | None:
"""Return entity specific state attributes. """Return entity specific state attributes.
This method is deprecated, platform classes should implement This method is deprecated, platform classes should implement
@ -178,7 +180,7 @@ class Entity(ABC):
return None return None
@property @property
def extra_state_attributes(self) -> Optional[Dict[str, Any]]: def extra_state_attributes(self) -> dict[str, Any] | None:
"""Return entity specific state attributes. """Return entity specific state attributes.
Implemented by platform classes. Convention for attribute names Implemented by platform classes. Convention for attribute names
@ -187,7 +189,7 @@ class Entity(ABC):
return None return None
@property @property
def device_info(self) -> Optional[Dict[str, Any]]: def device_info(self) -> dict[str, Any] | None:
"""Return device specific attributes. """Return device specific attributes.
Implemented by platform classes. Implemented by platform classes.
@ -195,22 +197,22 @@ class Entity(ABC):
return None return None
@property @property
def device_class(self) -> Optional[str]: def device_class(self) -> str | None:
"""Return the class of this device, from component DEVICE_CLASSES.""" """Return the class of this device, from component DEVICE_CLASSES."""
return None return None
@property @property
def unit_of_measurement(self) -> Optional[str]: def unit_of_measurement(self) -> str | None:
"""Return the unit of measurement of this entity, if any.""" """Return the unit of measurement of this entity, if any."""
return None return None
@property @property
def icon(self) -> Optional[str]: def icon(self) -> str | None:
"""Return the icon to use in the frontend, if any.""" """Return the icon to use in the frontend, if any."""
return None return None
@property @property
def entity_picture(self) -> Optional[str]: def entity_picture(self) -> str | None:
"""Return the entity picture to use in the frontend, if any.""" """Return the entity picture to use in the frontend, if any."""
return None return None
@ -234,7 +236,7 @@ class Entity(ABC):
return False return False
@property @property
def supported_features(self) -> Optional[int]: def supported_features(self) -> int | None:
"""Flag supported features.""" """Flag supported features."""
return None return None
@ -516,7 +518,7 @@ class Entity(ABC):
self, self,
hass: HomeAssistant, hass: HomeAssistant,
platform: EntityPlatform, platform: EntityPlatform,
parallel_updates: Optional[asyncio.Semaphore], parallel_updates: asyncio.Semaphore | None,
) -> None: ) -> None:
"""Start adding an entity to a platform.""" """Start adding an entity to a platform."""
if self._added: if self._added:

View file

@ -1,10 +1,12 @@
"""Helpers for components that manage entities.""" """Helpers for components that manage entities."""
from __future__ import annotations
import asyncio import asyncio
from datetime import timedelta from datetime import timedelta
from itertools import chain from itertools import chain
import logging import logging
from types import ModuleType from types import ModuleType
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Callable, Iterable
import voluptuous as vol import voluptuous as vol
@ -76,10 +78,10 @@ class EntityComponent:
self.domain = domain self.domain = domain
self.scan_interval = scan_interval self.scan_interval = scan_interval
self.config: Optional[ConfigType] = None self.config: ConfigType | None = None
self._platforms: Dict[ self._platforms: dict[
Union[str, Tuple[str, Optional[timedelta], Optional[str]]], EntityPlatform str | tuple[str, timedelta | None, str | None], EntityPlatform
] = {domain: self._async_init_entity_platform(domain, None)} ] = {domain: self._async_init_entity_platform(domain, None)}
self.async_add_entities = self._platforms[domain].async_add_entities self.async_add_entities = self._platforms[domain].async_add_entities
self.add_entities = self._platforms[domain].add_entities self.add_entities = self._platforms[domain].add_entities
@ -93,7 +95,7 @@ class EntityComponent:
platform.entities.values() for platform in self._platforms.values() platform.entities.values() for platform in self._platforms.values()
) )
def get_entity(self, entity_id: str) -> Optional[entity.Entity]: def get_entity(self, entity_id: str) -> entity.Entity | None:
"""Get an entity.""" """Get an entity."""
for platform in self._platforms.values(): for platform in self._platforms.values():
entity_obj = platform.entities.get(entity_id) entity_obj = platform.entities.get(entity_id)
@ -125,7 +127,7 @@ class EntityComponent:
# Generic discovery listener for loading platform dynamically # Generic discovery listener for loading platform dynamically
# Refer to: homeassistant.helpers.discovery.async_load_platform() # Refer to: homeassistant.helpers.discovery.async_load_platform()
async def component_platform_discovered( async def component_platform_discovered(
platform: str, info: Optional[Dict[str, Any]] platform: str, info: dict[str, Any] | None
) -> None: ) -> None:
"""Handle the loading of a platform.""" """Handle the loading of a platform."""
await self.async_setup_platform(platform, {}, info) await self.async_setup_platform(platform, {}, info)
@ -176,7 +178,7 @@ class EntityComponent:
async def async_extract_from_service( async def async_extract_from_service(
self, service_call: ServiceCall, expand_group: bool = True self, service_call: ServiceCall, expand_group: bool = True
) -> List[entity.Entity]: ) -> list[entity.Entity]:
"""Extract all known and available entities from a service call. """Extract all known and available entities from a service call.
Will return an empty list if entities specified but unknown. Will return an empty list if entities specified but unknown.
@ -191,9 +193,9 @@ class EntityComponent:
def async_register_entity_service( def async_register_entity_service(
self, self,
name: str, name: str,
schema: Union[Dict[str, Any], vol.Schema], schema: dict[str, Any] | vol.Schema,
func: Union[str, Callable[..., Any]], func: str | Callable[..., Any],
required_features: Optional[List[int]] = None, required_features: list[int] | None = None,
) -> None: ) -> None:
"""Register an entity service.""" """Register an entity service."""
if isinstance(schema, dict): if isinstance(schema, dict):
@ -211,7 +213,7 @@ class EntityComponent:
self, self,
platform_type: str, platform_type: str,
platform_config: ConfigType, platform_config: ConfigType,
discovery_info: Optional[DiscoveryInfoType] = None, discovery_info: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Set up a platform for this component.""" """Set up a platform for this component."""
if self.config is None: if self.config is None:
@ -274,7 +276,7 @@ class EntityComponent:
async def async_prepare_reload( async def async_prepare_reload(
self, *, skip_reset: bool = False self, *, skip_reset: bool = False
) -> Optional[ConfigType]: ) -> ConfigType | None:
"""Prepare reloading this entity component. """Prepare reloading this entity component.
This method must be run in the event loop. This method must be run in the event loop.
@ -303,9 +305,9 @@ class EntityComponent:
def _async_init_entity_platform( def _async_init_entity_platform(
self, self,
platform_type: str, platform_type: str,
platform: Optional[ModuleType], platform: ModuleType | None,
scan_interval: Optional[timedelta] = None, scan_interval: timedelta | None = None,
entity_namespace: Optional[str] = None, entity_namespace: str | None = None,
) -> EntityPlatform: ) -> EntityPlatform:
"""Initialize an entity platform.""" """Initialize an entity platform."""
if scan_interval is None: if scan_interval is None:

View file

@ -6,7 +6,7 @@ from contextvars import ContextVar
from datetime import datetime, timedelta from datetime import datetime, timedelta
from logging import Logger from logging import Logger
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Iterable, List, Optional from typing import TYPE_CHECKING, Callable, Coroutine, Iterable
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.const import ATTR_RESTORED, DEVICE_DEFAULT_NAME from homeassistant.const import ATTR_RESTORED, DEVICE_DEFAULT_NAME
@ -49,9 +49,9 @@ class EntityPlatform:
logger: Logger, logger: Logger,
domain: str, domain: str,
platform_name: str, platform_name: str,
platform: Optional[ModuleType], platform: ModuleType | None,
scan_interval: timedelta, scan_interval: timedelta,
entity_namespace: Optional[str], entity_namespace: str | None,
): ):
"""Initialize the entity platform.""" """Initialize the entity platform."""
self.hass = hass self.hass = hass
@ -61,18 +61,18 @@ class EntityPlatform:
self.platform = platform self.platform = platform
self.scan_interval = scan_interval self.scan_interval = scan_interval
self.entity_namespace = entity_namespace self.entity_namespace = entity_namespace
self.config_entry: Optional[config_entries.ConfigEntry] = None self.config_entry: config_entries.ConfigEntry | None = None
self.entities: Dict[str, Entity] = {} self.entities: dict[str, Entity] = {}
self._tasks: List[asyncio.Future] = [] self._tasks: list[asyncio.Future] = []
# Stop tracking tasks after setup is completed # Stop tracking tasks after setup is completed
self._setup_complete = False self._setup_complete = False
# Method to cancel the state change listener # Method to cancel the state change listener
self._async_unsub_polling: Optional[CALLBACK_TYPE] = None self._async_unsub_polling: CALLBACK_TYPE | None = None
# Method to cancel the retry of setup # Method to cancel the retry of setup
self._async_cancel_retry_setup: Optional[CALLBACK_TYPE] = None self._async_cancel_retry_setup: CALLBACK_TYPE | None = None
self._process_updates: Optional[asyncio.Lock] = None self._process_updates: asyncio.Lock | None = None
self.parallel_updates: Optional[asyncio.Semaphore] = None self.parallel_updates: asyncio.Semaphore | None = None
# Platform is None for the EntityComponent "catch-all" EntityPlatform # Platform is None for the EntityComponent "catch-all" EntityPlatform
# which powers entity_component.add_entities # which powers entity_component.add_entities
@ -89,7 +89,7 @@ class EntityPlatform:
@callback @callback
def _get_parallel_updates_semaphore( def _get_parallel_updates_semaphore(
self, entity_has_async_update: bool self, entity_has_async_update: bool
) -> Optional[asyncio.Semaphore]: ) -> asyncio.Semaphore | None:
"""Get or create a semaphore for parallel updates. """Get or create a semaphore for parallel updates.
Semaphore will be created on demand because we base it off if update method is async or not. Semaphore will be created on demand because we base it off if update method is async or not.
@ -364,7 +364,7 @@ class EntityPlatform:
return return
requested_entity_id = None requested_entity_id = None
suggested_object_id: Optional[str] = None suggested_object_id: str | None = None
# Get entity_id from unique ID registration # Get entity_id from unique ID registration
if entity.unique_id is not None: if entity.unique_id is not None:
@ -378,7 +378,7 @@ class EntityPlatform:
suggested_object_id = f"{self.entity_namespace} {suggested_object_id}" suggested_object_id = f"{self.entity_namespace} {suggested_object_id}"
if self.config_entry is not None: if self.config_entry is not None:
config_entry_id: Optional[str] = self.config_entry.entry_id config_entry_id: str | None = self.config_entry.entry_id
else: else:
config_entry_id = None config_entry_id = None
@ -408,7 +408,7 @@ class EntityPlatform:
if device: if device:
device_id = device.id device_id = device.id
disabled_by: Optional[str] = None disabled_by: str | None = None
if not entity.entity_registry_enabled_default: if not entity.entity_registry_enabled_default:
disabled_by = DISABLED_INTEGRATION disabled_by = DISABLED_INTEGRATION
@ -550,7 +550,7 @@ class EntityPlatform:
async def async_extract_from_service( async def async_extract_from_service(
self, service_call: ServiceCall, expand_group: bool = True self, service_call: ServiceCall, expand_group: bool = True
) -> List[Entity]: ) -> list[Entity]:
"""Extract all known and available entities from a service call. """Extract all known and available entities from a service call.
Will return an empty list if entities specified but unknown. Will return an empty list if entities specified but unknown.
@ -621,7 +621,7 @@ class EntityPlatform:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
current_platform: ContextVar[Optional[EntityPlatform]] = ContextVar( current_platform: ContextVar[EntityPlatform | None] = ContextVar(
"current_platform", default=None "current_platform", default=None
) )
@ -629,7 +629,7 @@ current_platform: ContextVar[Optional[EntityPlatform]] = ContextVar(
@callback @callback
def async_get_platforms( def async_get_platforms(
hass: HomeAssistantType, integration_name: str hass: HomeAssistantType, integration_name: str
) -> List[EntityPlatform]: ) -> list[EntityPlatform]:
"""Find existing platforms.""" """Find existing platforms."""
if ( if (
DATA_ENTITY_PLATFORM not in hass.data DATA_ENTITY_PLATFORM not in hass.data
@ -637,6 +637,6 @@ def async_get_platforms(
): ):
return [] return []
platforms: List[EntityPlatform] = hass.data[DATA_ENTITY_PLATFORM][integration_name] platforms: list[EntityPlatform] = hass.data[DATA_ENTITY_PLATFORM][integration_name]
return platforms return platforms

View file

@ -7,20 +7,11 @@ The Entity Registry will persist itself 10 seconds after a new entity is
registered. Registering a new entity while a timer is in progress resets the registered. Registering a new entity while a timer is in progress resets the
timer. timer.
""" """
from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
import logging import logging
from typing import ( from typing import TYPE_CHECKING, Any, Callable, Iterable, cast
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
cast,
)
import attr import attr
@ -80,12 +71,12 @@ class RegistryEntry:
entity_id: str = attr.ib() entity_id: str = attr.ib()
unique_id: str = attr.ib() unique_id: str = attr.ib()
platform: str = attr.ib() platform: str = attr.ib()
name: Optional[str] = attr.ib(default=None) name: str | None = attr.ib(default=None)
icon: Optional[str] = attr.ib(default=None) icon: str | None = attr.ib(default=None)
device_id: Optional[str] = attr.ib(default=None) device_id: str | None = attr.ib(default=None)
area_id: Optional[str] = attr.ib(default=None) area_id: str | None = attr.ib(default=None)
config_entry_id: Optional[str] = attr.ib(default=None) config_entry_id: str | None = attr.ib(default=None)
disabled_by: Optional[str] = attr.ib( disabled_by: str | None = attr.ib(
default=None, default=None,
validator=attr.validators.in_( validator=attr.validators.in_(
( (
@ -98,13 +89,13 @@ class RegistryEntry:
) )
), ),
) )
capabilities: Optional[Dict[str, Any]] = attr.ib(default=None) capabilities: dict[str, Any] | None = attr.ib(default=None)
supported_features: int = attr.ib(default=0) supported_features: int = attr.ib(default=0)
device_class: Optional[str] = attr.ib(default=None) device_class: str | None = attr.ib(default=None)
unit_of_measurement: Optional[str] = attr.ib(default=None) unit_of_measurement: str | None = attr.ib(default=None)
# As set by integration # As set by integration
original_name: Optional[str] = attr.ib(default=None) original_name: str | None = attr.ib(default=None)
original_icon: Optional[str] = attr.ib(default=None) original_icon: str | None = attr.ib(default=None)
domain: str = attr.ib(init=False, repr=False) domain: str = attr.ib(init=False, repr=False)
@domain.default @domain.default
@ -120,7 +111,7 @@ class RegistryEntry:
@callback @callback
def write_unavailable_state(self, hass: HomeAssistantType) -> None: def write_unavailable_state(self, hass: HomeAssistantType) -> None:
"""Write the unavailable state to the state machine.""" """Write the unavailable state to the state machine."""
attrs: Dict[str, Any] = {ATTR_RESTORED: True} attrs: dict[str, Any] = {ATTR_RESTORED: True}
if self.capabilities is not None: if self.capabilities is not None:
attrs.update(self.capabilities) attrs.update(self.capabilities)
@ -151,8 +142,8 @@ class EntityRegistry:
def __init__(self, hass: HomeAssistantType): def __init__(self, hass: HomeAssistantType):
"""Initialize the registry.""" """Initialize the registry."""
self.hass = hass self.hass = hass
self.entities: Dict[str, RegistryEntry] self.entities: dict[str, RegistryEntry]
self._index: Dict[Tuple[str, str, str], str] = {} self._index: dict[tuple[str, str, str], str] = {}
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self.hass.bus.async_listen( self.hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_modified EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_modified
@ -161,7 +152,7 @@ class EntityRegistry:
@callback @callback
def async_get_device_class_lookup(self, domain_device_classes: set) -> dict: def async_get_device_class_lookup(self, domain_device_classes: set) -> dict:
"""Return a lookup for the device class by domain.""" """Return a lookup for the device class by domain."""
lookup: Dict[str, Dict[Tuple[Any, Any], str]] = {} lookup: dict[str, dict[tuple[Any, Any], str]] = {}
for entity in self.entities.values(): for entity in self.entities.values():
if not entity.device_id: if not entity.device_id:
continue continue
@ -180,14 +171,14 @@ class EntityRegistry:
return entity_id in self.entities return entity_id in self.entities
@callback @callback
def async_get(self, entity_id: str) -> Optional[RegistryEntry]: def async_get(self, entity_id: str) -> RegistryEntry | None:
"""Get EntityEntry for an entity_id.""" """Get EntityEntry for an entity_id."""
return self.entities.get(entity_id) return self.entities.get(entity_id)
@callback @callback
def async_get_entity_id( def async_get_entity_id(
self, domain: str, platform: str, unique_id: str self, domain: str, platform: str, unique_id: str
) -> Optional[str]: ) -> str | None:
"""Check if an entity_id is currently registered.""" """Check if an entity_id is currently registered."""
return self._index.get((domain, platform, unique_id)) return self._index.get((domain, platform, unique_id))
@ -196,7 +187,7 @@ class EntityRegistry:
self, self,
domain: str, domain: str,
suggested_object_id: str, suggested_object_id: str,
known_object_ids: Optional[Iterable[str]] = None, known_object_ids: Iterable[str] | None = None,
) -> str: ) -> str:
"""Generate an entity ID that does not conflict. """Generate an entity ID that does not conflict.
@ -226,20 +217,20 @@ class EntityRegistry:
unique_id: str, unique_id: str,
*, *,
# To influence entity ID generation # To influence entity ID generation
suggested_object_id: Optional[str] = None, suggested_object_id: str | None = None,
known_object_ids: Optional[Iterable[str]] = None, known_object_ids: Iterable[str] | None = None,
# To disable an entity if it gets created # To disable an entity if it gets created
disabled_by: Optional[str] = None, disabled_by: str | None = None,
# Data that we want entry to have # Data that we want entry to have
config_entry: Optional["ConfigEntry"] = None, config_entry: "ConfigEntry" | None = None,
device_id: Optional[str] = None, device_id: str | None = None,
area_id: Optional[str] = None, area_id: str | None = None,
capabilities: Optional[Dict[str, Any]] = None, capabilities: dict[str, Any] | None = None,
supported_features: Optional[int] = None, supported_features: int | None = None,
device_class: Optional[str] = None, device_class: str | None = None,
unit_of_measurement: Optional[str] = None, unit_of_measurement: str | None = None,
original_name: Optional[str] = None, original_name: str | None = None,
original_icon: Optional[str] = None, original_icon: str | None = None,
) -> RegistryEntry: ) -> RegistryEntry:
"""Get entity. Create if it doesn't exist.""" """Get entity. Create if it doesn't exist."""
config_entry_id = None config_entry_id = None
@ -363,12 +354,12 @@ class EntityRegistry:
self, self,
entity_id: str, entity_id: str,
*, *,
name: Union[str, None, UndefinedType] = UNDEFINED, name: str | None | UndefinedType = UNDEFINED,
icon: Union[str, None, UndefinedType] = UNDEFINED, icon: str | None | UndefinedType = UNDEFINED,
area_id: Union[str, None, UndefinedType] = UNDEFINED, area_id: str | None | UndefinedType = UNDEFINED,
new_entity_id: Union[str, UndefinedType] = UNDEFINED, new_entity_id: str | UndefinedType = UNDEFINED,
new_unique_id: Union[str, UndefinedType] = UNDEFINED, new_unique_id: str | UndefinedType = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED, disabled_by: str | None | UndefinedType = UNDEFINED,
) -> RegistryEntry: ) -> RegistryEntry:
"""Update properties of an entity.""" """Update properties of an entity."""
return self._async_update_entity( return self._async_update_entity(
@ -386,20 +377,20 @@ class EntityRegistry:
self, self,
entity_id: str, entity_id: str,
*, *,
name: Union[str, None, UndefinedType] = UNDEFINED, name: str | None | UndefinedType = UNDEFINED,
icon: Union[str, None, UndefinedType] = UNDEFINED, icon: str | None | UndefinedType = UNDEFINED,
config_entry_id: Union[str, None, UndefinedType] = UNDEFINED, config_entry_id: str | None | UndefinedType = UNDEFINED,
new_entity_id: Union[str, UndefinedType] = UNDEFINED, new_entity_id: str | UndefinedType = UNDEFINED,
device_id: Union[str, None, UndefinedType] = UNDEFINED, device_id: str | None | UndefinedType = UNDEFINED,
area_id: Union[str, None, UndefinedType] = UNDEFINED, area_id: str | None | UndefinedType = UNDEFINED,
new_unique_id: Union[str, UndefinedType] = UNDEFINED, new_unique_id: str | UndefinedType = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED, disabled_by: str | None | UndefinedType = UNDEFINED,
capabilities: Union[Dict[str, Any], None, UndefinedType] = UNDEFINED, capabilities: dict[str, Any] | None | UndefinedType = UNDEFINED,
supported_features: Union[int, UndefinedType] = UNDEFINED, supported_features: int | UndefinedType = UNDEFINED,
device_class: Union[str, None, UndefinedType] = UNDEFINED, device_class: str | None | UndefinedType = UNDEFINED,
unit_of_measurement: Union[str, None, UndefinedType] = UNDEFINED, unit_of_measurement: str | None | UndefinedType = UNDEFINED,
original_name: Union[str, None, UndefinedType] = UNDEFINED, original_name: str | None | UndefinedType = UNDEFINED,
original_icon: Union[str, None, UndefinedType] = UNDEFINED, original_icon: str | None | UndefinedType = UNDEFINED,
) -> RegistryEntry: ) -> RegistryEntry:
"""Private facing update properties method.""" """Private facing update properties method."""
old = self.entities[entity_id] old = self.entities[entity_id]
@ -479,7 +470,7 @@ class EntityRegistry:
old_conf_load_func=load_yaml, old_conf_load_func=load_yaml,
old_conf_migrate_func=_async_migrate, old_conf_migrate_func=_async_migrate,
) )
entities: Dict[str, RegistryEntry] = OrderedDict() entities: dict[str, RegistryEntry] = OrderedDict()
if data is not None: if data is not None:
for entity in data["entities"]: for entity in data["entities"]:
@ -516,7 +507,7 @@ class EntityRegistry:
self._store.async_delay_save(self._data_to_save, SAVE_DELAY) self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback @callback
def _data_to_save(self) -> Dict[str, Any]: def _data_to_save(self) -> dict[str, Any]:
"""Return data of entity registry to store in a file.""" """Return data of entity registry to store in a file."""
data = {} data = {}
@ -605,7 +596,7 @@ async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
@callback @callback
def async_entries_for_device( def async_entries_for_device(
registry: EntityRegistry, device_id: str, include_disabled_entities: bool = False registry: EntityRegistry, device_id: str, include_disabled_entities: bool = False
) -> List[RegistryEntry]: ) -> list[RegistryEntry]:
"""Return entries that match a device.""" """Return entries that match a device."""
return [ return [
entry entry
@ -618,7 +609,7 @@ def async_entries_for_device(
@callback @callback
def async_entries_for_area( def async_entries_for_area(
registry: EntityRegistry, area_id: str registry: EntityRegistry, area_id: str
) -> List[RegistryEntry]: ) -> list[RegistryEntry]:
"""Return entries that match an area.""" """Return entries that match an area."""
return [entry for entry in registry.entities.values() if entry.area_id == area_id] return [entry for entry in registry.entities.values() if entry.area_id == area_id]
@ -626,7 +617,7 @@ def async_entries_for_area(
@callback @callback
def async_entries_for_config_entry( def async_entries_for_config_entry(
registry: EntityRegistry, config_entry_id: str registry: EntityRegistry, config_entry_id: str
) -> List[RegistryEntry]: ) -> list[RegistryEntry]:
"""Return entries that match a config entry.""" """Return entries that match a config entry."""
return [ return [
entry entry
@ -665,7 +656,7 @@ def async_config_entry_disabled_by_changed(
) )
async def _async_migrate(entities: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]: async def _async_migrate(entities: dict[str, Any]) -> dict[str, list[dict[str, Any]]]:
"""Migrate the YAML config file to storage helper format.""" """Migrate the YAML config file to storage helper format."""
return { return {
"entities": [ "entities": [
@ -721,7 +712,7 @@ def async_setup_entity_restore(
async def async_migrate_entries( async def async_migrate_entries(
hass: HomeAssistantType, hass: HomeAssistantType,
config_entry_id: str, config_entry_id: str,
entry_callback: Callable[[RegistryEntry], Optional[dict]], entry_callback: Callable[[RegistryEntry], dict | None],
) -> None: ) -> None:
"""Migrator of unique IDs.""" """Migrator of unique IDs."""
ent_reg = await async_get_registry(hass) ent_reg = await async_get_registry(hass)

View file

@ -1,8 +1,10 @@
"""A class to hold entity values.""" """A class to hold entity values."""
from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
import fnmatch import fnmatch
import re import re
from typing import Any, Dict, Optional, Pattern from typing import Any, Pattern
from homeassistant.core import split_entity_id from homeassistant.core import split_entity_id
@ -14,17 +16,17 @@ class EntityValues:
def __init__( def __init__(
self, self,
exact: Optional[Dict[str, Dict[str, str]]] = None, exact: dict[str, dict[str, str]] | None = None,
domain: Optional[Dict[str, Dict[str, str]]] = None, domain: dict[str, dict[str, str]] | None = None,
glob: Optional[Dict[str, Dict[str, str]]] = None, glob: dict[str, dict[str, str]] | None = None,
) -> None: ) -> None:
"""Initialize an EntityConfigDict.""" """Initialize an EntityConfigDict."""
self._cache: Dict[str, Dict[str, str]] = {} self._cache: dict[str, dict[str, str]] = {}
self._exact = exact self._exact = exact
self._domain = domain self._domain = domain
if glob is None: if glob is None:
compiled: Optional[Dict[Pattern[str], Any]] = None compiled: dict[Pattern[str], Any] | None = None
else: else:
compiled = OrderedDict() compiled = OrderedDict()
for key, value in glob.items(): for key, value in glob.items():
@ -32,7 +34,7 @@ class EntityValues:
self._glob = compiled self._glob = compiled
def get(self, entity_id: str) -> Dict[str, str]: def get(self, entity_id: str) -> dict[str, str]:
"""Get config for an entity id.""" """Get config for an entity id."""
if entity_id in self._cache: if entity_id in self._cache:
return self._cache[entity_id] return self._cache[entity_id]

View file

@ -1,7 +1,9 @@
"""Helper class to implement include/exclude of entities and domains.""" """Helper class to implement include/exclude of entities and domains."""
from __future__ import annotations
import fnmatch import fnmatch
import re import re
from typing import Callable, Dict, List, Pattern from typing import Callable, Pattern
import voluptuous as vol import voluptuous as vol
@ -19,7 +21,7 @@ CONF_EXCLUDE_ENTITIES = "exclude_entities"
CONF_ENTITY_GLOBS = "entity_globs" CONF_ENTITY_GLOBS = "entity_globs"
def convert_filter(config: Dict[str, List[str]]) -> Callable[[str], bool]: def convert_filter(config: dict[str, list[str]]) -> Callable[[str], bool]:
"""Convert the filter schema into a filter.""" """Convert the filter schema into a filter."""
filt = generate_filter( filt = generate_filter(
config[CONF_INCLUDE_DOMAINS], config[CONF_INCLUDE_DOMAINS],
@ -57,7 +59,7 @@ FILTER_SCHEMA = vol.All(BASE_FILTER_SCHEMA, convert_filter)
def convert_include_exclude_filter( def convert_include_exclude_filter(
config: Dict[str, Dict[str, List[str]]] config: dict[str, dict[str, list[str]]]
) -> Callable[[str], bool]: ) -> Callable[[str], bool]:
"""Convert the include exclude filter schema into a filter.""" """Convert the include exclude filter schema into a filter."""
include = config[CONF_INCLUDE] include = config[CONF_INCLUDE]
@ -107,7 +109,7 @@ def _glob_to_re(glob: str) -> Pattern[str]:
return re.compile(fnmatch.translate(glob)) return re.compile(fnmatch.translate(glob))
def _test_against_patterns(patterns: List[Pattern[str]], entity_id: str) -> bool: def _test_against_patterns(patterns: list[Pattern[str]], entity_id: str) -> bool:
"""Test entity against list of patterns, true if any match.""" """Test entity against list of patterns, true if any match."""
for pattern in patterns: for pattern in patterns:
if pattern.match(entity_id): if pattern.match(entity_id):
@ -119,12 +121,12 @@ def _test_against_patterns(patterns: List[Pattern[str]], entity_id: str) -> bool
# It's safe since we don't modify it. And None causes typing warnings # It's safe since we don't modify it. And None causes typing warnings
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
def generate_filter( def generate_filter(
include_domains: List[str], include_domains: list[str],
include_entities: List[str], include_entities: list[str],
exclude_domains: List[str], exclude_domains: list[str],
exclude_entities: List[str], exclude_entities: list[str],
include_entity_globs: List[str] = [], include_entity_globs: list[str] = [],
exclude_entity_globs: List[str] = [], exclude_entity_globs: list[str] = [],
) -> Callable[[str], bool]: ) -> Callable[[str], bool]:
"""Return a function that will filter entities based on the args.""" """Return a function that will filter entities based on the args."""
include_d = set(include_domains) include_d = set(include_domains)

View file

@ -1,4 +1,6 @@
"""Helpers for listening to events.""" """Helpers for listening to events."""
from __future__ import annotations
import asyncio import asyncio
import copy import copy
from dataclasses import dataclass from dataclasses import dataclass
@ -6,18 +8,7 @@ from datetime import datetime, timedelta
import functools as ft import functools as ft
import logging import logging
import time import time
from typing import ( from typing import Any, Awaitable, Callable, Iterable, List
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
import attr import attr
@ -79,8 +70,8 @@ class TrackStates:
""" """
all_states: bool all_states: bool
entities: Set entities: set
domains: Set domains: set
@dataclass @dataclass
@ -94,7 +85,7 @@ class TrackTemplate:
template: Template template: Template
variables: TemplateVarsType variables: TemplateVarsType
rate_limit: Optional[timedelta] = None rate_limit: timedelta | None = None
@dataclass @dataclass
@ -146,10 +137,10 @@ def threaded_listener_factory(
@bind_hass @bind_hass
def async_track_state_change( def async_track_state_change(
hass: HomeAssistant, hass: HomeAssistant,
entity_ids: Union[str, Iterable[str]], entity_ids: str | Iterable[str],
action: Callable[[str, State, State], None], action: Callable[[str, State, State], None],
from_state: Union[None, str, Iterable[str]] = None, from_state: None | str | Iterable[str] = None,
to_state: Union[None, str, Iterable[str]] = None, to_state: None | str | Iterable[str] = None,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Track specific state changes. """Track specific state changes.
@ -240,7 +231,7 @@ track_state_change = threaded_listener_factory(async_track_state_change)
@bind_hass @bind_hass
def async_track_state_change_event( def async_track_state_change_event(
hass: HomeAssistant, hass: HomeAssistant,
entity_ids: Union[str, Iterable[str]], entity_ids: str | Iterable[str],
action: Callable[[Event], Any], action: Callable[[Event], Any],
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Track specific state change events indexed by entity_id. """Track specific state change events indexed by entity_id.
@ -337,7 +328,7 @@ def _async_remove_indexed_listeners(
@bind_hass @bind_hass
def async_track_entity_registry_updated_event( def async_track_entity_registry_updated_event(
hass: HomeAssistant, hass: HomeAssistant,
entity_ids: Union[str, Iterable[str]], entity_ids: str | Iterable[str],
action: Callable[[Event], Any], action: Callable[[Event], Any],
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Track specific entity registry updated events indexed by entity_id. """Track specific entity registry updated events indexed by entity_id.
@ -402,7 +393,7 @@ def async_track_entity_registry_updated_event(
@callback @callback
def _async_dispatch_domain_event( def _async_dispatch_domain_event(
hass: HomeAssistant, event: Event, callbacks: Dict[str, List] hass: HomeAssistant, event: Event, callbacks: dict[str, list]
) -> None: ) -> None:
domain = split_entity_id(event.data["entity_id"])[0] domain = split_entity_id(event.data["entity_id"])[0]
@ -423,7 +414,7 @@ def _async_dispatch_domain_event(
@bind_hass @bind_hass
def async_track_state_added_domain( def async_track_state_added_domain(
hass: HomeAssistant, hass: HomeAssistant,
domains: Union[str, Iterable[str]], domains: str | Iterable[str],
action: Callable[[Event], Any], action: Callable[[Event], Any],
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Track state change events when an entity is added to domains.""" """Track state change events when an entity is added to domains."""
@ -476,7 +467,7 @@ def async_track_state_added_domain(
@bind_hass @bind_hass
def async_track_state_removed_domain( def async_track_state_removed_domain(
hass: HomeAssistant, hass: HomeAssistant,
domains: Union[str, Iterable[str]], domains: str | Iterable[str],
action: Callable[[Event], Any], action: Callable[[Event], Any],
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Track state change events when an entity is removed from domains.""" """Track state change events when an entity is removed from domains."""
@ -527,7 +518,7 @@ def async_track_state_removed_domain(
@callback @callback
def _async_string_to_lower_list(instr: Union[str, Iterable[str]]) -> List[str]: def _async_string_to_lower_list(instr: str | Iterable[str]) -> list[str]:
if isinstance(instr, str): if isinstance(instr, str):
return [instr.lower()] return [instr.lower()]
@ -546,7 +537,7 @@ class _TrackStateChangeFiltered:
"""Handle removal / refresh of tracker init.""" """Handle removal / refresh of tracker init."""
self.hass = hass self.hass = hass
self._action = action self._action = action
self._listeners: Dict[str, Callable] = {} self._listeners: dict[str, Callable] = {}
self._last_track_states: TrackStates = track_states self._last_track_states: TrackStates = track_states
@callback @callback
@ -569,7 +560,7 @@ class _TrackStateChangeFiltered:
self._setup_entities_listener(track_states.domains, track_states.entities) self._setup_entities_listener(track_states.domains, track_states.entities)
@property @property
def listeners(self) -> Dict: def listeners(self) -> dict:
"""State changes that will cause a re-render.""" """State changes that will cause a re-render."""
track_states = self._last_track_states track_states = self._last_track_states
return { return {
@ -628,7 +619,7 @@ class _TrackStateChangeFiltered:
self._listeners.pop(listener_name)() self._listeners.pop(listener_name)()
@callback @callback
def _setup_entities_listener(self, domains: Set, entities: Set) -> None: def _setup_entities_listener(self, domains: set, entities: set) -> None:
if domains: if domains:
entities = entities.copy() entities = entities.copy()
entities.update(self.hass.states.async_entity_ids(domains)) entities.update(self.hass.states.async_entity_ids(domains))
@ -642,7 +633,7 @@ class _TrackStateChangeFiltered:
) )
@callback @callback
def _setup_domains_listener(self, domains: Set) -> None: def _setup_domains_listener(self, domains: set) -> None:
if not domains: if not domains:
return return
@ -691,8 +682,8 @@ def async_track_state_change_filtered(
def async_track_template( def async_track_template(
hass: HomeAssistant, hass: HomeAssistant,
template: Template, template: Template,
action: Callable[[str, Optional[State], Optional[State]], None], action: Callable[[str, State | None, State | None], None],
variables: Optional[TemplateVarsType] = None, variables: TemplateVarsType | None = None,
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Add a listener that fires when a a template evaluates to 'true'. """Add a listener that fires when a a template evaluates to 'true'.
@ -734,7 +725,7 @@ def async_track_template(
@callback @callback
def _template_changed_listener( def _template_changed_listener(
event: Event, updates: List[TrackTemplateResult] event: Event, updates: list[TrackTemplateResult]
) -> None: ) -> None:
"""Check if condition is correct and run action.""" """Check if condition is correct and run action."""
track_result = updates.pop() track_result = updates.pop()
@ -792,12 +783,12 @@ class _TrackTemplateResultInfo:
track_template_.template.hass = hass track_template_.template.hass = hass
self._track_templates = track_templates self._track_templates = track_templates
self._last_result: Dict[Template, Union[str, TemplateError]] = {} self._last_result: dict[Template, str | TemplateError] = {}
self._rate_limit = KeyedRateLimit(hass) self._rate_limit = KeyedRateLimit(hass)
self._info: Dict[Template, RenderInfo] = {} self._info: dict[Template, RenderInfo] = {}
self._track_state_changes: Optional[_TrackStateChangeFiltered] = None self._track_state_changes: _TrackStateChangeFiltered | None = None
self._time_listeners: Dict[Template, Callable] = {} self._time_listeners: dict[Template, Callable] = {}
def async_setup(self, raise_on_template_error: bool) -> None: def async_setup(self, raise_on_template_error: bool) -> None:
"""Activation of template tracking.""" """Activation of template tracking."""
@ -826,7 +817,7 @@ class _TrackTemplateResultInfo:
) )
@property @property
def listeners(self) -> Dict: def listeners(self) -> dict:
"""State changes that will cause a re-render.""" """State changes that will cause a re-render."""
assert self._track_state_changes assert self._track_state_changes
return { return {
@ -882,8 +873,8 @@ class _TrackTemplateResultInfo:
self, self,
track_template_: TrackTemplate, track_template_: TrackTemplate,
now: datetime, now: datetime,
event: Optional[Event], event: Event | None,
) -> Union[bool, TrackTemplateResult]: ) -> bool | TrackTemplateResult:
"""Re-render the template if conditions match. """Re-render the template if conditions match.
Returns False if the template was not be re-rendered Returns False if the template was not be re-rendered
@ -927,7 +918,7 @@ class _TrackTemplateResultInfo:
) )
try: try:
result: Union[str, TemplateError] = info.result() result: str | TemplateError = info.result()
except TemplateError as ex: except TemplateError as ex:
result = ex result = ex
@ -945,9 +936,9 @@ class _TrackTemplateResultInfo:
@callback @callback
def _refresh( def _refresh(
self, self,
event: Optional[Event], event: Event | None,
track_templates: Optional[Iterable[TrackTemplate]] = None, track_templates: Iterable[TrackTemplate] | None = None,
replayed: Optional[bool] = False, replayed: bool | None = False,
) -> None: ) -> None:
"""Refresh the template. """Refresh the template.
@ -1076,16 +1067,16 @@ def async_track_same_state(
hass: HomeAssistant, hass: HomeAssistant,
period: timedelta, period: timedelta,
action: Callable[..., None], action: Callable[..., None],
async_check_same_func: Callable[[str, Optional[State], Optional[State]], bool], async_check_same_func: Callable[[str, State | None, State | None], bool],
entity_ids: Union[str, Iterable[str]] = MATCH_ALL, entity_ids: str | Iterable[str] = MATCH_ALL,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Track the state of entities for a period and run an action. """Track the state of entities for a period and run an action.
If async_check_func is None it use the state of orig_value. If async_check_func is None it use the state of orig_value.
Without entity_ids we track all state changes. Without entity_ids we track all state changes.
""" """
async_remove_state_for_cancel: Optional[CALLBACK_TYPE] = None async_remove_state_for_cancel: CALLBACK_TYPE | None = None
async_remove_state_for_listener: Optional[CALLBACK_TYPE] = None async_remove_state_for_listener: CALLBACK_TYPE | None = None
job = HassJob(action) job = HassJob(action)
@ -1113,8 +1104,8 @@ def async_track_same_state(
def state_for_cancel_listener(event: Event) -> None: def state_for_cancel_listener(event: Event) -> None:
"""Fire on changes and cancel for listener if changed.""" """Fire on changes and cancel for listener if changed."""
entity: str = event.data["entity_id"] entity: str = event.data["entity_id"]
from_state: Optional[State] = event.data.get("old_state") from_state: State | None = event.data.get("old_state")
to_state: Optional[State] = event.data.get("new_state") to_state: State | None = event.data.get("new_state")
if not async_check_same_func(entity, from_state, to_state): if not async_check_same_func(entity, from_state, to_state):
clear_listener() clear_listener()
@ -1144,7 +1135,7 @@ track_same_state = threaded_listener_factory(async_track_same_state)
@bind_hass @bind_hass
def async_track_point_in_time( def async_track_point_in_time(
hass: HomeAssistant, hass: HomeAssistant,
action: Union[HassJob, Callable[..., None]], action: HassJob | Callable[..., None],
point_in_time: datetime, point_in_time: datetime,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that fires once after a specific point in time.""" """Add a listener that fires once after a specific point in time."""
@ -1165,7 +1156,7 @@ track_point_in_time = threaded_listener_factory(async_track_point_in_time)
@bind_hass @bind_hass
def async_track_point_in_utc_time( def async_track_point_in_utc_time(
hass: HomeAssistant, hass: HomeAssistant,
action: Union[HassJob, Callable[..., None]], action: HassJob | Callable[..., None],
point_in_time: datetime, point_in_time: datetime,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that fires once after a specific point in UTC time.""" """Add a listener that fires once after a specific point in UTC time."""
@ -1176,7 +1167,7 @@ def async_track_point_in_utc_time(
# having to figure out how to call the action every time its called. # having to figure out how to call the action every time its called.
job = action if isinstance(action, HassJob) else HassJob(action) job = action if isinstance(action, HassJob) else HassJob(action)
cancel_callback: Optional[asyncio.TimerHandle] = None cancel_callback: asyncio.TimerHandle | None = None
@callback @callback
def run_action() -> None: def run_action() -> None:
@ -1217,7 +1208,7 @@ track_point_in_utc_time = threaded_listener_factory(async_track_point_in_utc_tim
@callback @callback
@bind_hass @bind_hass
def async_call_later( def async_call_later(
hass: HomeAssistant, delay: float, action: Union[HassJob, Callable[..., None]] hass: HomeAssistant, delay: float, action: HassJob | Callable[..., None]
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that is called in <delay>.""" """Add a listener that is called in <delay>."""
return async_track_point_in_utc_time( return async_track_point_in_utc_time(
@ -1232,7 +1223,7 @@ call_later = threaded_listener_factory(async_call_later)
@bind_hass @bind_hass
def async_track_time_interval( def async_track_time_interval(
hass: HomeAssistant, hass: HomeAssistant,
action: Callable[..., Union[None, Awaitable]], action: Callable[..., None | Awaitable],
interval: timedelta, interval: timedelta,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that fires repetitively at every timedelta interval.""" """Add a listener that fires repetitively at every timedelta interval."""
@ -1276,9 +1267,9 @@ class SunListener:
hass: HomeAssistant = attr.ib() hass: HomeAssistant = attr.ib()
job: HassJob = attr.ib() job: HassJob = attr.ib()
event: str = attr.ib() event: str = attr.ib()
offset: Optional[timedelta] = attr.ib() offset: timedelta | None = attr.ib()
_unsub_sun: Optional[CALLBACK_TYPE] = attr.ib(default=None) _unsub_sun: CALLBACK_TYPE | None = attr.ib(default=None)
_unsub_config: Optional[CALLBACK_TYPE] = attr.ib(default=None) _unsub_config: CALLBACK_TYPE | None = attr.ib(default=None)
@callback @callback
def async_attach(self) -> None: def async_attach(self) -> None:
@ -1332,7 +1323,7 @@ class SunListener:
@callback @callback
@bind_hass @bind_hass
def async_track_sunrise( def async_track_sunrise(
hass: HomeAssistant, action: Callable[..., None], offset: Optional[timedelta] = None hass: HomeAssistant, action: Callable[..., None], offset: timedelta | None = None
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that will fire a specified offset from sunrise daily.""" """Add a listener that will fire a specified offset from sunrise daily."""
listener = SunListener(hass, HassJob(action), SUN_EVENT_SUNRISE, offset) listener = SunListener(hass, HassJob(action), SUN_EVENT_SUNRISE, offset)
@ -1346,7 +1337,7 @@ track_sunrise = threaded_listener_factory(async_track_sunrise)
@callback @callback
@bind_hass @bind_hass
def async_track_sunset( def async_track_sunset(
hass: HomeAssistant, action: Callable[..., None], offset: Optional[timedelta] = None hass: HomeAssistant, action: Callable[..., None], offset: timedelta | None = None
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that will fire a specified offset from sunset daily.""" """Add a listener that will fire a specified offset from sunset daily."""
listener = SunListener(hass, HassJob(action), SUN_EVENT_SUNSET, offset) listener = SunListener(hass, HassJob(action), SUN_EVENT_SUNSET, offset)
@ -1365,9 +1356,9 @@ time_tracker_utcnow = dt_util.utcnow
def async_track_utc_time_change( def async_track_utc_time_change(
hass: HomeAssistant, hass: HomeAssistant,
action: Callable[..., None], action: Callable[..., None],
hour: Optional[Any] = None, hour: Any | None = None,
minute: Optional[Any] = None, minute: Any | None = None,
second: Optional[Any] = None, second: Any | None = None,
local: bool = False, local: bool = False,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that will fire if time matches a pattern.""" """Add a listener that will fire if time matches a pattern."""
@ -1394,7 +1385,7 @@ def async_track_utc_time_change(
localized_now, matching_seconds, matching_minutes, matching_hours localized_now, matching_seconds, matching_minutes, matching_hours
) )
time_listener: Optional[CALLBACK_TYPE] = None time_listener: CALLBACK_TYPE | None = None
@callback @callback
def pattern_time_change_listener(_: datetime) -> None: def pattern_time_change_listener(_: datetime) -> None:
@ -1431,9 +1422,9 @@ track_utc_time_change = threaded_listener_factory(async_track_utc_time_change)
def async_track_time_change( def async_track_time_change(
hass: HomeAssistant, hass: HomeAssistant,
action: Callable[..., None], action: Callable[..., None],
hour: Optional[Any] = None, hour: Any | None = None,
minute: Optional[Any] = None, minute: Any | None = None,
second: Optional[Any] = None, second: Any | None = None,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that will fire if UTC time matches a pattern.""" """Add a listener that will fire if UTC time matches a pattern."""
return async_track_utc_time_change(hass, action, hour, minute, second, local=True) return async_track_utc_time_change(hass, action, hour, minute, second, local=True)
@ -1442,9 +1433,7 @@ def async_track_time_change(
track_time_change = threaded_listener_factory(async_track_time_change) track_time_change = threaded_listener_factory(async_track_time_change)
def process_state_match( def process_state_match(parameter: None | str | Iterable[str]) -> Callable[[str], bool]:
parameter: Union[None, str, Iterable[str]]
) -> Callable[[str], bool]:
"""Convert parameter to function that matches input against parameter.""" """Convert parameter to function that matches input against parameter."""
if parameter is None or parameter == MATCH_ALL: if parameter is None or parameter == MATCH_ALL:
return lambda _: True return lambda _: True
@ -1459,7 +1448,7 @@ def process_state_match(
@callback @callback
def _entities_domains_from_render_infos( def _entities_domains_from_render_infos(
render_infos: Iterable[RenderInfo], render_infos: Iterable[RenderInfo],
) -> Tuple[Set, Set]: ) -> tuple[set, set]:
"""Combine from multiple RenderInfo.""" """Combine from multiple RenderInfo."""
entities = set() entities = set()
domains = set() domains = set()
@ -1520,7 +1509,7 @@ def _event_triggers_rerender(event: Event, info: RenderInfo) -> bool:
@callback @callback
def _rate_limit_for_event( def _rate_limit_for_event(
event: Event, info: RenderInfo, track_template_: TrackTemplate event: Event, info: RenderInfo, track_template_: TrackTemplate
) -> Optional[timedelta]: ) -> timedelta | None:
"""Determine the rate limit for an event.""" """Determine the rate limit for an event."""
entity_id = event.data.get(ATTR_ENTITY_ID) entity_id = event.data.get(ATTR_ENTITY_ID)
@ -1532,7 +1521,7 @@ def _rate_limit_for_event(
if track_template_.rate_limit is not None: if track_template_.rate_limit is not None:
return track_template_.rate_limit return track_template_.rate_limit
rate_limit: Optional[timedelta] = info.rate_limit rate_limit: timedelta | None = info.rate_limit
return rate_limit return rate_limit

View file

@ -1,9 +1,11 @@
"""Provide frame helper for finding the current frame context.""" """Provide frame helper for finding the current frame context."""
from __future__ import annotations
import asyncio import asyncio
import functools import functools
import logging import logging
from traceback import FrameSummary, extract_stack from traceback import FrameSummary, extract_stack
from typing import Any, Callable, Optional, Tuple, TypeVar, cast from typing import Any, Callable, TypeVar, cast
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -13,8 +15,8 @@ CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-na
def get_integration_frame( def get_integration_frame(
exclude_integrations: Optional[set] = None, exclude_integrations: set | None = None,
) -> Tuple[FrameSummary, str, str]: ) -> tuple[FrameSummary, str, str]:
"""Return the frame, integration and integration path of the current stack frame.""" """Return the frame, integration and integration path of the current stack frame."""
found_frame = None found_frame = None
if not exclude_integrations: if not exclude_integrations:
@ -64,7 +66,7 @@ def report(what: str) -> None:
def report_integration( def report_integration(
what: str, integration_frame: Tuple[FrameSummary, str, str] what: str, integration_frame: tuple[FrameSummary, str, str]
) -> None: ) -> None:
"""Report incorrect usage in an integration. """Report incorrect usage in an integration.

View file

@ -1,6 +1,8 @@
"""Helper for httpx.""" """Helper for httpx."""
from __future__ import annotations
import sys import sys
from typing import Any, Callable, Optional from typing import Any, Callable
import httpx import httpx
@ -29,7 +31,7 @@ def get_async_client(
""" """
key = DATA_ASYNC_CLIENT if verify_ssl else DATA_ASYNC_CLIENT_NOVERIFY key = DATA_ASYNC_CLIENT if verify_ssl else DATA_ASYNC_CLIENT_NOVERIFY
client: Optional[httpx.AsyncClient] = hass.data.get(key) client: httpx.AsyncClient | None = hass.data.get(key)
if client is None: if client is None:
client = hass.data[key] = create_async_httpx_client(hass, verify_ssl) client = hass.data[key] = create_async_httpx_client(hass, verify_ssl)

View file

@ -1,9 +1,9 @@
"""Icon helper methods.""" """Icon helper methods."""
from typing import Optional from __future__ import annotations
def icon_for_battery_level( def icon_for_battery_level(
battery_level: Optional[int] = None, charging: bool = False battery_level: int | None = None, charging: bool = False
) -> str: ) -> str:
"""Return a battery icon valid identifier.""" """Return a battery icon valid identifier."""
icon = "mdi:battery" icon = "mdi:battery"
@ -20,7 +20,7 @@ def icon_for_battery_level(
return icon return icon
def icon_for_signal_level(signal_level: Optional[int] = None) -> str: def icon_for_signal_level(signal_level: int | None = None) -> str:
"""Return a signal icon valid identifier.""" """Return a signal icon valid identifier."""
if signal_level is None or signal_level == 0: if signal_level is None or signal_level == 0:
return "mdi:signal-cellular-outline" return "mdi:signal-cellular-outline"

View file

@ -1,5 +1,6 @@
"""Helper to create a unique instance ID.""" """Helper to create a unique instance ID."""
from typing import Dict, Optional from __future__ import annotations
import uuid import uuid
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -17,7 +18,7 @@ async def async_get(hass: HomeAssistant) -> str:
"""Get unique ID for the hass instance.""" """Get unique ID for the hass instance."""
store = storage.Store(hass, DATA_VERSION, DATA_KEY, True) store = storage.Store(hass, DATA_VERSION, DATA_KEY, True)
data: Optional[Dict[str, str]] = await storage.async_migrator( # type: ignore data: dict[str, str] | None = await storage.async_migrator( # type: ignore
hass, hass,
hass.config.path(LEGACY_UUID_FILE), hass.config.path(LEGACY_UUID_FILE),
store, store,

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
import re import re
from typing import Any, Callable, Dict, Iterable, Optional from typing import Any, Callable, Dict, Iterable
import voluptuous as vol import voluptuous as vol
@ -52,9 +52,9 @@ async def async_handle(
hass: HomeAssistantType, hass: HomeAssistantType,
platform: str, platform: str,
intent_type: str, intent_type: str,
slots: Optional[_SlotsType] = None, slots: _SlotsType | None = None,
text_input: Optional[str] = None, text_input: str | None = None,
context: Optional[Context] = None, context: Context | None = None,
) -> IntentResponse: ) -> IntentResponse:
"""Handle an intent.""" """Handle an intent."""
handler: IntentHandler = hass.data.get(DATA_KEY, {}).get(intent_type) handler: IntentHandler = hass.data.get(DATA_KEY, {}).get(intent_type)
@ -103,7 +103,7 @@ class IntentUnexpectedError(IntentError):
@callback @callback
@bind_hass @bind_hass
def async_match_state( def async_match_state(
hass: HomeAssistantType, name: str, states: Optional[Iterable[State]] = None hass: HomeAssistantType, name: str, states: Iterable[State] | None = None
) -> State: ) -> State:
"""Find a state that matches the name.""" """Find a state that matches the name."""
if states is None: if states is None:
@ -127,10 +127,10 @@ def async_test_feature(state: State, feature: int, feature_name: str) -> None:
class IntentHandler: class IntentHandler:
"""Intent handler registration.""" """Intent handler registration."""
intent_type: Optional[str] = None intent_type: str | None = None
slot_schema: Optional[vol.Schema] = None slot_schema: vol.Schema | None = None
_slot_schema: Optional[vol.Schema] = None _slot_schema: vol.Schema | None = None
platforms: Optional[Iterable[str]] = [] platforms: Iterable[str] | None = []
@callback @callback
def async_can_handle(self, intent_obj: Intent) -> bool: def async_can_handle(self, intent_obj: Intent) -> bool:
@ -163,7 +163,7 @@ class IntentHandler:
return f"<{self.__class__.__name__} - {self.intent_type}>" return f"<{self.__class__.__name__} - {self.intent_type}>"
def _fuzzymatch(name: str, items: Iterable[T], key: Callable[[T], str]) -> Optional[T]: def _fuzzymatch(name: str, items: Iterable[T], key: Callable[[T], str]) -> T | None:
"""Fuzzy matching function.""" """Fuzzy matching function."""
matches = [] matches = []
pattern = ".*?".join(name) pattern = ".*?".join(name)
@ -226,7 +226,7 @@ class Intent:
platform: str, platform: str,
intent_type: str, intent_type: str,
slots: _SlotsType, slots: _SlotsType,
text_input: Optional[str], text_input: str | None,
context: Context, context: Context,
) -> None: ) -> None:
"""Initialize an intent.""" """Initialize an intent."""
@ -246,15 +246,15 @@ class Intent:
class IntentResponse: class IntentResponse:
"""Response to an intent.""" """Response to an intent."""
def __init__(self, intent: Optional[Intent] = None) -> None: def __init__(self, intent: Intent | None = None) -> None:
"""Initialize an IntentResponse.""" """Initialize an IntentResponse."""
self.intent = intent self.intent = intent
self.speech: Dict[str, Dict[str, Any]] = {} self.speech: dict[str, dict[str, Any]] = {}
self.card: Dict[str, Dict[str, str]] = {} self.card: dict[str, dict[str, str]] = {}
@callback @callback
def async_set_speech( def async_set_speech(
self, speech: str, speech_type: str = "plain", extra_data: Optional[Any] = None self, speech: str, speech_type: str = "plain", extra_data: Any | None = None
) -> None: ) -> None:
"""Set speech response.""" """Set speech response."""
self.speech[speech_type] = {"speech": speech, "extra_data": extra_data} self.speech[speech_type] = {"speech": speech, "extra_data": extra_data}
@ -267,6 +267,6 @@ class IntentResponse:
self.card[card_type] = {"title": title, "content": content} self.card[card_type] = {"title": title, "content": content}
@callback @callback
def as_dict(self) -> Dict[str, Dict[str, Dict[str, Any]]]: def as_dict(self) -> dict[str, dict[str, dict[str, Any]]]:
"""Return a dictionary representation of an intent response.""" """Return a dictionary representation of an intent response."""
return {"speech": self.speech, "card": self.card} return {"speech": self.speech, "card": self.card}

View file

@ -1,7 +1,8 @@
"""Location helpers for Home Assistant.""" """Location helpers for Home Assistant."""
from __future__ import annotations
import logging import logging
from typing import Optional, Sequence from typing import Sequence
import voluptuous as vol import voluptuous as vol
@ -25,9 +26,7 @@ def has_location(state: State) -> bool:
) )
def closest( def closest(latitude: float, longitude: float, states: Sequence[State]) -> State | None:
latitude: float, longitude: float, states: Sequence[State]
) -> Optional[State]:
"""Return closest state to point. """Return closest state to point.
Async friendly. Async friendly.
@ -50,8 +49,8 @@ def closest(
def find_coordinates( def find_coordinates(
hass: HomeAssistantType, entity_id: str, recursion_history: Optional[list] = None hass: HomeAssistantType, entity_id: str, recursion_history: list | None = None
) -> Optional[str]: ) -> str | None:
"""Find the gps coordinates of the entity in the form of '90.000,180.000'.""" """Find the gps coordinates of the entity in the form of '90.000,180.000'."""
entity_state = hass.states.get(entity_id) entity_state = hass.states.get(entity_id)

View file

@ -1,7 +1,9 @@
"""Helpers for logging allowing more advanced logging styles to be used.""" """Helpers for logging allowing more advanced logging styles to be used."""
from __future__ import annotations
import inspect import inspect
import logging import logging
from typing import Any, Mapping, MutableMapping, Optional, Tuple from typing import Any, Mapping, MutableMapping
class KeywordMessage: class KeywordMessage:
@ -26,7 +28,7 @@ class KeywordStyleAdapter(logging.LoggerAdapter):
"""Represents an adapter wrapping the logger allowing KeywordMessages.""" """Represents an adapter wrapping the logger allowing KeywordMessages."""
def __init__( def __init__(
self, logger: logging.Logger, extra: Optional[Mapping[str, Any]] = None self, logger: logging.Logger, extra: Mapping[str, Any] | None = None
) -> None: ) -> None:
"""Initialize a new StyleAdapter for the provided logger.""" """Initialize a new StyleAdapter for the provided logger."""
super().__init__(logger, extra or {}) super().__init__(logger, extra or {})
@ -41,7 +43,7 @@ class KeywordStyleAdapter(logging.LoggerAdapter):
def process( def process(
self, msg: Any, kwargs: MutableMapping[str, Any] self, msg: Any, kwargs: MutableMapping[str, Any]
) -> Tuple[Any, MutableMapping[str, Any]]: ) -> tuple[Any, MutableMapping[str, Any]]:
"""Process the keyword args in preparation for logging.""" """Process the keyword args in preparation for logging."""
return ( return (
msg, msg,

View file

@ -1,6 +1,8 @@
"""Network helpers.""" """Network helpers."""
from __future__ import annotations
from ipaddress import ip_address from ipaddress import ip_address
from typing import Optional, cast from typing import cast
import yarl import yarl
@ -117,7 +119,7 @@ def get_url(
raise NoURLAvailableError raise NoURLAvailableError
def _get_request_host() -> Optional[str]: def _get_request_host() -> str | None:
"""Get the host address of the current request.""" """Get the host address of the current request."""
request = http.current_request.get() request = http.current_request.get()
if request is None: if request is None:

View file

@ -1,8 +1,10 @@
"""Ratelimit helper.""" """Ratelimit helper."""
from __future__ import annotations
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
import logging import logging
from typing import Any, Callable, Dict, Hashable, Optional from typing import Any, Callable, Hashable
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -19,8 +21,8 @@ class KeyedRateLimit:
): ):
"""Initialize ratelimit tracker.""" """Initialize ratelimit tracker."""
self.hass = hass self.hass = hass
self._last_triggered: Dict[Hashable, datetime] = {} self._last_triggered: dict[Hashable, datetime] = {}
self._rate_limit_timers: Dict[Hashable, asyncio.TimerHandle] = {} self._rate_limit_timers: dict[Hashable, asyncio.TimerHandle] = {}
@callback @callback
def async_has_timer(self, key: Hashable) -> bool: def async_has_timer(self, key: Hashable) -> bool:
@ -30,7 +32,7 @@ class KeyedRateLimit:
return key in self._rate_limit_timers return key in self._rate_limit_timers
@callback @callback
def async_triggered(self, key: Hashable, now: Optional[datetime] = None) -> None: def async_triggered(self, key: Hashable, now: datetime | None = None) -> None:
"""Call when the action we are tracking was triggered.""" """Call when the action we are tracking was triggered."""
self.async_cancel_timer(key) self.async_cancel_timer(key)
self._last_triggered[key] = now or dt_util.utcnow() self._last_triggered[key] = now or dt_util.utcnow()
@ -54,11 +56,11 @@ class KeyedRateLimit:
def async_schedule_action( def async_schedule_action(
self, self,
key: Hashable, key: Hashable,
rate_limit: Optional[timedelta], rate_limit: timedelta | None,
now: datetime, now: datetime,
action: Callable, action: Callable,
*args: Any, *args: Any,
) -> Optional[datetime]: ) -> datetime | None:
"""Check rate limits and schedule an action if we hit the limit. """Check rate limits and schedule an action if we hit the limit.
If the rate limit is hit: If the rate limit is hit:

View file

@ -1,8 +1,9 @@
"""Class to reload platforms.""" """Class to reload platforms."""
from __future__ import annotations
import asyncio import asyncio
import logging import logging
from typing import Dict, Iterable, List, Optional from typing import Iterable
from homeassistant import config as conf_util from homeassistant import config as conf_util
from homeassistant.const import SERVICE_RELOAD from homeassistant.const import SERVICE_RELOAD
@ -61,7 +62,7 @@ async def _resetup_platform(
if not conf: if not conf:
return return
root_config: Dict = {integration_platform: []} root_config: dict = {integration_platform: []}
# Extract only the config for template, ignore the rest. # Extract only the config for template, ignore the rest.
for p_type, p_config in config_per_platform(conf, integration_platform): for p_type, p_config in config_per_platform(conf, integration_platform):
if p_type != integration_name: if p_type != integration_name:
@ -101,7 +102,7 @@ async def _async_setup_platform(
hass: HomeAssistantType, hass: HomeAssistantType,
integration_name: str, integration_name: str,
integration_platform: str, integration_platform: str,
platform_configs: List[Dict], platform_configs: list[dict],
) -> None: ) -> None:
"""Platform for the first time when new configuration is added.""" """Platform for the first time when new configuration is added."""
if integration_platform not in hass.data: if integration_platform not in hass.data:
@ -119,7 +120,7 @@ async def _async_setup_platform(
async def _async_reconfig_platform( async def _async_reconfig_platform(
platform: EntityPlatform, platform_configs: List[Dict] platform: EntityPlatform, platform_configs: list[dict]
) -> None: ) -> None:
"""Reconfigure an already loaded platform.""" """Reconfigure an already loaded platform."""
await platform.async_reset() await platform.async_reset()
@ -129,7 +130,7 @@ async def _async_reconfig_platform(
async def async_integration_yaml_config( async def async_integration_yaml_config(
hass: HomeAssistantType, integration_name: str hass: HomeAssistantType, integration_name: str
) -> Optional[ConfigType]: ) -> ConfigType | None:
"""Fetch the latest yaml configuration for an integration.""" """Fetch the latest yaml configuration for an integration."""
integration = await async_get_integration(hass, integration_name) integration = await async_get_integration(hass, integration_name)
@ -141,7 +142,7 @@ async def async_integration_yaml_config(
@callback @callback
def async_get_platform_without_config_entry( def async_get_platform_without_config_entry(
hass: HomeAssistantType, integration_name: str, integration_platform_name: str hass: HomeAssistantType, integration_name: str, integration_platform_name: str
) -> Optional[EntityPlatform]: ) -> EntityPlatform | None:
"""Find an existing platform that is not a config entry.""" """Find an existing platform that is not a config entry."""
for integration_platform in async_get_platforms(hass, integration_name): for integration_platform in async_get_platforms(hass, integration_name):
if integration_platform.config_entry is not None: if integration_platform.config_entry is not None:

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
import logging import logging
from typing import Any, Dict, List, Optional, Set, cast from typing import Any, cast
from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import ( from homeassistant.core import (
@ -45,12 +45,12 @@ class StoredState:
self.state = state self.state = state
self.last_seen = last_seen self.last_seen = last_seen
def as_dict(self) -> Dict[str, Any]: def as_dict(self) -> dict[str, Any]:
"""Return a dict representation of the stored state.""" """Return a dict representation of the stored state."""
return {"state": self.state.as_dict(), "last_seen": self.last_seen} return {"state": self.state.as_dict(), "last_seen": self.last_seen}
@classmethod @classmethod
def from_dict(cls, json_dict: Dict) -> StoredState: def from_dict(cls, json_dict: dict) -> StoredState:
"""Initialize a stored state from a dict.""" """Initialize a stored state from a dict."""
last_seen = json_dict["last_seen"] last_seen = json_dict["last_seen"]
@ -106,11 +106,11 @@ class RestoreStateData:
self.store: Store = Store( self.store: Store = Store(
hass, STORAGE_VERSION, STORAGE_KEY, encoder=JSONEncoder hass, STORAGE_VERSION, STORAGE_KEY, encoder=JSONEncoder
) )
self.last_states: Dict[str, StoredState] = {} self.last_states: dict[str, StoredState] = {}
self.entity_ids: Set[str] = set() self.entity_ids: set[str] = set()
@callback @callback
def async_get_stored_states(self) -> List[StoredState]: def async_get_stored_states(self) -> list[StoredState]:
"""Get the set of states which should be stored. """Get the set of states which should be stored.
This includes the states of all registered entities, as well as the This includes the states of all registered entities, as well as the
@ -249,7 +249,7 @@ class RestoreEntity(Entity):
) )
data.async_restore_entity_removed(self.entity_id) data.async_restore_entity_removed(self.entity_id)
async def async_get_last_state(self) -> Optional[State]: async def async_get_last_state(self) -> State | None:
"""Get the entity state from the previous run.""" """Get the entity state from the previous run."""
if self.hass is None or self.entity_id is None: if self.hass is None or self.entity_id is None:
# Return None if this entity isn't added to hass yet # Return None if this entity isn't added to hass yet

View file

@ -1,4 +1,6 @@
"""Helpers to execute scripts.""" """Helpers to execute scripts."""
from __future__ import annotations
import asyncio import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -6,18 +8,7 @@ from functools import partial
import itertools import itertools
import logging import logging
from types import MappingProxyType from types import MappingProxyType
from typing import ( from typing import Any, Callable, Dict, Sequence, Union, cast
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
import async_timeout import async_timeout
import voluptuous as vol import voluptuous as vol
@ -232,8 +223,8 @@ STATIC_VALIDATION_ACTION_TYPES = (
async def async_validate_actions_config( async def async_validate_actions_config(
hass: HomeAssistant, actions: List[ConfigType] hass: HomeAssistant, actions: list[ConfigType]
) -> List[ConfigType]: ) -> list[ConfigType]:
"""Validate a list of actions.""" """Validate a list of actions."""
return await asyncio.gather( return await asyncio.gather(
*[async_validate_action_config(hass, action) for action in actions] *[async_validate_action_config(hass, action) for action in actions]
@ -300,8 +291,8 @@ class _ScriptRun:
self, self,
hass: HomeAssistant, hass: HomeAssistant,
script: "Script", script: "Script",
variables: Dict[str, Any], variables: dict[str, Any],
context: Optional[Context], context: Context | None,
log_exceptions: bool, log_exceptions: bool,
) -> None: ) -> None:
self._hass = hass self._hass = hass
@ -310,7 +301,7 @@ class _ScriptRun:
self._context = context self._context = context
self._log_exceptions = log_exceptions self._log_exceptions = log_exceptions
self._step = -1 self._step = -1
self._action: Optional[Dict[str, Any]] = None self._action: dict[str, Any] | None = None
self._stop = asyncio.Event() self._stop = asyncio.Event()
self._stopped = asyncio.Event() self._stopped = asyncio.Event()
@ -890,7 +881,7 @@ async def _async_stop_scripts_at_shutdown(hass, event):
_VarsType = Union[Dict[str, Any], MappingProxyType] _VarsType = Union[Dict[str, Any], MappingProxyType]
def _referenced_extract_ids(data: Dict[str, Any], key: str, found: Set[str]) -> None: def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) -> None:
"""Extract referenced IDs.""" """Extract referenced IDs."""
if not data: if not data:
return return
@ -913,20 +904,20 @@ class Script:
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
sequence: Sequence[Dict[str, Any]], sequence: Sequence[dict[str, Any]],
name: str, name: str,
domain: str, domain: str,
*, *,
# Used in "Running <running_description>" log message # Used in "Running <running_description>" log message
running_description: Optional[str] = None, running_description: str | None = None,
change_listener: Optional[Callable[..., Any]] = None, change_listener: Callable[..., Any] | None = None,
script_mode: str = DEFAULT_SCRIPT_MODE, script_mode: str = DEFAULT_SCRIPT_MODE,
max_runs: int = DEFAULT_MAX, max_runs: int = DEFAULT_MAX,
max_exceeded: str = DEFAULT_MAX_EXCEEDED, max_exceeded: str = DEFAULT_MAX_EXCEEDED,
logger: Optional[logging.Logger] = None, logger: logging.Logger | None = None,
log_exceptions: bool = True, log_exceptions: bool = True,
top_level: bool = True, top_level: bool = True,
variables: Optional[ScriptVariables] = None, variables: ScriptVariables | None = None,
) -> None: ) -> None:
"""Initialize the script.""" """Initialize the script."""
all_scripts = hass.data.get(DATA_SCRIPTS) all_scripts = hass.data.get(DATA_SCRIPTS)
@ -959,25 +950,25 @@ class Script:
self._log_exceptions = log_exceptions self._log_exceptions = log_exceptions
self.last_action = None self.last_action = None
self.last_triggered: Optional[datetime] = None self.last_triggered: datetime | None = None
self._runs: List[_ScriptRun] = [] self._runs: list[_ScriptRun] = []
self.max_runs = max_runs self.max_runs = max_runs
self._max_exceeded = max_exceeded self._max_exceeded = max_exceeded
if script_mode == SCRIPT_MODE_QUEUED: if script_mode == SCRIPT_MODE_QUEUED:
self._queue_lck = asyncio.Lock() self._queue_lck = asyncio.Lock()
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {} self._config_cache: dict[set[tuple], Callable[..., bool]] = {}
self._repeat_script: Dict[int, Script] = {} self._repeat_script: dict[int, Script] = {}
self._choose_data: Dict[int, Dict[str, Any]] = {} self._choose_data: dict[int, dict[str, Any]] = {}
self._referenced_entities: Optional[Set[str]] = None self._referenced_entities: set[str] | None = None
self._referenced_devices: Optional[Set[str]] = None self._referenced_devices: set[str] | None = None
self.variables = variables self.variables = variables
self._variables_dynamic = template.is_complex(variables) self._variables_dynamic = template.is_complex(variables)
if self._variables_dynamic: if self._variables_dynamic:
template.attach(hass, variables) template.attach(hass, variables)
@property @property
def change_listener(self) -> Optional[Callable[..., Any]]: def change_listener(self) -> Callable[..., Any] | None:
"""Return the change_listener.""" """Return the change_listener."""
return self._change_listener return self._change_listener
@ -991,13 +982,13 @@ class Script:
): ):
self._change_listener_job = HassJob(change_listener) self._change_listener_job = HassJob(change_listener)
def _set_logger(self, logger: Optional[logging.Logger] = None) -> None: def _set_logger(self, logger: logging.Logger | None = None) -> None:
if logger: if logger:
self._logger = logger self._logger = logger
else: else:
self._logger = logging.getLogger(f"{__name__}.{slugify(self.name)}") self._logger = logging.getLogger(f"{__name__}.{slugify(self.name)}")
def update_logger(self, logger: Optional[logging.Logger] = None) -> None: def update_logger(self, logger: logging.Logger | None = None) -> None:
"""Update logger.""" """Update logger."""
self._set_logger(logger) self._set_logger(logger)
for script in self._repeat_script.values(): for script in self._repeat_script.values():
@ -1038,7 +1029,7 @@ class Script:
if self._referenced_devices is not None: if self._referenced_devices is not None:
return self._referenced_devices return self._referenced_devices
referenced: Set[str] = set() referenced: set[str] = set()
for step in self.sequence: for step in self.sequence:
action = cv.determine_script_action(step) action = cv.determine_script_action(step)
@ -1067,7 +1058,7 @@ class Script:
if self._referenced_entities is not None: if self._referenced_entities is not None:
return self._referenced_entities return self._referenced_entities
referenced: Set[str] = set() referenced: set[str] = set()
for step in self.sequence: for step in self.sequence:
action = cv.determine_script_action(step) action = cv.determine_script_action(step)
@ -1091,7 +1082,7 @@ class Script:
return referenced return referenced
def run( def run(
self, variables: Optional[_VarsType] = None, context: Optional[Context] = None self, variables: _VarsType | None = None, context: Context | None = None
) -> None: ) -> None:
"""Run script.""" """Run script."""
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
@ -1100,9 +1091,9 @@ class Script:
async def async_run( async def async_run(
self, self,
run_variables: Optional[_VarsType] = None, run_variables: _VarsType | None = None,
context: Optional[Context] = None, context: Context | None = None,
started_action: Optional[Callable[..., Any]] = None, started_action: Callable[..., Any] | None = None,
) -> None: ) -> None:
"""Run script.""" """Run script."""
if context is None: if context is None:

View file

@ -1,5 +1,7 @@
"""Script variables.""" """Script variables."""
from typing import Any, Dict, Mapping, Optional from __future__ import annotations
from typing import Any, Mapping
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
@ -9,20 +11,20 @@ from . import template
class ScriptVariables: class ScriptVariables:
"""Class to hold and render script variables.""" """Class to hold and render script variables."""
def __init__(self, variables: Dict[str, Any]): def __init__(self, variables: dict[str, Any]):
"""Initialize script variables.""" """Initialize script variables."""
self.variables = variables self.variables = variables
self._has_template: Optional[bool] = None self._has_template: bool | None = None
@callback @callback
def async_render( def async_render(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
run_variables: Optional[Mapping[str, Any]], run_variables: Mapping[str, Any] | None,
*, *,
render_as_defaults: bool = True, render_as_defaults: bool = True,
limited: bool = False, limited: bool = False,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Render script variables. """Render script variables.
The run variables are used to compute the static variables. The run variables are used to compute the static variables.

View file

@ -10,14 +10,9 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Dict,
Iterable, Iterable,
List,
Optional,
Set,
Tuple, Tuple,
TypedDict, TypedDict,
Union,
cast, cast,
) )
@ -79,8 +74,8 @@ class ServiceParams(TypedDict):
domain: str domain: str
service: str service: str
service_data: Dict[str, Any] service_data: dict[str, Any]
target: Optional[Dict] target: dict | None
@dataclasses.dataclass @dataclasses.dataclass
@ -88,17 +83,17 @@ class SelectedEntities:
"""Class to hold the selected entities.""" """Class to hold the selected entities."""
# Entities that were explicitly mentioned. # Entities that were explicitly mentioned.
referenced: Set[str] = dataclasses.field(default_factory=set) referenced: set[str] = dataclasses.field(default_factory=set)
# Entities that were referenced via device/area ID. # Entities that were referenced via device/area ID.
# Should not trigger a warning when they don't exist. # Should not trigger a warning when they don't exist.
indirectly_referenced: Set[str] = dataclasses.field(default_factory=set) indirectly_referenced: set[str] = dataclasses.field(default_factory=set)
# Referenced items that could not be found. # Referenced items that could not be found.
missing_devices: Set[str] = dataclasses.field(default_factory=set) missing_devices: set[str] = dataclasses.field(default_factory=set)
missing_areas: Set[str] = dataclasses.field(default_factory=set) missing_areas: set[str] = dataclasses.field(default_factory=set)
def log_missing(self, missing_entities: Set[str]) -> None: def log_missing(self, missing_entities: set[str]) -> None:
"""Log about missing items.""" """Log about missing items."""
parts = [] parts = []
for label, items in ( for label, items in (
@ -137,7 +132,7 @@ async def async_call_from_config(
blocking: bool = False, blocking: bool = False,
variables: TemplateVarsType = None, variables: TemplateVarsType = None,
validate_config: bool = True, validate_config: bool = True,
context: Optional[ha.Context] = None, context: ha.Context | None = None,
) -> None: ) -> None:
"""Call a service based on a config hash.""" """Call a service based on a config hash."""
try: try:
@ -235,7 +230,7 @@ def async_prepare_call_from_config(
@bind_hass @bind_hass
def extract_entity_ids( def extract_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
) -> Set[str]: ) -> set[str]:
"""Extract a list of entity ids from a service call. """Extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents. Will convert group entity ids to the entity ids it represents.
@ -251,7 +246,7 @@ async def async_extract_entities(
entities: Iterable[Entity], entities: Iterable[Entity],
service_call: ha.ServiceCall, service_call: ha.ServiceCall,
expand_group: bool = True, expand_group: bool = True,
) -> List[Entity]: ) -> list[Entity]:
"""Extract a list of entity objects from a service call. """Extract a list of entity objects from a service call.
Will convert group entity ids to the entity ids it represents. Will convert group entity ids to the entity ids it represents.
@ -287,7 +282,7 @@ async def async_extract_entities(
@bind_hass @bind_hass
async def async_extract_entity_ids( async def async_extract_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
) -> Set[str]: ) -> set[str]:
"""Extract a set of entity ids from a service call. """Extract a set of entity ids from a service call.
Will convert group entity ids to the entity ids it represents. Will convert group entity ids to the entity ids it represents.
@ -408,7 +403,7 @@ def _load_services_file(hass: HomeAssistantType, integration: Integration) -> JS
def _load_services_files( def _load_services_files(
hass: HomeAssistantType, integrations: Iterable[Integration] hass: HomeAssistantType, integrations: Iterable[Integration]
) -> List[JSON_TYPE]: ) -> list[JSON_TYPE]:
"""Load service files for multiple intergrations.""" """Load service files for multiple intergrations."""
return [_load_services_file(hass, integration) for integration in integrations] return [_load_services_file(hass, integration) for integration in integrations]
@ -416,7 +411,7 @@ def _load_services_files(
@bind_hass @bind_hass
async def async_get_all_descriptions( async def async_get_all_descriptions(
hass: HomeAssistantType, hass: HomeAssistantType,
) -> Dict[str, Dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
"""Return descriptions (i.e. user documentation) for all service calls.""" """Return descriptions (i.e. user documentation) for all service calls."""
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {}) descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
format_cache_key = "{}.{}".format format_cache_key = "{}.{}".format
@ -448,7 +443,7 @@ async def async_get_all_descriptions(
loaded[domain] = content loaded[domain] = content
# Build response # Build response
descriptions: Dict[str, Dict[str, Any]] = {} descriptions: dict[str, dict[str, Any]] = {}
for domain in services: for domain in services:
descriptions[domain] = {} descriptions[domain] = {}
@ -483,7 +478,7 @@ async def async_get_all_descriptions(
@ha.callback @ha.callback
@bind_hass @bind_hass
def async_set_service_schema( def async_set_service_schema(
hass: HomeAssistantType, domain: str, service: str, schema: Dict[str, Any] hass: HomeAssistantType, domain: str, service: str, schema: dict[str, Any]
) -> None: ) -> None:
"""Register a description for a service.""" """Register a description for a service."""
hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {}) hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
@ -504,9 +499,9 @@ def async_set_service_schema(
async def entity_service_call( async def entity_service_call(
hass: HomeAssistantType, hass: HomeAssistantType,
platforms: Iterable["EntityPlatform"], platforms: Iterable["EntityPlatform"],
func: Union[str, Callable[..., Any]], func: str | Callable[..., Any],
call: ha.ServiceCall, call: ha.ServiceCall,
required_features: Optional[Iterable[int]] = None, required_features: Iterable[int] | None = None,
) -> None: ) -> None:
"""Handle an entity service call. """Handle an entity service call.
@ -516,17 +511,17 @@ async def entity_service_call(
user = await hass.auth.async_get_user(call.context.user_id) user = await hass.auth.async_get_user(call.context.user_id)
if user is None: if user is None:
raise UnknownUser(context=call.context) raise UnknownUser(context=call.context)
entity_perms: Optional[ entity_perms: None | (
Callable[[str, str], bool] Callable[[str, str], bool]
] = user.permissions.check_entity ) = user.permissions.check_entity
else: else:
entity_perms = None entity_perms = None
target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL
if target_all_entities: if target_all_entities:
referenced: Optional[SelectedEntities] = None referenced: SelectedEntities | None = None
all_referenced: Optional[Set[str]] = None all_referenced: set[str] | None = None
else: else:
# A set of entities we're trying to target. # A set of entities we're trying to target.
referenced = await async_extract_referenced_entity_ids(hass, call, True) referenced = await async_extract_referenced_entity_ids(hass, call, True)
@ -534,7 +529,7 @@ async def entity_service_call(
# If the service function is a string, we'll pass it the service call data # If the service function is a string, we'll pass it the service call data
if isinstance(func, str): if isinstance(func, str):
data: Union[Dict, ha.ServiceCall] = { data: dict | ha.ServiceCall = {
key: val key: val
for key, val in call.data.items() for key, val in call.data.items()
if key not in cv.ENTITY_SERVICE_FIELDS if key not in cv.ENTITY_SERVICE_FIELDS
@ -546,7 +541,7 @@ async def entity_service_call(
# Check the permissions # Check the permissions
# A list with entities to call the service on. # A list with entities to call the service on.
entity_candidates: List["Entity"] = [] entity_candidates: list["Entity"] = []
if entity_perms is None: if entity_perms is None:
for platform in platforms: for platform in platforms:
@ -662,8 +657,8 @@ async def entity_service_call(
async def _handle_entity_call( async def _handle_entity_call(
hass: HomeAssistantType, hass: HomeAssistantType,
entity: Entity, entity: Entity,
func: Union[str, Callable[..., Any]], func: str | Callable[..., Any],
data: Union[Dict, ha.ServiceCall], data: dict | ha.ServiceCall,
context: ha.Context, context: ha.Context,
) -> None: ) -> None:
"""Handle calling service method.""" """Handle calling service method."""
@ -693,7 +688,7 @@ def async_register_admin_service(
hass: HomeAssistantType, hass: HomeAssistantType,
domain: str, domain: str,
service: str, service: str,
service_func: Callable[[ha.ServiceCall], Optional[Awaitable]], service_func: Callable[[ha.ServiceCall], Awaitable | None],
schema: vol.Schema = vol.Schema({}, extra=vol.PREVENT_EXTRA), schema: vol.Schema = vol.Schema({}, extra=vol.PREVENT_EXTRA),
) -> None: ) -> None:
"""Register a service that requires admin access.""" """Register a service that requires admin access."""

View file

@ -29,7 +29,7 @@ The following cases will never be passed to your function:
from __future__ import annotations from __future__ import annotations
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Optional, Union
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, State, callback from homeassistant.core import HomeAssistant, State, callback
@ -66,7 +66,7 @@ ExtraCheckTypeFunc = Callable[
async def create_checker( async def create_checker(
hass: HomeAssistant, hass: HomeAssistant,
_domain: str, _domain: str,
extra_significant_check: Optional[ExtraCheckTypeFunc] = None, extra_significant_check: ExtraCheckTypeFunc | None = None,
) -> SignificantlyChangedChecker: ) -> SignificantlyChangedChecker:
"""Create a significantly changed checker for a domain.""" """Create a significantly changed checker for a domain."""
await _initialize(hass) await _initialize(hass)
@ -90,15 +90,15 @@ async def _initialize(hass: HomeAssistant) -> None:
await async_process_integration_platforms(hass, PLATFORM, process_platform) await async_process_integration_platforms(hass, PLATFORM, process_platform)
def either_one_none(val1: Optional[Any], val2: Optional[Any]) -> bool: def either_one_none(val1: Any | None, val2: Any | None) -> bool:
"""Test if exactly one value is None.""" """Test if exactly one value is None."""
return (val1 is None and val2 is not None) or (val1 is not None and val2 is None) return (val1 is None and val2 is not None) or (val1 is not None and val2 is None)
def check_numeric_changed( def check_numeric_changed(
val1: Optional[Union[int, float]], val1: int | float | None,
val2: Optional[Union[int, float]], val2: int | float | None,
change: Union[int, float], change: int | float,
) -> bool: ) -> bool:
"""Check if two numeric values have changed.""" """Check if two numeric values have changed."""
if val1 is None and val2 is None: if val1 is None and val2 is None:
@ -125,22 +125,22 @@ class SignificantlyChangedChecker:
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
extra_significant_check: Optional[ExtraCheckTypeFunc] = None, extra_significant_check: ExtraCheckTypeFunc | None = None,
) -> None: ) -> None:
"""Test if an entity has significantly changed.""" """Test if an entity has significantly changed."""
self.hass = hass self.hass = hass
self.last_approved_entities: Dict[str, Tuple[State, Any]] = {} self.last_approved_entities: dict[str, tuple[State, Any]] = {}
self.extra_significant_check = extra_significant_check self.extra_significant_check = extra_significant_check
@callback @callback
def async_is_significant_change( def async_is_significant_change(
self, new_state: State, *, extra_arg: Optional[Any] = None self, new_state: State, *, extra_arg: Any | None = None
) -> bool: ) -> bool:
"""Return if this was a significant change. """Return if this was a significant change.
Extra kwargs are passed to the extra significant checker. Extra kwargs are passed to the extra significant checker.
""" """
old_data: Optional[Tuple[State, Any]] = self.last_approved_entities.get( old_data: tuple[State, Any] | None = self.last_approved_entities.get(
new_state.entity_id new_state.entity_id
) )
@ -164,9 +164,7 @@ class SignificantlyChangedChecker:
self.last_approved_entities[new_state.entity_id] = (new_state, extra_arg) self.last_approved_entities[new_state.entity_id] = (new_state, extra_arg)
return True return True
functions: Optional[Dict[str, CheckTypeFunc]] = self.hass.data.get( functions: dict[str, CheckTypeFunc] | None = self.hass.data.get(DATA_FUNCTIONS)
DATA_FUNCTIONS
)
if functions is None: if functions is None:
raise RuntimeError("Significant Change not initialized") raise RuntimeError("Significant Change not initialized")

View file

@ -1,7 +1,9 @@
"""Helper to help coordinating calls.""" """Helper to help coordinating calls."""
from __future__ import annotations
import asyncio import asyncio
import functools import functools
from typing import Callable, Optional, TypeVar, cast from typing import Callable, TypeVar, cast
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
@ -24,7 +26,7 @@ def singleton(data_key: str) -> Callable[[FUNC], FUNC]:
@bind_hass @bind_hass
@functools.wraps(func) @functools.wraps(func)
def wrapped(hass: HomeAssistant) -> T: def wrapped(hass: HomeAssistant) -> T:
obj: Optional[T] = hass.data.get(data_key) obj: T | None = hass.data.get(data_key)
if obj is None: if obj is None:
obj = hass.data[data_key] = func(hass) obj = hass.data[data_key] = func(hass)
return obj return obj

View file

@ -1,10 +1,12 @@
"""Helpers that help with state related things.""" """Helpers that help with state related things."""
from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
import datetime as dt import datetime as dt
import logging import logging
from types import ModuleType, TracebackType from types import ModuleType, TracebackType
from typing import Any, Dict, Iterable, List, Optional, Type, Union from typing import Any, Iterable
from homeassistant.components.sun import STATE_ABOVE_HORIZON, STATE_BELOW_HORIZON from homeassistant.components.sun import STATE_ABOVE_HORIZON, STATE_BELOW_HORIZON
from homeassistant.const import ( from homeassistant.const import (
@ -44,19 +46,19 @@ class AsyncTrackStates:
def __init__(self, hass: HomeAssistantType) -> None: def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize a TrackStates block.""" """Initialize a TrackStates block."""
self.hass = hass self.hass = hass
self.states: List[State] = [] self.states: list[State] = []
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
def __enter__(self) -> List[State]: def __enter__(self) -> list[State]:
"""Record time from which to track changes.""" """Record time from which to track changes."""
self.now = dt_util.utcnow() self.now = dt_util.utcnow()
return self.states return self.states
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_value: Optional[BaseException], exc_value: BaseException | None,
traceback: Optional[TracebackType], traceback: TracebackType | None,
) -> None: ) -> None:
"""Add changes states to changes list.""" """Add changes states to changes list."""
self.states.extend(get_changed_since(self.hass.states.async_all(), self.now)) self.states.extend(get_changed_since(self.hass.states.async_all(), self.now))
@ -64,7 +66,7 @@ class AsyncTrackStates:
def get_changed_since( def get_changed_since(
states: Iterable[State], utc_point_in_time: dt.datetime states: Iterable[State], utc_point_in_time: dt.datetime
) -> List[State]: ) -> list[State]:
"""Return list of states that have been changed since utc_point_in_time. """Return list of states that have been changed since utc_point_in_time.
Deprecated. Remove after June 2021. Deprecated. Remove after June 2021.
@ -76,21 +78,21 @@ def get_changed_since(
@bind_hass @bind_hass
async def async_reproduce_state( async def async_reproduce_state(
hass: HomeAssistantType, hass: HomeAssistantType,
states: Union[State, Iterable[State]], states: State | Iterable[State],
*, *,
context: Optional[Context] = None, context: Context | None = None,
reproduce_options: Optional[Dict[str, Any]] = None, reproduce_options: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Reproduce a list of states on multiple domains.""" """Reproduce a list of states on multiple domains."""
if isinstance(states, State): if isinstance(states, State):
states = [states] states = [states]
to_call: Dict[str, List[State]] = defaultdict(list) to_call: dict[str, list[State]] = defaultdict(list)
for state in states: for state in states:
to_call[state.domain].append(state) to_call[state.domain].append(state)
async def worker(domain: str, states_by_domain: List[State]) -> None: async def worker(domain: str, states_by_domain: list[State]) -> None:
try: try:
integration = await async_get_integration(hass, domain) integration = await async_get_integration(hass, domain)
except IntegrationNotFound: except IntegrationNotFound:
@ -100,7 +102,7 @@ async def async_reproduce_state(
return return
try: try:
platform: Optional[ModuleType] = integration.get_platform("reproduce_state") platform: ModuleType | None = integration.get_platform("reproduce_state")
except ImportError: except ImportError:
_LOGGER.warning("Integration %s does not support reproduce state", domain) _LOGGER.warning("Integration %s does not support reproduce state", domain)
return return

View file

@ -1,9 +1,11 @@
"""Helper to help store data.""" """Helper to help store data."""
from __future__ import annotations
import asyncio import asyncio
from json import JSONEncoder from json import JSONEncoder
import logging import logging
import os import os
from typing import Any, Callable, Dict, List, Optional, Type, Union from typing import Any, Callable
from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE
from homeassistant.core import CALLBACK_TYPE, CoreState, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, CoreState, HomeAssistant, callback
@ -71,18 +73,18 @@ class Store:
key: str, key: str,
private: bool = False, private: bool = False,
*, *,
encoder: Optional[Type[JSONEncoder]] = None, encoder: type[JSONEncoder] | None = None,
): ):
"""Initialize storage class.""" """Initialize storage class."""
self.version = version self.version = version
self.key = key self.key = key
self.hass = hass self.hass = hass
self._private = private self._private = private
self._data: Optional[Dict[str, Any]] = None self._data: dict[str, Any] | None = None
self._unsub_delay_listener: Optional[CALLBACK_TYPE] = None self._unsub_delay_listener: CALLBACK_TYPE | None = None
self._unsub_final_write_listener: Optional[CALLBACK_TYPE] = None self._unsub_final_write_listener: CALLBACK_TYPE | None = None
self._write_lock = asyncio.Lock() self._write_lock = asyncio.Lock()
self._load_task: Optional[asyncio.Future] = None self._load_task: asyncio.Future | None = None
self._encoder = encoder self._encoder = encoder
@property @property
@ -90,7 +92,7 @@ class Store:
"""Return the config path.""" """Return the config path."""
return self.hass.config.path(STORAGE_DIR, self.key) return self.hass.config.path(STORAGE_DIR, self.key)
async def async_load(self) -> Union[Dict, List, None]: async def async_load(self) -> dict | list | None:
"""Load data. """Load data.
If the expected version does not match the given version, the migrate If the expected version does not match the given version, the migrate
@ -140,7 +142,7 @@ class Store:
return stored return stored
async def async_save(self, data: Union[Dict, List]) -> None: async def async_save(self, data: dict | list) -> None:
"""Save data.""" """Save data."""
self._data = {"version": self.version, "key": self.key, "data": data} self._data = {"version": self.version, "key": self.key, "data": data}
@ -151,7 +153,7 @@ class Store:
await self._async_handle_write_data() await self._async_handle_write_data()
@callback @callback
def async_delay_save(self, data_func: Callable[[], Dict], delay: float = 0) -> None: def async_delay_save(self, data_func: Callable[[], dict], delay: float = 0) -> None:
"""Save data with an optional delay.""" """Save data with an optional delay."""
self._data = {"version": self.version, "key": self.key, "data_func": data_func} self._data = {"version": self.version, "key": self.key, "data_func": data_func}
@ -224,7 +226,7 @@ class Store:
except (json_util.SerializationError, json_util.WriteError) as err: except (json_util.SerializationError, json_util.WriteError) as err:
_LOGGER.error("Error writing config for %s: %s", self.key, err) _LOGGER.error("Error writing config for %s: %s", self.key, err)
def _write_data(self, path: str, data: Dict) -> None: def _write_data(self, path: str, data: dict) -> None:
"""Write the data.""" """Write the data."""
if not os.path.isdir(os.path.dirname(path)): if not os.path.isdir(os.path.dirname(path)):
os.makedirs(os.path.dirname(path)) os.makedirs(os.path.dirname(path))

View file

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import datetime import datetime
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING
from homeassistant.const import SUN_EVENT_SUNRISE, SUN_EVENT_SUNSET from homeassistant.const import SUN_EVENT_SUNRISE, SUN_EVENT_SUNSET
from homeassistant.core import callback from homeassistant.core import callback
@ -44,8 +44,8 @@ def get_astral_location(hass: HomeAssistantType) -> astral.Location:
def get_astral_event_next( def get_astral_event_next(
hass: HomeAssistantType, hass: HomeAssistantType,
event: str, event: str,
utc_point_in_time: Optional[datetime.datetime] = None, utc_point_in_time: datetime.datetime | None = None,
offset: Optional[datetime.timedelta] = None, offset: datetime.timedelta | None = None,
) -> datetime.datetime: ) -> datetime.datetime:
"""Calculate the next specified solar event.""" """Calculate the next specified solar event."""
location = get_astral_location(hass) location = get_astral_location(hass)
@ -56,8 +56,8 @@ def get_astral_event_next(
def get_location_astral_event_next( def get_location_astral_event_next(
location: "astral.Location", location: "astral.Location",
event: str, event: str,
utc_point_in_time: Optional[datetime.datetime] = None, utc_point_in_time: datetime.datetime | None = None,
offset: Optional[datetime.timedelta] = None, offset: datetime.timedelta | None = None,
) -> datetime.datetime: ) -> datetime.datetime:
"""Calculate the next specified solar event.""" """Calculate the next specified solar event."""
from astral import AstralError # pylint: disable=import-outside-toplevel from astral import AstralError # pylint: disable=import-outside-toplevel
@ -91,8 +91,8 @@ def get_location_astral_event_next(
def get_astral_event_date( def get_astral_event_date(
hass: HomeAssistantType, hass: HomeAssistantType,
event: str, event: str,
date: Union[datetime.date, datetime.datetime, None] = None, date: datetime.date | datetime.datetime | None = None,
) -> Optional[datetime.datetime]: ) -> datetime.datetime | None:
"""Calculate the astral event time for the specified date.""" """Calculate the astral event time for the specified date."""
from astral import AstralError # pylint: disable=import-outside-toplevel from astral import AstralError # pylint: disable=import-outside-toplevel
@ -114,7 +114,7 @@ def get_astral_event_date(
@callback @callback
@bind_hass @bind_hass
def is_up( def is_up(
hass: HomeAssistantType, utc_point_in_time: Optional[datetime.datetime] = None hass: HomeAssistantType, utc_point_in_time: datetime.datetime | None = None
) -> bool: ) -> bool:
"""Calculate if the sun is currently up.""" """Calculate if the sun is currently up."""
if utc_point_in_time is None: if utc_point_in_time is None:

View file

@ -1,7 +1,9 @@
"""Helper to gather system info.""" """Helper to gather system info."""
from __future__ import annotations
import os import os
import platform import platform
from typing import Any, Dict from typing import Any
from homeassistant.const import __version__ as current_version from homeassistant.const import __version__ as current_version
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
@ -11,7 +13,7 @@ from .typing import HomeAssistantType
@bind_hass @bind_hass
async def async_get_system_info(hass: HomeAssistantType) -> Dict[str, Any]: async def async_get_system_info(hass: HomeAssistantType) -> dict[str, Any]:
"""Return info about the system.""" """Return info about the system."""
info_object = { info_object = {
"installation_type": "Unknown", "installation_type": "Unknown",

View file

@ -1,6 +1,7 @@
"""Temperature helpers for Home Assistant.""" """Temperature helpers for Home Assistant."""
from __future__ import annotations
from numbers import Number from numbers import Number
from typing import Optional
from homeassistant.const import PRECISION_HALVES, PRECISION_TENTHS from homeassistant.const import PRECISION_HALVES, PRECISION_TENTHS
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -8,8 +9,8 @@ from homeassistant.util.temperature import convert as convert_temperature
def display_temp( def display_temp(
hass: HomeAssistant, temperature: Optional[float], unit: str, precision: float hass: HomeAssistant, temperature: float | None, unit: str, precision: float
) -> Optional[float]: ) -> float | None:
"""Convert temperature into preferred units/precision for display.""" """Convert temperature into preferred units/precision for display."""
temperature_unit = unit temperature_unit = unit
ha_unit = hass.config.units.temperature_unit ha_unit = hass.config.units.temperature_unit

View file

@ -13,7 +13,7 @@ import math
from operator import attrgetter from operator import attrgetter
import random import random
import re import re
from typing import Any, Dict, Generator, Iterable, Optional, Type, Union, cast from typing import Any, Generator, Iterable, cast
from urllib.parse import urlencode as urllib_urlencode from urllib.parse import urlencode as urllib_urlencode
import weakref import weakref
@ -125,7 +125,7 @@ def is_template_string(maybe_template: str) -> bool:
class ResultWrapper: class ResultWrapper:
"""Result wrapper class to store render result.""" """Result wrapper class to store render result."""
render_result: Optional[str] render_result: str | None
def gen_result_wrapper(kls): def gen_result_wrapper(kls):
@ -134,7 +134,7 @@ def gen_result_wrapper(kls):
class Wrapper(kls, ResultWrapper): class Wrapper(kls, ResultWrapper):
"""Wrapper of a kls that can store render_result.""" """Wrapper of a kls that can store render_result."""
def __init__(self, *args: tuple, render_result: Optional[str] = None) -> None: def __init__(self, *args: tuple, render_result: str | None = None) -> None:
super().__init__(*args) super().__init__(*args)
self.render_result = render_result self.render_result = render_result
@ -156,15 +156,13 @@ class TupleWrapper(tuple, ResultWrapper):
# This is all magic to be allowed to subclass a tuple. # This is all magic to be allowed to subclass a tuple.
def __new__( def __new__(cls, value: tuple, *, render_result: str | None = None) -> TupleWrapper:
cls, value: tuple, *, render_result: Optional[str] = None
) -> TupleWrapper:
"""Create a new tuple class.""" """Create a new tuple class."""
return super().__new__(cls, tuple(value)) return super().__new__(cls, tuple(value))
# pylint: disable=super-init-not-called # pylint: disable=super-init-not-called
def __init__(self, value: tuple, *, render_result: Optional[str] = None): def __init__(self, value: tuple, *, render_result: str | None = None):
"""Initialize a new tuple class.""" """Initialize a new tuple class."""
self.render_result = render_result self.render_result = render_result
@ -176,7 +174,7 @@ class TupleWrapper(tuple, ResultWrapper):
return self.render_result return self.render_result
RESULT_WRAPPERS: Dict[Type, Type] = { RESULT_WRAPPERS: dict[type, type] = {
kls: gen_result_wrapper(kls) # type: ignore[no-untyped-call] kls: gen_result_wrapper(kls) # type: ignore[no-untyped-call]
for kls in (list, dict, set) for kls in (list, dict, set)
} }
@ -200,15 +198,15 @@ class RenderInfo:
# Will be set sensibly once frozen. # Will be set sensibly once frozen.
self.filter_lifecycle = _true self.filter_lifecycle = _true
self.filter = _true self.filter = _true
self._result: Optional[str] = None self._result: str | None = None
self.is_static = False self.is_static = False
self.exception: Optional[TemplateError] = None self.exception: TemplateError | None = None
self.all_states = False self.all_states = False
self.all_states_lifecycle = False self.all_states_lifecycle = False
self.domains = set() self.domains = set()
self.domains_lifecycle = set() self.domains_lifecycle = set()
self.entities = set() self.entities = set()
self.rate_limit: Optional[timedelta] = None self.rate_limit: timedelta | None = None
self.has_time = False self.has_time = False
def __repr__(self) -> str: def __repr__(self) -> str:
@ -294,7 +292,7 @@ class Template:
self.template: str = template.strip() self.template: str = template.strip()
self._compiled_code = None self._compiled_code = None
self._compiled: Optional[Template] = None self._compiled: Template | None = None
self.hass = hass self.hass = hass
self.is_static = not is_template_string(template) self.is_static = not is_template_string(template)
self._limited = None self._limited = None
@ -304,7 +302,7 @@ class Template:
if self.hass is None: if self.hass is None:
return _NO_HASS_ENV return _NO_HASS_ENV
wanted_env = _ENVIRONMENT_LIMITED if self._limited else _ENVIRONMENT wanted_env = _ENVIRONMENT_LIMITED if self._limited else _ENVIRONMENT
ret: Optional[TemplateEnvironment] = self.hass.data.get(wanted_env) ret: TemplateEnvironment | None = self.hass.data.get(wanted_env)
if ret is None: if ret is None:
ret = self.hass.data[wanted_env] = TemplateEnvironment(self.hass, self._limited) # type: ignore[no-untyped-call] ret = self.hass.data[wanted_env] = TemplateEnvironment(self.hass, self._limited) # type: ignore[no-untyped-call]
return ret return ret
@ -776,7 +774,7 @@ def _collect_state(hass: HomeAssistantType, entity_id: str) -> None:
entity_collect.entities.add(entity_id) entity_collect.entities.add(entity_id)
def _state_generator(hass: HomeAssistantType, domain: Optional[str]) -> Generator: def _state_generator(hass: HomeAssistantType, domain: str | None) -> Generator:
"""State generator for a domain or all states.""" """State generator for a domain or all states."""
for state in sorted(hass.states.async_all(domain), key=attrgetter("entity_id")): for state in sorted(hass.states.async_all(domain), key=attrgetter("entity_id")):
yield TemplateState(hass, state, collect=False) yield TemplateState(hass, state, collect=False)
@ -784,20 +782,20 @@ def _state_generator(hass: HomeAssistantType, domain: Optional[str]) -> Generato
def _get_state_if_valid( def _get_state_if_valid(
hass: HomeAssistantType, entity_id: str hass: HomeAssistantType, entity_id: str
) -> Optional[TemplateState]: ) -> TemplateState | None:
state = hass.states.get(entity_id) state = hass.states.get(entity_id)
if state is None and not valid_entity_id(entity_id): if state is None and not valid_entity_id(entity_id):
raise TemplateError(f"Invalid entity ID '{entity_id}'") # type: ignore raise TemplateError(f"Invalid entity ID '{entity_id}'") # type: ignore
return _get_template_state_from_state(hass, entity_id, state) return _get_template_state_from_state(hass, entity_id, state)
def _get_state(hass: HomeAssistantType, entity_id: str) -> Optional[TemplateState]: def _get_state(hass: HomeAssistantType, entity_id: str) -> TemplateState | None:
return _get_template_state_from_state(hass, entity_id, hass.states.get(entity_id)) return _get_template_state_from_state(hass, entity_id, hass.states.get(entity_id))
def _get_template_state_from_state( def _get_template_state_from_state(
hass: HomeAssistantType, entity_id: str, state: Optional[State] hass: HomeAssistantType, entity_id: str, state: State | None
) -> Optional[TemplateState]: ) -> TemplateState | None:
if state is None: if state is None:
# Only need to collect if none, if not none collect first actual # Only need to collect if none, if not none collect first actual
# access to the state properties in the state wrapper. # access to the state properties in the state wrapper.
@ -808,7 +806,7 @@ def _get_template_state_from_state(
def _resolve_state( def _resolve_state(
hass: HomeAssistantType, entity_id_or_state: Any hass: HomeAssistantType, entity_id_or_state: Any
) -> Union[State, TemplateState, None]: ) -> State | TemplateState | None:
"""Return state or entity_id if given.""" """Return state or entity_id if given."""
if isinstance(entity_id_or_state, State): if isinstance(entity_id_or_state, State):
return entity_id_or_state return entity_id_or_state
@ -817,7 +815,7 @@ def _resolve_state(
return None return None
def result_as_boolean(template_result: Optional[str]) -> bool: def result_as_boolean(template_result: str | None) -> bool:
"""Convert the template result to a boolean. """Convert the template result to a boolean.
True/not 0/'1'/'true'/'yes'/'on'/'enable' are considered truthy True/not 0/'1'/'true'/'yes'/'on'/'enable' are considered truthy

View file

@ -1,8 +1,10 @@
"""Helpers for script and condition tracing.""" """Helpers for script and condition tracing."""
from __future__ import annotations
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from typing import Any, Deque, Dict, Generator, List, Optional, Tuple, Union, cast from typing import Any, Deque, Generator, cast
from homeassistant.helpers.typing import TemplateVarsType from homeassistant.helpers.typing import TemplateVarsType
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -13,9 +15,9 @@ class TraceElement:
def __init__(self, variables: TemplateVarsType, path: str): def __init__(self, variables: TemplateVarsType, path: str):
"""Container for trace data.""" """Container for trace data."""
self._error: Optional[Exception] = None self._error: Exception | None = None
self.path: str = path self.path: str = path
self._result: Optional[dict] = None self._result: dict | None = None
self._timestamp = dt_util.utcnow() self._timestamp = dt_util.utcnow()
if variables is None: if variables is None:
@ -41,9 +43,9 @@ class TraceElement:
"""Set result.""" """Set result."""
self._result = {**kwargs} self._result = {**kwargs}
def as_dict(self) -> Dict[str, Any]: def as_dict(self) -> dict[str, Any]:
"""Return dictionary version of this TraceElement.""" """Return dictionary version of this TraceElement."""
result: Dict[str, Any] = {"path": self.path, "timestamp": self._timestamp} result: dict[str, Any] = {"path": self.path, "timestamp": self._timestamp}
if self._variables: if self._variables:
result["changed_variables"] = self._variables result["changed_variables"] = self._variables
if self._error is not None: if self._error is not None:
@ -55,31 +57,31 @@ class TraceElement:
# Context variables for tracing # Context variables for tracing
# Current trace # Current trace
trace_cv: ContextVar[Optional[Dict[str, Deque[TraceElement]]]] = ContextVar( trace_cv: ContextVar[dict[str, Deque[TraceElement]] | None] = ContextVar(
"trace_cv", default=None "trace_cv", default=None
) )
# Stack of TraceElements # Stack of TraceElements
trace_stack_cv: ContextVar[Optional[List[TraceElement]]] = ContextVar( trace_stack_cv: ContextVar[list[TraceElement] | None] = ContextVar(
"trace_stack_cv", default=None "trace_stack_cv", default=None
) )
# Current location in config tree # Current location in config tree
trace_path_stack_cv: ContextVar[Optional[List[str]]] = ContextVar( trace_path_stack_cv: ContextVar[list[str] | None] = ContextVar(
"trace_path_stack_cv", default=None "trace_path_stack_cv", default=None
) )
# Copy of last variables # Copy of last variables
variables_cv: ContextVar[Optional[Any]] = ContextVar("variables_cv", default=None) variables_cv: ContextVar[Any | None] = ContextVar("variables_cv", default=None)
# Automation ID + Run ID # Automation ID + Run ID
trace_id_cv: ContextVar[Optional[Tuple[str, str]]] = ContextVar( trace_id_cv: ContextVar[tuple[str, str] | None] = ContextVar(
"trace_id_cv", default=None "trace_id_cv", default=None
) )
def trace_id_set(trace_id: Tuple[str, str]) -> None: def trace_id_set(trace_id: tuple[str, str]) -> None:
"""Set id of the current trace.""" """Set id of the current trace."""
trace_id_cv.set(trace_id) trace_id_cv.set(trace_id)
def trace_id_get() -> Optional[Tuple[str, str]]: def trace_id_get() -> tuple[str, str] | None:
"""Get id if the current trace.""" """Get id if the current trace."""
return trace_id_cv.get() return trace_id_cv.get()
@ -99,13 +101,13 @@ def trace_stack_pop(trace_stack_var: ContextVar) -> None:
trace_stack.pop() trace_stack.pop()
def trace_stack_top(trace_stack_var: ContextVar) -> Optional[Any]: def trace_stack_top(trace_stack_var: ContextVar) -> Any | None:
"""Return the element at the top of a trace stack.""" """Return the element at the top of a trace stack."""
trace_stack = trace_stack_var.get() trace_stack = trace_stack_var.get()
return trace_stack[-1] if trace_stack else None return trace_stack[-1] if trace_stack else None
def trace_path_push(suffix: Union[str, List[str]]) -> int: def trace_path_push(suffix: str | list[str]) -> int:
"""Go deeper in the config tree.""" """Go deeper in the config tree."""
if isinstance(suffix, str): if isinstance(suffix, str):
suffix = [suffix] suffix = [suffix]
@ -130,7 +132,7 @@ def trace_path_get() -> str:
def trace_append_element( def trace_append_element(
trace_element: TraceElement, trace_element: TraceElement,
maxlen: Optional[int] = None, maxlen: int | None = None,
) -> None: ) -> None:
"""Append a TraceElement to trace[path].""" """Append a TraceElement to trace[path]."""
path = trace_element.path path = trace_element.path
@ -143,7 +145,7 @@ def trace_append_element(
trace[path].append(trace_element) trace[path].append(trace_element)
def trace_get(clear: bool = True) -> Optional[Dict[str, Deque[TraceElement]]]: def trace_get(clear: bool = True) -> dict[str, Deque[TraceElement]] | None:
"""Return the current trace.""" """Return the current trace."""
if clear: if clear:
trace_clear() trace_clear()
@ -165,7 +167,7 @@ def trace_set_result(**kwargs: Any) -> None:
@contextmanager @contextmanager
def trace_path(suffix: Union[str, List[str]]) -> Generator: def trace_path(suffix: str | list[str]) -> Generator:
"""Go deeper in the config tree.""" """Go deeper in the config tree."""
count = trace_path_push(suffix) count = trace_path_push(suffix)
try: try:

View file

@ -1,8 +1,10 @@
"""Translation string lookup helpers.""" """Translation string lookup helpers."""
from __future__ import annotations
import asyncio import asyncio
from collections import ChainMap from collections import ChainMap
import logging import logging
from typing import Any, Dict, List, Optional, Set from typing import Any
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.loader import ( from homeassistant.loader import (
@ -24,7 +26,7 @@ TRANSLATION_FLATTEN_CACHE = "translation_flatten_cache"
LOCALE_EN = "en" LOCALE_EN = "en"
def recursive_flatten(prefix: Any, data: Dict) -> Dict[str, Any]: def recursive_flatten(prefix: Any, data: dict) -> dict[str, Any]:
"""Return a flattened representation of dict data.""" """Return a flattened representation of dict data."""
output = {} output = {}
for key, value in data.items(): for key, value in data.items():
@ -38,7 +40,7 @@ def recursive_flatten(prefix: Any, data: Dict) -> Dict[str, Any]:
@callback @callback
def component_translation_path( def component_translation_path(
component: str, language: str, integration: Integration component: str, language: str, integration: Integration
) -> Optional[str]: ) -> str | None:
"""Return the translation json file location for a component. """Return the translation json file location for a component.
For component: For component:
@ -69,8 +71,8 @@ def component_translation_path(
def load_translations_files( def load_translations_files(
translation_files: Dict[str, str] translation_files: dict[str, str]
) -> Dict[str, Dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
"""Load and parse translation.json files.""" """Load and parse translation.json files."""
loaded = {} loaded = {}
for component, translation_file in translation_files.items(): for component, translation_file in translation_files.items():
@ -90,13 +92,13 @@ def load_translations_files(
def _merge_resources( def _merge_resources(
translation_strings: Dict[str, Dict[str, Any]], translation_strings: dict[str, dict[str, Any]],
components: Set[str], components: set[str],
category: str, category: str,
) -> Dict[str, Dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
"""Build and merge the resources response for the given components and platforms.""" """Build and merge the resources response for the given components and platforms."""
# Build response # Build response
resources: Dict[str, Dict[str, Any]] = {} resources: dict[str, dict[str, Any]] = {}
for component in components: for component in components:
if "." not in component: if "." not in component:
domain = component domain = component
@ -131,10 +133,10 @@ def _merge_resources(
def _build_resources( def _build_resources(
translation_strings: Dict[str, Dict[str, Any]], translation_strings: dict[str, dict[str, Any]],
components: Set[str], components: set[str],
category: str, category: str,
) -> Dict[str, Dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
"""Build the resources response for the given components.""" """Build the resources response for the given components."""
# Build response # Build response
return { return {
@ -146,8 +148,8 @@ def _build_resources(
async def async_get_component_strings( async def async_get_component_strings(
hass: HomeAssistantType, language: str, components: Set[str] hass: HomeAssistantType, language: str, components: set[str]
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Load translations.""" """Load translations."""
domains = list({loaded.split(".")[-1] for loaded in components}) domains = list({loaded.split(".")[-1] for loaded in components})
integrations = dict( integrations = dict(
@ -160,7 +162,7 @@ async def async_get_component_strings(
) )
) )
translations: Dict[str, Any] = {} translations: dict[str, Any] = {}
# Determine paths of missing components/platforms # Determine paths of missing components/platforms
files_to_load = {} files_to_load = {}
@ -205,15 +207,15 @@ class _TranslationCache:
def __init__(self, hass: HomeAssistantType) -> None: def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize the cache.""" """Initialize the cache."""
self.hass = hass self.hass = hass
self.loaded: Dict[str, Set[str]] = {} self.loaded: dict[str, set[str]] = {}
self.cache: Dict[str, Dict[str, Dict[str, Any]]] = {} self.cache: dict[str, dict[str, dict[str, Any]]] = {}
async def async_fetch( async def async_fetch(
self, self,
language: str, language: str,
category: str, category: str,
components: Set, components: set,
) -> List[Dict[str, Dict[str, Any]]]: ) -> list[dict[str, dict[str, Any]]]:
"""Load resources into the cache.""" """Load resources into the cache."""
components_to_load = components - self.loaded.setdefault(language, set()) components_to_load = components - self.loaded.setdefault(language, set())
@ -224,7 +226,7 @@ class _TranslationCache:
return [cached.get(component, {}).get(category, {}) for component in components] return [cached.get(component, {}).get(category, {}) for component in components]
async def _async_load(self, language: str, components: Set) -> None: async def _async_load(self, language: str, components: set) -> None:
"""Populate the cache for a given set of components.""" """Populate the cache for a given set of components."""
_LOGGER.debug( _LOGGER.debug(
"Cache miss for %s: %s", "Cache miss for %s: %s",
@ -247,12 +249,12 @@ class _TranslationCache:
def _build_category_cache( def _build_category_cache(
self, self,
language: str, language: str,
components: Set, components: set,
translation_strings: Dict[str, Dict[str, Any]], translation_strings: dict[str, dict[str, Any]],
) -> None: ) -> None:
"""Extract resources into the cache.""" """Extract resources into the cache."""
cached = self.cache.setdefault(language, {}) cached = self.cache.setdefault(language, {})
categories: Set[str] = set() categories: set[str] = set()
for resource in translation_strings.values(): for resource in translation_strings.values():
categories.update(resource) categories.update(resource)
@ -263,7 +265,7 @@ class _TranslationCache:
new_resources = resource_func(translation_strings, components, category) new_resources = resource_func(translation_strings, components, category)
for component, resource in new_resources.items(): for component, resource in new_resources.items():
category_cache: Dict[str, Any] = cached.setdefault( category_cache: dict[str, Any] = cached.setdefault(
component, {} component, {}
).setdefault(category, {}) ).setdefault(category, {})
@ -283,9 +285,9 @@ async def async_get_translations(
hass: HomeAssistantType, hass: HomeAssistantType,
language: str, language: str,
category: str, category: str,
integration: Optional[str] = None, integration: str | None = None,
config_flow: Optional[bool] = None, config_flow: bool | None = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Return all backend translations. """Return all backend translations.
If integration specified, load it for that one. If integration specified, load it for that one.

View file

@ -1,8 +1,10 @@
"""Triggers.""" """Triggers."""
from __future__ import annotations
import asyncio import asyncio
import logging import logging
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable
import voluptuous as vol import voluptuous as vol
@ -39,8 +41,8 @@ async def _async_get_trigger_platform(
async def async_validate_trigger_config( async def async_validate_trigger_config(
hass: HomeAssistantType, trigger_config: List[ConfigType] hass: HomeAssistantType, trigger_config: list[ConfigType]
) -> List[ConfigType]: ) -> list[ConfigType]:
"""Validate triggers.""" """Validate triggers."""
config = [] config = []
for conf in trigger_config: for conf in trigger_config:
@ -55,14 +57,14 @@ async def async_validate_trigger_config(
async def async_initialize_triggers( async def async_initialize_triggers(
hass: HomeAssistantType, hass: HomeAssistantType,
trigger_config: List[ConfigType], trigger_config: list[ConfigType],
action: Callable, action: Callable,
domain: str, domain: str,
name: str, name: str,
log_cb: Callable, log_cb: Callable,
home_assistant_start: bool = False, home_assistant_start: bool = False,
variables: Optional[Union[Dict[str, Any], MappingProxyType]] = None, variables: dict[str, Any] | MappingProxyType | None = None,
) -> Optional[CALLBACK_TYPE]: ) -> CALLBACK_TYPE | None:
"""Initialize triggers.""" """Initialize triggers."""
info = { info = {
"domain": domain, "domain": domain,

View file

@ -1,9 +1,11 @@
"""Helpers to help coordinate updates.""" """Helpers to help coordinate updates."""
from __future__ import annotations
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
import logging import logging
from time import monotonic from time import monotonic
from typing import Any, Awaitable, Callable, Generic, List, Optional, TypeVar from typing import Any, Awaitable, Callable, Generic, TypeVar
import urllib.error import urllib.error
import aiohttp import aiohttp
@ -37,9 +39,9 @@ class DataUpdateCoordinator(Generic[T]):
logger: logging.Logger, logger: logging.Logger,
*, *,
name: str, name: str,
update_interval: Optional[timedelta] = None, update_interval: timedelta | None = None,
update_method: Optional[Callable[[], Awaitable[T]]] = None, update_method: Callable[[], Awaitable[T]] | None = None,
request_refresh_debouncer: Optional[Debouncer] = None, request_refresh_debouncer: Debouncer | None = None,
): ):
"""Initialize global data updater.""" """Initialize global data updater."""
self.hass = hass self.hass = hass
@ -48,12 +50,12 @@ class DataUpdateCoordinator(Generic[T]):
self.update_method = update_method self.update_method = update_method
self.update_interval = update_interval self.update_interval = update_interval
self.data: Optional[T] = None self.data: T | None = None
self._listeners: List[CALLBACK_TYPE] = [] self._listeners: list[CALLBACK_TYPE] = []
self._job = HassJob(self._handle_refresh_interval) self._job = HassJob(self._handle_refresh_interval)
self._unsub_refresh: Optional[CALLBACK_TYPE] = None self._unsub_refresh: CALLBACK_TYPE | None = None
self._request_refresh_task: Optional[asyncio.TimerHandle] = None self._request_refresh_task: asyncio.TimerHandle | None = None
self.last_update_success = True self.last_update_success = True
if request_refresh_debouncer is None: if request_refresh_debouncer is None:
@ -132,7 +134,7 @@ class DataUpdateCoordinator(Generic[T]):
""" """
await self._debounced_refresh.async_call() await self._debounced_refresh.async_call()
async def _async_update_data(self) -> Optional[T]: async def _async_update_data(self) -> T | None:
"""Fetch the latest data from the source.""" """Fetch the latest data from the source."""
if self.update_method is None: if self.update_method is None:
raise NotImplementedError("Update method not implemented") raise NotImplementedError("Update method not implemented")