Type hint improvements (#32905)

* Complete helpers.entity_component type hints

* Add discovery info type
This commit is contained in:
Ville Skyttä 2020-03-18 19:27:25 +02:00 committed by GitHub
parent 7c79adad8f
commit 05abf37046
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 120 additions and 76 deletions

View file

@ -12,6 +12,7 @@ from homeassistant.config_entries import CONN_CLASS_LOCAL_POLL, ConfigFlow
from homeassistant.const import CONF_HOST, CONF_NAME
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import DiscoveryInfoType
from .const import DEFAULT_PORT
from .const import DOMAIN # pylint: disable=unused-import
@ -83,7 +84,7 @@ class DirecTVConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_create_entry(title=info["title"], data=user_input)
async def async_step_ssdp(
self, discovery_info: Optional[Dict] = None
self, discovery_info: Optional[DiscoveryInfoType] = None
) -> Dict[str, Any]:
"""Handle a flow initialized by discovery."""
host = urlparse(discovery_info[ATTR_SSDP_LOCATION]).hostname

View file

@ -19,7 +19,11 @@ from homeassistant.const import (
)
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
from homeassistant.helpers.typing import (
ConfigType,
DiscoveryInfoType,
HomeAssistantType,
)
from homeassistant.util import slugify
import homeassistant.util.dt as dt_util
@ -332,7 +336,7 @@ def setup_platform(
hass: HomeAssistantType,
config: ConfigType,
add_entities: Callable[[list], None],
discovery_info: Optional[dict] = None,
discovery_info: Optional[DiscoveryInfoType] = None,
) -> None:
"""Set up the GTFS sensor."""
gtfs_dir = hass.config.path(DEFAULT_PATH)

View file

@ -24,6 +24,7 @@ from homeassistant.core import HomeAssistant, State, callback
from homeassistant.helpers import location
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.typing import DiscoveryInfoType
import homeassistant.util.dt as dt
_LOGGER = logging.getLogger(__name__)
@ -144,7 +145,7 @@ async def async_setup_platform(
hass: HomeAssistant,
config: Dict[str, Union[str, bool]],
async_add_entities: Callable,
discovery_info: None = None,
discovery_info: Optional[DiscoveryInfoType] = None,
) -> None:
"""Set up the HERE travel time platform."""

View file

@ -114,9 +114,7 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
return cast(
bool, await cast(EntityComponent, hass.data[DOMAIN]).async_setup_entry(entry)
)
return await cast(EntityComponent, hass.data[DOMAIN]).async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistantType, entry: ConfigEntry) -> bool:

View file

@ -1,6 +1,6 @@
"""Light support for switch entities."""
import logging
from typing import Callable, Dict, Optional, Sequence, cast
from typing import Callable, Optional, Sequence, cast
import voluptuous as vol
@ -17,7 +17,11 @@ from homeassistant.core import CALLBACK_TYPE, State, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_track_state_change
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
from homeassistant.helpers.typing import (
ConfigType,
DiscoveryInfoType,
HomeAssistantType,
)
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
@ -37,7 +41,7 @@ async def async_setup_platform(
hass: HomeAssistantType,
config: ConfigType,
async_add_entities: Callable[[Sequence[Entity], bool], None],
discovery_info: Optional[Dict] = None,
discovery_info: Optional[DiscoveryInfoType] = None,
) -> None:
"""Initialize Light Switch platform."""
async_add_entities(

View file

@ -20,6 +20,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import (
ContextType,
DiscoveryInfoType,
EventType,
HomeAssistantType,
ServiceCallType,
@ -115,7 +116,7 @@ async def async_setup(hass: HomeAssistantType, config: Dict) -> bool:
hass.data[DOMAIN] = {DATA_DEVICE: device_data}
async def async_switch_platform_discovered(
platform: str, discovery_info: Optional[Dict]
platform: str, discovery_info: DiscoveryInfoType
) -> None:
"""Use for registering services after switch platform is discovered."""
if platform != DOMAIN:

View file

@ -1,7 +1,7 @@
"""Config flow for Vizio."""
import copy
import logging
from typing import Any, Dict
from typing import Any, Dict, Optional
from pyvizio import VizioAsync, async_guess_device_type
import voluptuous as vol
@ -23,6 +23,7 @@ from homeassistant.const import (
from homeassistant.core import callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.typing import DiscoveryInfoType
from .const import (
CONF_APPS,
@ -318,7 +319,7 @@ class VizioConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return await self.async_step_user(user_input=import_config)
async def async_step_zeroconf(
self, discovery_info: Dict[str, Any] = None
self, discovery_info: Optional[DiscoveryInfoType] = None
) -> Dict[str, Any]:
"""Handle zeroconf discovery."""

View file

@ -221,7 +221,7 @@ async def async_setup(hass: HomeAssistant, config: Dict) -> bool:
home_zone = Zone(_home_conf(hass), True,)
home_zone.entity_id = ENTITY_ID_HOME
await component.async_add_entities([home_zone]) # type: ignore
await component.async_add_entities([home_zone])
async def core_config_updated(_: Event) -> None:
"""Handle core config updated."""

View file

@ -266,7 +266,7 @@ def attach_entity_component_collection(
"""Handle a collection change."""
if change_type == CHANGE_ADDED:
entity = create_entity(config)
await entity_component.async_add_entities([entity]) # type: ignore
await entity_component.async_add_entities([entity])
entities[item_id] = entity
return

View file

@ -5,7 +5,7 @@ There are two different types of discoveries that can be fired/listened for.
- listen_platform/discover_platform is for platforms. These are used by
components to allow discovery of their platforms.
"""
from typing import Callable, Collection, Union
from typing import Any, Callable, Collection, Dict, Optional, Union
from homeassistant import core, setup
from homeassistant.const import ATTR_DISCOVERED, ATTR_SERVICE, EVENT_PLATFORM_DISCOVERED
@ -90,7 +90,9 @@ def listen_platform(
@bind_hass
def async_listen_platform(
hass: core.HomeAssistant, component: str, callback: Callable
hass: core.HomeAssistant,
component: str,
callback: Callable[[str, Optional[Dict[str, Any]]], Any],
) -> None:
"""Register a platform loader listener.

View file

@ -4,12 +4,14 @@ from datetime import timedelta
from itertools import chain
import logging
from types import ModuleType
from typing import Dict, Optional, cast
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import voluptuous as vol
from homeassistant import config as conf_util
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_ENTITY_NAMESPACE, CONF_SCAN_INTERVAL
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
config_per_platform,
@ -18,13 +20,12 @@ from homeassistant.helpers import (
entity,
service,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.loader import async_get_integration, bind_hass
from homeassistant.setup import async_prepare_setup_platform
from .entity_platform import EntityPlatform
# mypy: allow-untyped-defs, no-check-untyped-defs
DEFAULT_SCAN_INTERVAL = timedelta(seconds=15)
DATA_INSTANCES = "entity_components"
@ -75,18 +76,18 @@ class EntityComponent:
self.domain = domain
self.scan_interval = scan_interval
self.config = None
self.config: Optional[ConfigType] = None
self._platforms: Dict[str, EntityPlatform] = {
domain: self._async_init_entity_platform(domain, None)
}
self._platforms: Dict[
Union[str, Tuple[str, Optional[timedelta], Optional[str]]], EntityPlatform
] = {domain: self._async_init_entity_platform(domain, None)}
self.async_add_entities = self._platforms[domain].async_add_entities
self.add_entities = self._platforms[domain].add_entities
hass.data.setdefault(DATA_INSTANCES, {})[domain] = self
@property
def entities(self):
def entities(self) -> Iterable[entity.Entity]:
"""Return an iterable that returns all entities."""
return chain.from_iterable(
platform.entities.values() for platform in self._platforms.values()
@ -95,19 +96,23 @@ class EntityComponent:
def get_entity(self, entity_id: str) -> Optional[entity.Entity]:
"""Get an entity."""
for platform in self._platforms.values():
entity_obj = cast(Optional[entity.Entity], platform.entities.get(entity_id))
entity_obj = platform.entities.get(entity_id)
if entity_obj is not None:
return entity_obj
return None
def setup(self, config):
def setup(self, config: ConfigType) -> None:
"""Set up a full entity component.
This doesn't block the executor to protect from deadlocks.
"""
self.hass.add_job(self.async_setup(config))
self.hass.add_job(
self.async_setup( # type: ignore
config
)
)
async def async_setup(self, config):
async def async_setup(self, config: ConfigType) -> None:
"""Set up a full entity component.
Loads the platforms from the config and will listen for supported
@ -127,7 +132,9 @@ class EntityComponent:
# Generic discovery listener for loading platform dynamically
# Refer to: homeassistant.components.discovery.load_platform()
async def component_platform_discovered(platform, info):
async def component_platform_discovered(
platform: str, info: Optional[Dict[str, Any]]
) -> None:
"""Handle the loading of a platform."""
await self.async_setup_platform(platform, {}, info)
@ -135,7 +142,7 @@ class EntityComponent:
self.hass, self.domain, component_platform_discovered
)
async def async_setup_entry(self, config_entry):
async def async_setup_entry(self, config_entry: ConfigEntry) -> bool:
"""Set up a config entry."""
platform_type = config_entry.domain
platform = await async_prepare_setup_platform(
@ -161,7 +168,7 @@ class EntityComponent:
scan_interval=getattr(platform, "SCAN_INTERVAL", None),
)
return await self._platforms[key].async_setup_entry(config_entry)
return await self._platforms[key].async_setup_entry(config_entry) # type: ignore
async def async_unload_entry(self, config_entry: ConfigEntry) -> bool:
"""Unload a config entry."""
@ -175,24 +182,32 @@ class EntityComponent:
await platform.async_reset()
return True
async def async_extract_from_service(self, service_call, expand_group=True):
async def async_extract_from_service(
self, service_call: ServiceCall, expand_group: bool = True
) -> List[entity.Entity]:
"""Extract all known and available entities from a service call.
Will return an empty list if entities specified but unknown.
This method must be run in the event loop.
"""
return await service.async_extract_entities(
return await service.async_extract_entities( # type: ignore
self.hass, self.entities, service_call, expand_group
)
@callback
def async_register_entity_service(self, name, schema, func, required_features=None):
def async_register_entity_service(
self,
name: str,
schema: Union[Dict[str, Any], vol.Schema],
func: str,
required_features: Optional[int] = None,
) -> None:
"""Register an entity service."""
if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)
async def handle_service(call):
async def handle_service(call: Callable) -> None:
"""Handle the service."""
await self.hass.helpers.service.entity_service_call(
self._platforms.values(), func, call, required_features
@ -201,8 +216,11 @@ class EntityComponent:
self.hass.services.async_register(self.domain, name, handle_service, schema)
async def async_setup_platform(
self, platform_type, platform_config, discovery_info=None
):
self,
platform_type: str,
platform_config: ConfigType,
discovery_info: Optional[DiscoveryInfoType] = None,
) -> None:
"""Set up a platform for this component."""
if self.config is None:
raise RuntimeError("async_setup needs to be called first")
@ -227,7 +245,9 @@ class EntityComponent:
platform_type, platform, scan_interval, entity_namespace
)
await self._platforms[key].async_setup(platform_config, discovery_info)
await self._platforms[key].async_setup( # type: ignore
platform_config, discovery_info
)
async def _async_reset(self) -> None:
"""Remove entities and reset the entity component to initial values.
@ -285,7 +305,7 @@ class EntityComponent:
if scan_interval is None:
scan_interval = self.scan_interval
return EntityPlatform( # type: ignore
return EntityPlatform(
hass=self.hass,
logger=self.logger,
domain=self.domain,

View file

@ -1,18 +1,24 @@
"""Class to manage the entities for a single platform."""
import asyncio
from contextvars import ContextVar
from datetime import datetime
from typing import Optional
from datetime import datetime, timedelta
from logging import Logger
from types import ModuleType
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, cast
from homeassistant.const import DEVICE_DEFAULT_NAME
from homeassistant.core import callback, split_entity_id, valid_entity_id
from homeassistant.core import CALLBACK_TYPE, callback, split_entity_id, valid_entity_id
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.helpers import config_validation as cv, service
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.util.async_ import run_callback_threadsafe
from .entity_registry import DISABLED_INTEGRATION
from .event import async_call_later, async_track_time_interval
if TYPE_CHECKING:
from .entity import Entity
# mypy: allow-untyped-defs, no-check-untyped-defs
SLOW_SETUP_WARNING = 10
@ -26,23 +32,15 @@ class EntityPlatform:
def __init__(
self,
*,
hass,
logger,
domain,
platform_name,
platform,
scan_interval,
entity_namespace,
hass: HomeAssistantType,
logger: Logger,
domain: str,
platform_name: str,
platform: Optional[ModuleType],
scan_interval: timedelta,
entity_namespace: Optional[str],
):
"""Initialize the entity platform.
hass: HomeAssistant
logger: Logger
domain: str
platform_name: str
scan_interval: timedelta
entity_namespace: str
"""
"""Initialize the entity platform."""
self.hass = hass
self.logger = logger
self.domain = domain
@ -51,13 +49,13 @@ class EntityPlatform:
self.scan_interval = scan_interval
self.entity_namespace = entity_namespace
self.config_entry = None
self.entities = {}
self._tasks = []
self.entities: Dict[str, Entity] = {} # pylint: disable=used-before-assignment
self._tasks: List[asyncio.Future] = []
# Method to cancel the state change listener
self._async_unsub_polling = None
self._async_unsub_polling: Optional[CALLBACK_TYPE] = None
# Method to cancel the retry of setup
self._async_cancel_retry_setup = None
self._process_updates = None
self._async_cancel_retry_setup: Optional[CALLBACK_TYPE] = None
self._process_updates: Optional[asyncio.Lock] = None
# Platform is None for the EntityComponent "catch-all" EntityPlatform
# which powers entity_component.add_entities
@ -224,7 +222,9 @@ class EntityPlatform:
finally:
warn_task.cancel()
def _schedule_add_entities(self, new_entities, update_before_add=False):
def _schedule_add_entities(
self, new_entities: Iterable["Entity"], update_before_add: bool = False
) -> None:
"""Schedule adding entities for a single platform, synchronously."""
run_callback_threadsafe(
self.hass.loop,
@ -234,17 +234,24 @@ class EntityPlatform:
).result()
@callback
def _async_schedule_add_entities(self, new_entities, update_before_add=False):
def _async_schedule_add_entities(
self, new_entities: Iterable["Entity"], update_before_add: bool = False
) -> None:
"""Schedule adding entities for a single platform async."""
self._tasks.append(
cast(
asyncio.Future,
self.hass.async_add_job(
self.async_add_entities(
self.async_add_entities( # type: ignore
new_entities, update_before_add=update_before_add
)
),
),
)
)
def add_entities(self, new_entities, update_before_add=False):
def add_entities(
self, new_entities: Iterable["Entity"], update_before_add: bool = False
) -> None:
"""Add entities for a single platform."""
# That avoid deadlocks
if update_before_add:
@ -258,7 +265,9 @@ class EntityPlatform:
self.hass.loop,
).result()
async def async_add_entities(self, new_entities, update_before_add=False):
async def async_add_entities(
self, new_entities: Iterable["Entity"], update_before_add: bool = False
) -> None:
"""Add entities for a single platform async.
This method must be run in the event loop.
@ -272,7 +281,7 @@ class EntityPlatform:
device_registry = await hass.helpers.device_registry.async_get_registry()
entity_registry = await hass.helpers.entity_registry.async_get_registry()
tasks = [
self._async_add_entity(
self._async_add_entity( # type: ignore
entity, update_before_add, entity_registry, device_registry
)
for entity in new_entities
@ -290,7 +299,9 @@ class EntityPlatform:
return
self._async_unsub_polling = async_track_time_interval(
self.hass, self._update_entity_states, self.scan_interval
self.hass,
self._update_entity_states, # type: ignore
self.scan_interval,
)
async def _async_add_entity(
@ -515,7 +526,7 @@ class EntityPlatform:
for entity in self.entities.values():
if not entity.should_poll:
continue
tasks.append(entity.async_update_ha_state(True))
tasks.append(entity.async_update_ha_state(True)) # type: ignore
if tasks:
await asyncio.wait(tasks)

View file

@ -8,6 +8,7 @@ import homeassistant.core
GPSType = Tuple[float, float]
ConfigType = Dict[str, Any]
ContextType = homeassistant.core.Context
DiscoveryInfoType = Dict[str, Any]
EventType = homeassistant.core.Event
HomeAssistantType = homeassistant.core.HomeAssistant
ServiceCallType = homeassistant.core.ServiceCall