Add strict typing to core.py (5) - Task (#63243)

This commit is contained in:
Marc Mueller 2022-01-07 16:48:34 +01:00 committed by GitHub
parent ae3ff0a8ce
commit 54cb1e7556
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 115 additions and 25 deletions

View file

@ -76,7 +76,7 @@ async def async_attach_trigger(
job, job,
{ {
"trigger": { "trigger": {
**trigger_data, **trigger_data, # type: ignore # https://github.com/python/mypy/issues/9117
**config, **config,
"description": f"{DOMAIN} - {entity_id}", "description": f"{DOMAIN} - {entity_id}",
} }

View file

@ -23,10 +23,12 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Generic,
NamedTuple, NamedTuple,
Optional, Optional,
TypeVar, TypeVar,
cast, cast,
overload,
) )
from urllib.parse import urlparse from urllib.parse import urlparse
@ -94,6 +96,8 @@ async_timeout_backcompat.enable()
block_async_io.enable() block_async_io.enable()
T = TypeVar("T") T = TypeVar("T")
_R = TypeVar("_R")
_R_co = TypeVar("_R_co", covariant=True) # pylint: disable=invalid-name
# Internal; not helpers.typing.UNDEFINED due to circular dependency # Internal; not helpers.typing.UNDEFINED due to circular dependency
_UNDEF: dict[Any, Any] = {} _UNDEF: dict[Any, Any] = {}
# pylint: disable=invalid-name # pylint: disable=invalid-name
@ -174,7 +178,7 @@ class HassJobType(enum.Enum):
Executor = 3 Executor = 3
class HassJob: class HassJob(Generic[_R_co]):
"""Represent a job to be run later. """Represent a job to be run later.
We check the callable type in advance We check the callable type in advance
@ -184,7 +188,7 @@ class HassJob:
__slots__ = ("job_type", "target") __slots__ = ("job_type", "target")
def __init__(self, target: Callable) -> None: def __init__(self, target: Callable[..., _R_co]) -> None:
"""Create a job object.""" """Create a job object."""
if asyncio.iscoroutine(target): if asyncio.iscoroutine(target):
raise ValueError("Coroutine not allowed to be passed to HassJob") raise ValueError("Coroutine not allowed to be passed to HassJob")
@ -197,7 +201,7 @@ class HassJob:
return f"<Job {self.job_type} {self.target}>" return f"<Job {self.job_type} {self.target}>"
def _get_callable_job_type(target: Callable) -> HassJobType: def _get_callable_job_type(target: Callable[..., Any]) -> HassJobType:
"""Determine the job type from the callable.""" """Determine the job type from the callable."""
# Check for partials to properly determine if coroutine function # Check for partials to properly determine if coroutine function
check_target = target check_target = target
@ -236,7 +240,8 @@ class HomeAssistant:
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize new Home Assistant object.""" """Initialize new Home Assistant object."""
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self._pending_tasks: list = [] # pylint: disable-next=unsubscriptable-object
self._pending_tasks: list[asyncio.Future[Any]] = []
self._track_task = True self._track_task = True
self.bus = EventBus(self) self.bus = EventBus(self)
self.services = ServiceRegistry(self) self.services = ServiceRegistry(self)
@ -354,10 +359,33 @@ class HomeAssistant:
raise ValueError("Don't call add_job with None") raise ValueError("Don't call add_job with None")
self.loop.call_soon_threadsafe(self.async_add_job, target, *args) self.loop.call_soon_threadsafe(self.async_add_job, target, *args)
@overload
@callback @callback
def async_add_job( def async_add_job(
self, target: Callable[..., Any], *args: Any self, target: Callable[..., Awaitable[_R]], *args: Any
) -> asyncio.Future | None: ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object
...
@overload
@callback
def async_add_job(
self, target: Callable[..., Awaitable[_R] | _R], *args: Any
) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object
...
@overload
@callback
def async_add_job(
self, target: Coroutine[Any, Any, _R], *args: Any
) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object
...
@callback
def async_add_job(
self,
target: Callable[..., Awaitable[_R] | _R] | Coroutine[Any, Any, _R],
*args: Any,
) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object
"""Add a job to be executed by the event loop or by an executor. """Add a job to be executed by the event loop or by an executor.
If the job is either a coroutine or decorated with @callback, it will be If the job is either a coroutine or decorated with @callback, it will be
@ -374,24 +402,44 @@ class HomeAssistant:
if asyncio.iscoroutine(target): if asyncio.iscoroutine(target):
return self.async_create_task(target) return self.async_create_task(target)
target = cast(Callable[..., _R], target)
return self.async_add_hass_job(HassJob(target), *args) return self.async_add_hass_job(HassJob(target), *args)
@overload
@callback @callback
def async_add_hass_job(self, hassjob: HassJob, *args: Any) -> asyncio.Future | None: def async_add_hass_job(
self, hassjob: HassJob[Awaitable[_R]], *args: Any
) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object
...
@overload
@callback
def async_add_hass_job(
self, hassjob: HassJob[Awaitable[_R] | _R], *args: Any
) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object
...
@callback
def async_add_hass_job(
self, hassjob: HassJob[Awaitable[_R] | _R], *args: Any
) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object
"""Add a HassJob from within the event loop. """Add a HassJob from within the event loop.
This method must be run in the event loop. This method must be run in the event loop.
hassjob: HassJob to call. hassjob: HassJob to call.
args: parameters for method to call. args: parameters for method to call.
""" """
task: asyncio.Future[_R] # pylint: disable=unsubscriptable-object
if hassjob.job_type == HassJobType.Coroutinefunction: if hassjob.job_type == HassJobType.Coroutinefunction:
task = self.loop.create_task(hassjob.target(*args)) task = self.loop.create_task(
cast(Callable[..., Awaitable[_R]], hassjob.target)(*args)
)
elif hassjob.job_type == HassJobType.Callback: elif hassjob.job_type == HassJobType.Callback:
self.loop.call_soon(hassjob.target, *args) self.loop.call_soon(hassjob.target, *args)
return None return None
else: else:
task = self.loop.run_in_executor( # type: ignore task = self.loop.run_in_executor(
None, hassjob.target, *args None, cast(Callable[..., _R], hassjob.target), *args
) )
# If a task is scheduled # If a task is scheduled
@ -400,7 +448,7 @@ class HomeAssistant:
return task return task
def create_task(self, target: Awaitable) -> None: def create_task(self, target: Awaitable[Any]) -> None:
"""Add task to the executor pool. """Add task to the executor pool.
target: target to call. target: target to call.
@ -408,14 +456,14 @@ class HomeAssistant:
self.loop.call_soon_threadsafe(self.async_create_task, target) self.loop.call_soon_threadsafe(self.async_create_task, target)
@callback @callback
def async_create_task(self, target: Awaitable) -> asyncio.Task: def async_create_task(self, target: Awaitable[_R]) -> asyncio.Task[_R]:
"""Create a task from within the eventloop. """Create a task from within the eventloop.
This method must be run in the event loop. This method must be run in the event loop.
target: target to call. target: target to call.
""" """
task: asyncio.Task = self.loop.create_task(target) task = self.loop.create_task(target)
if self._track_task: if self._track_task:
self._pending_tasks.append(task) self._pending_tasks.append(task)
@ -425,7 +473,7 @@ class HomeAssistant:
@callback @callback
def async_add_executor_job( def async_add_executor_job(
self, target: Callable[..., T], *args: Any self, target: Callable[..., T], *args: Any
) -> Awaitable[T]: ) -> asyncio.Future[T]: # pylint: disable=unsubscriptable-object
"""Add an executor job from within the event loop.""" """Add an executor job from within the event loop."""
task = self.loop.run_in_executor(None, target, *args) task = self.loop.run_in_executor(None, target, *args)
@ -445,8 +493,24 @@ class HomeAssistant:
"""Stop track tasks so you can't wait for all tasks to be done.""" """Stop track tasks so you can't wait for all tasks to be done."""
self._track_task = False self._track_task = False
@overload
@callback @callback
def async_run_hass_job(self, hassjob: HassJob, *args: Any) -> asyncio.Future | None: def async_run_hass_job(
self, hassjob: HassJob[Awaitable[_R]], *args: Any
) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object
...
@overload
@callback
def async_run_hass_job(
self, hassjob: HassJob[Awaitable[_R] | _R], *args: Any
) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object
...
@callback
def async_run_hass_job(
self, hassjob: HassJob[Awaitable[_R] | _R], *args: Any
) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object
"""Run a HassJob from within the event loop. """Run a HassJob from within the event loop.
This method must be run in the event loop. This method must be run in the event loop.
@ -455,15 +519,38 @@ class HomeAssistant:
args: parameters for method to call. args: parameters for method to call.
""" """
if hassjob.job_type == HassJobType.Callback: if hassjob.job_type == HassJobType.Callback:
hassjob.target(*args) cast(Callable[..., _R], hassjob.target)(*args)
return None return None
return self.async_add_hass_job(hassjob, *args) return self.async_add_hass_job(hassjob, *args)
@overload
@callback @callback
def async_run_job( def async_run_job(
self, target: Callable[..., None | Awaitable], *args: Any self, target: Callable[..., Awaitable[_R]], *args: Any
) -> asyncio.Future | None: ) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object
...
@overload
@callback
def async_run_job(
self, target: Callable[..., Awaitable[_R] | _R], *args: Any
) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object
...
@overload
@callback
def async_run_job(
self, target: Coroutine[Any, Any, _R], *args: Any
) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object
...
@callback
def async_run_job(
self,
target: Callable[..., Awaitable[_R] | _R] | Coroutine[Any, Any, _R],
*args: Any,
) -> asyncio.Future[_R] | None: # pylint: disable=unsubscriptable-object
"""Run a job from within the event loop. """Run a job from within the event loop.
This method must be run in the event loop. This method must be run in the event loop.
@ -474,6 +561,7 @@ class HomeAssistant:
if asyncio.iscoroutine(target): if asyncio.iscoroutine(target):
return self.async_create_task(target) return self.async_create_task(target)
target = cast(Callable[..., _R], target)
return self.async_run_hass_job(HassJob(target), *args) return self.async_run_hass_job(HassJob(target), *args)
def block_till_done(self) -> None: def block_till_done(self) -> None:
@ -685,7 +773,7 @@ class Event:
class _FilterableJob(NamedTuple): class _FilterableJob(NamedTuple):
"""Event listener job to be executed with optional filter.""" """Event listener job to be executed with optional filter."""
job: HassJob job: HassJob[None | Awaitable[None]]
event_filter: Callable[[Event], bool] | None event_filter: Callable[[Event], bool] | None

View file

@ -1,8 +1,9 @@
"""Helpers for data entry flows for config entries.""" """Helpers for data entry flows for config entries."""
from __future__ import annotations from __future__ import annotations
import asyncio # pylint: disable=unused-import # used in cast as string
import logging import logging
from typing import Any, Awaitable, Callable, Union from typing import Any, Awaitable, Callable, Union, cast
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import dhcp, mqtt, ssdp, zeroconf from homeassistant.components import dhcp, mqtt, ssdp, zeroconf
@ -55,9 +56,10 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow):
# Get current discovered entries. # Get current discovered entries.
in_progress = self._async_in_progress() in_progress = self._async_in_progress()
if not (has_devices := in_progress): if not (has_devices := bool(in_progress)):
has_devices = await self.hass.async_add_job( # type: ignore has_devices = await cast(
self._discovery_function, self.hass "asyncio.Future[bool]",
self.hass.async_add_job(self._discovery_function, self.hass),
) )
if not has_devices: if not has_devices:

View file

@ -77,7 +77,7 @@ async def async_setup_component(
) )
try: try:
return await task # type: ignore return await task
finally: finally:
if domain in hass.data.get(DATA_SETUP_DONE, {}): if domain in hass.data.get(DATA_SETUP_DONE, {}):
hass.data[DATA_SETUP_DONE].pop(domain).set() hass.data[DATA_SETUP_DONE].pop(domain).set()