Refactor event time trackers to avoid using nonlocal (#107997)

This commit is contained in:
J. Nick Koston 2024-01-13 17:17:55 -10:00 committed by GitHub
parent e7c25d1c36
commit 659ee51914
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -10,7 +10,7 @@ import functools as ft
import logging
from random import randint
import time
from typing import Any, Concatenate, ParamSpec, TypedDict, TypeVar
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypedDict, TypeVar
import attr
@ -1389,6 +1389,45 @@ def async_track_point_in_time(
track_point_in_time = threaded_listener_factory(async_track_point_in_time)
@dataclass(slots=True)
class _TrackPointUTCTime:
hass: HomeAssistant
job: HassJob[[datetime], Coroutine[Any, Any, None] | None]
utc_point_in_time: datetime
expected_fire_timestamp: float
_cancel_callback: asyncio.TimerHandle | None = None
def async_attach(self) -> None:
"""Initialize track job."""
loop = self.hass.loop
self._cancel_callback = loop.call_at(
loop.time() + self.expected_fire_timestamp - time.time(), self._run_action
)
@callback
def _run_action(self) -> None:
"""Call the action."""
# Depending on the available clock support (including timer hardware
# and the OS kernel) it can happen that we fire a little bit too early
# as measured by utcnow(). That is bad when callbacks have assumptions
# about the current time. Thus, we rearm the timer for the remaining
# time.
if (delta := (self.expected_fire_timestamp - time_tracker_timestamp())) > 0:
_LOGGER.debug("Called %f seconds too early, rearming", delta)
loop = self.hass.loop
self._cancel_callback = loop.call_at(loop.time() + delta, self._run_action)
return
self.hass.async_run_hass_job(self.job, self.utc_point_in_time)
@callback
def async_cancel(self) -> None:
"""Cancel the call_at."""
if TYPE_CHECKING:
assert self._cancel_callback is not None
self._cancel_callback.cancel()
@callback
@bind_hass
def async_track_point_in_utc_time(
@ -1404,44 +1443,14 @@ def async_track_point_in_utc_time(
# Ensure point_in_time is UTC
utc_point_in_time = dt_util.as_utc(point_in_time)
expected_fire_timestamp = dt_util.utc_to_timestamp(utc_point_in_time)
# Since this is called once, we accept a HassJob so we can avoid
# having to figure out how to call the action every time its called.
cancel_callback: asyncio.TimerHandle | None = None
loop = hass.loop
@callback
def run_action(job: HassJob[[datetime], Coroutine[Any, Any, None] | None]) -> None:
"""Call the action."""
nonlocal cancel_callback
# Depending on the available clock support (including timer hardware
# and the OS kernel) it can happen that we fire a little bit too early
# as measured by utcnow(). That is bad when callbacks have assumptions
# about the current time. Thus, we rearm the timer for the remaining
# time.
if (delta := (expected_fire_timestamp - time_tracker_timestamp())) > 0:
_LOGGER.debug("Called %f seconds too early, rearming", delta)
cancel_callback = loop.call_at(loop.time() + delta, run_action, job)
return
hass.async_run_hass_job(job, utc_point_in_time)
job = (
action
if isinstance(action, HassJob)
else HassJob(action, f"track point in utc time {utc_point_in_time}")
)
delta = expected_fire_timestamp - time.time()
cancel_callback = loop.call_at(loop.time() + delta, run_action, job)
@callback
def unsub_point_in_time_listener() -> None:
"""Cancel the call_at."""
assert cancel_callback is not None
cancel_callback.cancel()
return unsub_point_in_time_listener
track = _TrackPointUTCTime(hass, job, utc_point_in_time, expected_fire_timestamp)
track.async_attach()
return track.async_cancel
track_point_in_utc_time = threaded_listener_factory(async_track_point_in_utc_time)
@ -1500,6 +1509,61 @@ def async_call_later(
call_later = threaded_listener_factory(async_call_later)
@dataclass(slots=True)
class _TrackTimeInterval:
"""Helper class to help listen to time interval events."""
hass: HomeAssistant
seconds: float
job_name: str
action: Callable[[datetime], Coroutine[Any, Any, None] | None]
cancel_on_shutdown: bool | None
_track_job: HassJob[[datetime], Coroutine[Any, Any, None] | None] | None = None
_run_job: HassJob[[datetime], Coroutine[Any, Any, None] | None] | None = None
_cancel_callback: CALLBACK_TYPE | None = None
def async_attach(self) -> None:
"""Initialize track job."""
hass = self.hass
self._track_job = HassJob(
self._interval_listener,
self.job_name,
job_type=HassJobType.Callback,
cancel_on_shutdown=self.cancel_on_shutdown,
)
self._run_job = HassJob(
self.action,
f"track time interval {self.seconds}",
cancel_on_shutdown=self.cancel_on_shutdown,
)
self._cancel_callback = async_call_at(
hass,
self._track_job,
hass.loop.time() + self.seconds,
)
@callback
def _interval_listener(self, now: datetime) -> None:
"""Handle elapsed intervals."""
if TYPE_CHECKING:
assert self._run_job is not None
assert self._track_job is not None
hass = self.hass
self._cancel_callback = async_call_at(
hass,
self._track_job,
hass.loop.time() + self.seconds,
)
hass.async_run_hass_job(self._run_job, now)
@callback
def async_cancel(self) -> None:
"""Cancel the call_at."""
if TYPE_CHECKING:
assert self._cancel_callback is not None
self._cancel_callback()
@callback
@bind_hass
def async_track_time_interval(
@ -1514,41 +1578,13 @@ def async_track_time_interval(
The listener is passed the time it fires in UTC time.
"""
remove: CALLBACK_TYPE
interval_listener_job: HassJob[[datetime], None]
interval_seconds = interval.total_seconds()
job = HassJob(
action, f"track time interval {interval}", cancel_on_shutdown=cancel_on_shutdown
)
@callback
def interval_listener(now: datetime) -> None:
"""Handle elapsed intervals."""
nonlocal remove
nonlocal interval_listener_job
remove = async_call_later(hass, interval_seconds, interval_listener_job)
hass.async_run_hass_job(job, now)
seconds = interval.total_seconds()
job_name = f"track time interval {seconds} {action}"
if name:
job_name = f"{name}: track time interval {interval} {action}"
else:
job_name = f"track time interval {interval} {action}"
interval_listener_job = HassJob(
interval_listener,
job_name,
cancel_on_shutdown=cancel_on_shutdown,
job_type=HassJobType.Callback,
)
remove = async_call_later(hass, interval_seconds, interval_listener_job)
def remove_listener() -> None:
"""Remove interval listener."""
remove()
return remove_listener
job_name = f"{name}: {job_name}"
track = _TrackTimeInterval(hass, seconds, job_name, action, cancel_on_shutdown)
track.async_attach()
return track.async_cancel
track_time_interval = threaded_listener_factory(async_track_time_interval)