From 54cb1e7556afa5e96ddbf0e4aba9d4a82f198ff3 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Fri, 7 Jan 2022 16:48:34 +0100 Subject: [PATCH] Add strict typing to `core.py` (5) - Task (#63243) --- .../components/arcam_fmj/device_trigger.py | 2 +- homeassistant/core.py | 126 +++++++++++++++--- homeassistant/helpers/config_entry_flow.py | 10 +- homeassistant/setup.py | 2 +- 4 files changed, 115 insertions(+), 25 deletions(-) diff --git a/homeassistant/components/arcam_fmj/device_trigger.py b/homeassistant/components/arcam_fmj/device_trigger.py index ed9308a89c6..b33710bf936 100644 --- a/homeassistant/components/arcam_fmj/device_trigger.py +++ b/homeassistant/components/arcam_fmj/device_trigger.py @@ -76,7 +76,7 @@ async def async_attach_trigger( job, { "trigger": { - **trigger_data, + **trigger_data, # type: ignore # https://github.com/python/mypy/issues/9117 **config, "description": f"{DOMAIN} - {entity_id}", } diff --git a/homeassistant/core.py b/homeassistant/core.py index 137f80659c9..fb0b244c23b 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -23,10 +23,12 @@ from typing import ( Any, Awaitable, Callable, + Generic, NamedTuple, Optional, TypeVar, cast, + overload, ) from urllib.parse import urlparse @@ -94,6 +96,8 @@ async_timeout_backcompat.enable() block_async_io.enable() T = TypeVar("T") +_R = TypeVar("_R") +_R_co = TypeVar("_R_co", covariant=True) # pylint: disable=invalid-name # Internal; not helpers.typing.UNDEFINED due to circular dependency _UNDEF: dict[Any, Any] = {} # pylint: disable=invalid-name @@ -174,7 +178,7 @@ class HassJobType(enum.Enum): Executor = 3 -class HassJob: +class HassJob(Generic[_R_co]): """Represent a job to be run later. We check the callable type in advance @@ -184,7 +188,7 @@ class HassJob: __slots__ = ("job_type", "target") - def __init__(self, target: Callable) -> None: + def __init__(self, target: Callable[..., _R_co]) -> None: """Create a job object.""" if asyncio.iscoroutine(target): raise ValueError("Coroutine not allowed to be passed to HassJob") @@ -197,7 +201,7 @@ class HassJob: return f"" -def _get_callable_job_type(target: Callable) -> HassJobType: +def _get_callable_job_type(target: Callable[..., Any]) -> HassJobType: """Determine the job type from the callable.""" # Check for partials to properly determine if coroutine function check_target = target @@ -236,7 +240,8 @@ class HomeAssistant: def __init__(self) -> None: """Initialize new Home Assistant object.""" self.loop = asyncio.get_running_loop() - self._pending_tasks: list = [] + # pylint: disable-next=unsubscriptable-object + self._pending_tasks: list[asyncio.Future[Any]] = [] self._track_task = True self.bus = EventBus(self) self.services = ServiceRegistry(self) @@ -354,10 +359,33 @@ class HomeAssistant: raise ValueError("Don't call add_job with None") self.loop.call_soon_threadsafe(self.async_add_job, target, *args) + @overload @callback def async_add_job( - self, target: Callable[..., Any], *args: Any - ) -> asyncio.Future | None: + self, target: Callable[..., Awaitable[_R]], *args: Any + ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object + ... + + @overload + @callback + def async_add_job( + self, target: Callable[..., Awaitable[_R] | _R], *args: Any + ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object + ... + + @overload + @callback + def async_add_job( + self, target: Coroutine[Any, Any, _R], *args: Any + ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object + ... + + @callback + def async_add_job( + self, + target: Callable[..., Awaitable[_R] | _R] | Coroutine[Any, Any, _R], + *args: Any, + ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object """Add a job to be executed by the event loop or by an executor. If the job is either a coroutine or decorated with @callback, it will be @@ -374,24 +402,44 @@ class HomeAssistant: if asyncio.iscoroutine(target): return self.async_create_task(target) + target = cast(Callable[..., _R], target) return self.async_add_hass_job(HassJob(target), *args) + @overload @callback - def async_add_hass_job(self, hassjob: HassJob, *args: Any) -> asyncio.Future | None: + def async_add_hass_job( + self, hassjob: HassJob[Awaitable[_R]], *args: Any + ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object + ... + + @overload + @callback + def async_add_hass_job( + self, hassjob: HassJob[Awaitable[_R] | _R], *args: Any + ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object + ... + + @callback + def async_add_hass_job( + self, hassjob: HassJob[Awaitable[_R] | _R], *args: Any + ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object """Add a HassJob from within the event loop. This method must be run in the event loop. hassjob: HassJob to call. args: parameters for method to call. """ + task: asyncio.Future[_R] # pylint: disable=unsubscriptable-object if hassjob.job_type == HassJobType.Coroutinefunction: - task = self.loop.create_task(hassjob.target(*args)) + task = self.loop.create_task( + cast(Callable[..., Awaitable[_R]], hassjob.target)(*args) + ) elif hassjob.job_type == HassJobType.Callback: self.loop.call_soon(hassjob.target, *args) return None else: - task = self.loop.run_in_executor( # type: ignore - None, hassjob.target, *args + task = self.loop.run_in_executor( + None, cast(Callable[..., _R], hassjob.target), *args ) # If a task is scheduled @@ -400,7 +448,7 @@ class HomeAssistant: return task - def create_task(self, target: Awaitable) -> None: + def create_task(self, target: Awaitable[Any]) -> None: """Add task to the executor pool. target: target to call. @@ -408,14 +456,14 @@ class HomeAssistant: self.loop.call_soon_threadsafe(self.async_create_task, target) @callback - def async_create_task(self, target: Awaitable) -> asyncio.Task: + def async_create_task(self, target: Awaitable[_R]) -> asyncio.Task[_R]: """Create a task from within the eventloop. This method must be run in the event loop. target: target to call. """ - task: asyncio.Task = self.loop.create_task(target) + task = self.loop.create_task(target) if self._track_task: self._pending_tasks.append(task) @@ -425,7 +473,7 @@ class HomeAssistant: @callback def async_add_executor_job( self, target: Callable[..., T], *args: Any - ) -> Awaitable[T]: + ) -> asyncio.Future[T]: # pylint: disable=unsubscriptable-object """Add an executor job from within the event loop.""" task = self.loop.run_in_executor(None, target, *args) @@ -445,8 +493,24 @@ class HomeAssistant: """Stop track tasks so you can't wait for all tasks to be done.""" self._track_task = False + @overload @callback - def async_run_hass_job(self, hassjob: HassJob, *args: Any) -> asyncio.Future | None: + def async_run_hass_job( + self, hassjob: HassJob[Awaitable[_R]], *args: Any + ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object + ... + + @overload + @callback + def async_run_hass_job( + self, hassjob: HassJob[Awaitable[_R] | _R], *args: Any + ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object + ... + + @callback + def async_run_hass_job( + self, hassjob: HassJob[Awaitable[_R] | _R], *args: Any + ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object """Run a HassJob from within the event loop. This method must be run in the event loop. @@ -455,15 +519,38 @@ class HomeAssistant: args: parameters for method to call. """ if hassjob.job_type == HassJobType.Callback: - hassjob.target(*args) + cast(Callable[..., _R], hassjob.target)(*args) return None return self.async_add_hass_job(hassjob, *args) + @overload @callback def async_run_job( - self, target: Callable[..., None | Awaitable], *args: Any - ) -> asyncio.Future | None: + self, target: Callable[..., Awaitable[_R]], *args: Any + ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object + ... + + @overload + @callback + def async_run_job( + self, target: Callable[..., Awaitable[_R] | _R], *args: Any + ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object + ... + + @overload + @callback + def async_run_job( + self, target: Coroutine[Any, Any, _R], *args: Any + ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object + ... + + @callback + def async_run_job( + self, + target: Callable[..., Awaitable[_R] | _R] | Coroutine[Any, Any, _R], + *args: Any, + ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object """Run a job from within the event loop. This method must be run in the event loop. @@ -474,6 +561,7 @@ class HomeAssistant: if asyncio.iscoroutine(target): return self.async_create_task(target) + target = cast(Callable[..., _R], target) return self.async_run_hass_job(HassJob(target), *args) def block_till_done(self) -> None: @@ -685,7 +773,7 @@ class Event: class _FilterableJob(NamedTuple): """Event listener job to be executed with optional filter.""" - job: HassJob + job: HassJob[None | Awaitable[None]] event_filter: Callable[[Event], bool] | None diff --git a/homeassistant/helpers/config_entry_flow.py b/homeassistant/helpers/config_entry_flow.py index a3e7ae4869d..73125b707ec 100644 --- a/homeassistant/helpers/config_entry_flow.py +++ b/homeassistant/helpers/config_entry_flow.py @@ -1,8 +1,9 @@ """Helpers for data entry flows for config entries.""" from __future__ import annotations +import asyncio # pylint: disable=unused-import # used in cast as string import logging -from typing import Any, Awaitable, Callable, Union +from typing import Any, Awaitable, Callable, Union, cast from homeassistant import config_entries from homeassistant.components import dhcp, mqtt, ssdp, zeroconf @@ -55,9 +56,10 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow): # Get current discovered entries. in_progress = self._async_in_progress() - if not (has_devices := in_progress): - has_devices = await self.hass.async_add_job( # type: ignore - self._discovery_function, self.hass + if not (has_devices := bool(in_progress)): + has_devices = await cast( + "asyncio.Future[bool]", + self.hass.async_add_job(self._discovery_function, self.hass), ) if not has_devices: diff --git a/homeassistant/setup.py b/homeassistant/setup.py index b509419bcb5..f694e7a79c6 100644 --- a/homeassistant/setup.py +++ b/homeassistant/setup.py @@ -77,7 +77,7 @@ async def async_setup_component( ) try: - return await task # type: ignore + return await task finally: if domain in hass.data.get(DATA_SETUP_DONE, {}): hass.data[DATA_SETUP_DONE].pop(domain).set()