Additional strict typing for additional recorder internals (#68689)
* Strict typing for additional recorder internals * revert * fix refactoring error
This commit is contained in:
parent
05ddd773ff
commit
d7634d1cb1
5 changed files with 84 additions and 42 deletions
|
@ -164,11 +164,14 @@ homeassistant.components.pure_energie.*
|
||||||
homeassistant.components.rainmachine.*
|
homeassistant.components.rainmachine.*
|
||||||
homeassistant.components.rdw.*
|
homeassistant.components.rdw.*
|
||||||
homeassistant.components.recollect_waste.*
|
homeassistant.components.recollect_waste.*
|
||||||
|
homeassistant.components.recorder.models
|
||||||
homeassistant.components.recorder.history
|
homeassistant.components.recorder.history
|
||||||
|
homeassistant.components.recorder.pool
|
||||||
homeassistant.components.recorder.purge
|
homeassistant.components.recorder.purge
|
||||||
homeassistant.components.recorder.repack
|
homeassistant.components.recorder.repack
|
||||||
homeassistant.components.recorder.statistics
|
homeassistant.components.recorder.statistics
|
||||||
homeassistant.components.recorder.util
|
homeassistant.components.recorder.util
|
||||||
|
homeassistant.components.recorder.websocket_api
|
||||||
homeassistant.components.remote.*
|
homeassistant.components.remote.*
|
||||||
homeassistant.components.renault.*
|
homeassistant.components.renault.*
|
||||||
homeassistant.components.ridwell.*
|
homeassistant.components.ridwell.*
|
||||||
|
|
|
@ -232,8 +232,7 @@ def run_information(hass, point_in_time: datetime | None = None) -> RecorderRuns
|
||||||
|
|
||||||
There is also the run that covers point_in_time.
|
There is also the run that covers point_in_time.
|
||||||
"""
|
"""
|
||||||
run_info = run_information_from_instance(hass, point_in_time)
|
if run_info := run_information_from_instance(hass, point_in_time):
|
||||||
if run_info:
|
|
||||||
return run_info
|
return run_info
|
||||||
|
|
||||||
with session_scope(hass=hass) as session:
|
with session_scope(hass=hass) as session:
|
||||||
|
@ -1028,8 +1027,7 @@ class Recorder(threading.Thread):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if event.event_type == EVENT_STATE_CHANGED:
|
if event.event_type == EVENT_STATE_CHANGED:
|
||||||
dbevent = Events.from_event(event, event_data="{}")
|
dbevent = Events.from_event(event, event_data=None)
|
||||||
dbevent.event_data = None
|
|
||||||
else:
|
else:
|
||||||
dbevent = Events.from_event(event)
|
dbevent = Events.from_event(event)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, TypedDict, overload
|
from typing import Any, TypedDict, cast, overload
|
||||||
|
|
||||||
from fnvhash import fnv1a_32
|
from fnvhash import fnv1a_32
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
|
@ -35,6 +35,7 @@ from homeassistant.const import (
|
||||||
MAX_LENGTH_STATE_STATE,
|
MAX_LENGTH_STATE_STATE,
|
||||||
)
|
)
|
||||||
from homeassistant.core import Context, Event, EventOrigin, State
|
from homeassistant.core import Context, Event, EventOrigin, State
|
||||||
|
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
|
||||||
from .const import JSON_DUMP
|
from .const import JSON_DUMP
|
||||||
|
@ -113,11 +114,13 @@ class Events(Base): # type: ignore[misc,valid-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_event(event, event_data=None):
|
def from_event(
|
||||||
|
event: Event, event_data: UndefinedType | None = UNDEFINED
|
||||||
|
) -> Events:
|
||||||
"""Create an event database object from a native event."""
|
"""Create an event database object from a native event."""
|
||||||
return Events(
|
return Events(
|
||||||
event_type=event.event_type,
|
event_type=event.event_type,
|
||||||
event_data=event_data or JSON_DUMP(event.data),
|
event_data=JSON_DUMP(event.data) if event_data is UNDEFINED else event_data,
|
||||||
origin=str(event.origin.value),
|
origin=str(event.origin.value),
|
||||||
time_fired=event.time_fired,
|
time_fired=event.time_fired,
|
||||||
context_id=event.context.id,
|
context_id=event.context.id,
|
||||||
|
@ -125,7 +128,7 @@ class Events(Base): # type: ignore[misc,valid-type]
|
||||||
context_parent_id=event.context.parent_id,
|
context_parent_id=event.context.parent_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_native(self, validate_entity_id=True):
|
def to_native(self, validate_entity_id: bool = True) -> Event | None:
|
||||||
"""Convert to a native HA Event."""
|
"""Convert to a native HA Event."""
|
||||||
context = Context(
|
context = Context(
|
||||||
id=self.context_id,
|
id=self.context_id,
|
||||||
|
@ -185,7 +188,7 @@ class States(Base): # type: ignore[misc,valid-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_event(event) -> States:
|
def from_event(event: Event) -> States:
|
||||||
"""Create object from a state_changed event."""
|
"""Create object from a state_changed event."""
|
||||||
entity_id = event.data["entity_id"]
|
entity_id = event.data["entity_id"]
|
||||||
state: State | None = event.data.get("new_state")
|
state: State | None = event.data.get("new_state")
|
||||||
|
@ -266,12 +269,12 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def hash_shared_attrs(shared_attrs: str) -> int:
|
def hash_shared_attrs(shared_attrs: str) -> int:
|
||||||
"""Return the hash of json encoded shared attributes."""
|
"""Return the hash of json encoded shared attributes."""
|
||||||
return fnv1a_32(shared_attrs.encode("utf-8"))
|
return cast(int, fnv1a_32(shared_attrs.encode("utf-8")))
|
||||||
|
|
||||||
def to_native(self) -> dict[str, Any]:
|
def to_native(self) -> dict[str, Any]:
|
||||||
"""Convert to an HA state object."""
|
"""Convert to an HA state object."""
|
||||||
try:
|
try:
|
||||||
return json.loads(self.shared_attrs)
|
return cast(dict[str, Any], json.loads(self.shared_attrs))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# When json.loads fails
|
# When json.loads fails
|
||||||
_LOGGER.exception("Error converting row to state attributes: %s", self)
|
_LOGGER.exception("Error converting row to state attributes: %s", self)
|
||||||
|
@ -311,8 +314,8 @@ class StatisticsBase:
|
||||||
id = Column(Integer, Identity(), primary_key=True)
|
id = Column(Integer, Identity(), primary_key=True)
|
||||||
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
|
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
|
||||||
|
|
||||||
@declared_attr
|
@declared_attr # type: ignore[misc]
|
||||||
def metadata_id(self):
|
def metadata_id(self) -> Column:
|
||||||
"""Define the metadata_id column for sub classes."""
|
"""Define the metadata_id column for sub classes."""
|
||||||
return Column(
|
return Column(
|
||||||
Integer,
|
Integer,
|
||||||
|
@ -329,7 +332,7 @@ class StatisticsBase:
|
||||||
sum = Column(DOUBLE_TYPE)
|
sum = Column(DOUBLE_TYPE)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_stats(cls, metadata_id: int, stats: StatisticData):
|
def from_stats(cls, metadata_id: int, stats: StatisticData) -> StatisticsBase:
|
||||||
"""Create object from a statistics."""
|
"""Create object from a statistics."""
|
||||||
return cls( # type: ignore[call-arg,misc]
|
return cls( # type: ignore[call-arg,misc]
|
||||||
metadata_id=metadata_id,
|
metadata_id=metadata_id,
|
||||||
|
@ -422,7 +425,7 @@ class RecorderRuns(Base): # type: ignore[misc,valid-type]
|
||||||
f")>"
|
f")>"
|
||||||
)
|
)
|
||||||
|
|
||||||
def entity_ids(self, point_in_time=None):
|
def entity_ids(self, point_in_time: datetime | None = None) -> list[str]:
|
||||||
"""Return the entity ids that existed in this run.
|
"""Return the entity ids that existed in this run.
|
||||||
|
|
||||||
Specify point_in_time if you want to know which existed at that point
|
Specify point_in_time if you want to know which existed at that point
|
||||||
|
@ -443,7 +446,7 @@ class RecorderRuns(Base): # type: ignore[misc,valid-type]
|
||||||
|
|
||||||
return [row[0] for row in query]
|
return [row[0] for row in query]
|
||||||
|
|
||||||
def to_native(self, validate_entity_id=True):
|
def to_native(self, validate_entity_id: bool = True) -> RecorderRuns:
|
||||||
"""Return self, native format is this model."""
|
"""Return self, native format is this model."""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -540,16 +543,16 @@ class LazyState(State):
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Init the lazy state."""
|
"""Init the lazy state."""
|
||||||
self._row = row
|
self._row = row
|
||||||
self.entity_id = self._row.entity_id
|
self.entity_id: str = self._row.entity_id
|
||||||
self.state = self._row.state or ""
|
self.state = self._row.state or ""
|
||||||
self._attributes = None
|
self._attributes: dict[str, Any] | None = None
|
||||||
self._last_changed = None
|
self._last_changed: datetime | None = None
|
||||||
self._last_updated = None
|
self._last_updated: datetime | None = None
|
||||||
self._context = None
|
self._context: Context | None = None
|
||||||
self._attr_cache = attr_cache
|
self._attr_cache = attr_cache
|
||||||
|
|
||||||
@property # type: ignore[override]
|
@property # type: ignore[override]
|
||||||
def attributes(self):
|
def attributes(self) -> dict[str, Any]: # type: ignore[override]
|
||||||
"""State attributes."""
|
"""State attributes."""
|
||||||
if self._attributes is None:
|
if self._attributes is None:
|
||||||
source = self._row.shared_attrs or self._row.attributes
|
source = self._row.shared_attrs or self._row.attributes
|
||||||
|
@ -574,47 +577,47 @@ class LazyState(State):
|
||||||
return self._attributes
|
return self._attributes
|
||||||
|
|
||||||
@attributes.setter
|
@attributes.setter
|
||||||
def attributes(self, value):
|
def attributes(self, value: dict[str, Any]) -> None:
|
||||||
"""Set attributes."""
|
"""Set attributes."""
|
||||||
self._attributes = value
|
self._attributes = value
|
||||||
|
|
||||||
@property # type: ignore[override]
|
@property # type: ignore[override]
|
||||||
def context(self):
|
def context(self) -> Context: # type: ignore[override]
|
||||||
"""State context."""
|
"""State context."""
|
||||||
if not self._context:
|
if self._context is None:
|
||||||
self._context = Context(id=None)
|
self._context = Context(id=None) # type: ignore[arg-type]
|
||||||
return self._context
|
return self._context
|
||||||
|
|
||||||
@context.setter
|
@context.setter
|
||||||
def context(self, value):
|
def context(self, value: Context) -> None:
|
||||||
"""Set context."""
|
"""Set context."""
|
||||||
self._context = value
|
self._context = value
|
||||||
|
|
||||||
@property # type: ignore[override]
|
@property # type: ignore[override]
|
||||||
def last_changed(self):
|
def last_changed(self) -> datetime: # type: ignore[override]
|
||||||
"""Last changed datetime."""
|
"""Last changed datetime."""
|
||||||
if not self._last_changed:
|
if self._last_changed is None:
|
||||||
self._last_changed = process_timestamp(self._row.last_changed)
|
self._last_changed = process_timestamp(self._row.last_changed)
|
||||||
return self._last_changed
|
return self._last_changed
|
||||||
|
|
||||||
@last_changed.setter
|
@last_changed.setter
|
||||||
def last_changed(self, value):
|
def last_changed(self, value: datetime) -> None:
|
||||||
"""Set last changed datetime."""
|
"""Set last changed datetime."""
|
||||||
self._last_changed = value
|
self._last_changed = value
|
||||||
|
|
||||||
@property # type: ignore[override]
|
@property # type: ignore[override]
|
||||||
def last_updated(self):
|
def last_updated(self) -> datetime: # type: ignore[override]
|
||||||
"""Last updated datetime."""
|
"""Last updated datetime."""
|
||||||
if not self._last_updated:
|
if self._last_updated is None:
|
||||||
self._last_updated = process_timestamp(self._row.last_updated)
|
self._last_updated = process_timestamp(self._row.last_updated)
|
||||||
return self._last_updated
|
return self._last_updated
|
||||||
|
|
||||||
@last_updated.setter
|
@last_updated.setter
|
||||||
def last_updated(self, value):
|
def last_updated(self, value: datetime) -> None:
|
||||||
"""Set last updated datetime."""
|
"""Set last updated datetime."""
|
||||||
self._last_updated = value
|
self._last_updated = value
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self) -> dict[str, Any]: # type: ignore[override]
|
||||||
"""Return a dict representation of the LazyState.
|
"""Return a dict representation of the LazyState.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
|
@ -645,7 +648,7 @@ class LazyState(State):
|
||||||
"last_updated": last_updated_isoformat,
|
"last_updated": last_updated_isoformat,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Return the comparison."""
|
"""Return the comparison."""
|
||||||
return (
|
return (
|
||||||
other.__class__ in [self.__class__, State]
|
other.__class__ in [self.__class__, State]
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""A pool for sqlite connections."""
|
"""A pool for sqlite connections."""
|
||||||
import threading
|
import threading
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy.pool import NullPool, SingletonThreadPool
|
from sqlalchemy.pool import NullPool, SingletonThreadPool
|
||||||
|
|
||||||
|
@ -10,14 +11,16 @@ from .const import DB_WORKER_PREFIX
|
||||||
POOL_SIZE = 5
|
POOL_SIZE = 5
|
||||||
|
|
||||||
|
|
||||||
class RecorderPool(SingletonThreadPool, NullPool):
|
class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
|
||||||
"""A hybrid of NullPool and SingletonThreadPool.
|
"""A hybrid of NullPool and SingletonThreadPool.
|
||||||
|
|
||||||
When called from the creating thread or db executor acts like SingletonThreadPool
|
When called from the creating thread or db executor acts like SingletonThreadPool
|
||||||
When called from any other thread, acts like NullPool
|
When called from any other thread, acts like NullPool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kw): # pylint: disable=super-init-not-called
|
def __init__( # pylint: disable=super-init-not-called
|
||||||
|
self, *args: Any, **kw: Any
|
||||||
|
) -> None:
|
||||||
"""Create the pool."""
|
"""Create the pool."""
|
||||||
kw["pool_size"] = POOL_SIZE
|
kw["pool_size"] = POOL_SIZE
|
||||||
SingletonThreadPool.__init__(self, *args, **kw)
|
SingletonThreadPool.__init__(self, *args, **kw)
|
||||||
|
@ -30,22 +33,24 @@ class RecorderPool(SingletonThreadPool, NullPool):
|
||||||
thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX)
|
thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _do_return_conn(self, conn):
|
# Any can be switched out for ConnectionPoolEntry in the next version of sqlalchemy
|
||||||
|
def _do_return_conn(self, conn: Any) -> Any:
|
||||||
if self.recorder_or_dbworker:
|
if self.recorder_or_dbworker:
|
||||||
return super()._do_return_conn(conn)
|
return super()._do_return_conn(conn)
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self) -> None:
|
||||||
"""Close the connection."""
|
"""Close the connection."""
|
||||||
if self.recorder_or_dbworker and self._conn and (conn := self._conn.current()):
|
if self.recorder_or_dbworker and self._conn and (conn := self._conn.current()):
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def dispose(self):
|
def dispose(self) -> None:
|
||||||
"""Dispose of the connection."""
|
"""Dispose of the connection."""
|
||||||
if self.recorder_or_dbworker:
|
if self.recorder_or_dbworker:
|
||||||
return super().dispose()
|
super().dispose()
|
||||||
|
|
||||||
def _do_get(self):
|
# Any can be switched out for ConnectionPoolEntry in the next version of sqlalchemy
|
||||||
|
def _do_get(self) -> Any:
|
||||||
if self.recorder_or_dbworker:
|
if self.recorder_or_dbworker:
|
||||||
return super()._do_get()
|
return super()._do_get()
|
||||||
report(
|
report(
|
||||||
|
|
33
mypy.ini
33
mypy.ini
|
@ -1606,6 +1606,17 @@ no_implicit_optional = true
|
||||||
warn_return_any = true
|
warn_return_any = true
|
||||||
warn_unreachable = true
|
warn_unreachable = true
|
||||||
|
|
||||||
|
[mypy-homeassistant.components.recorder.models]
|
||||||
|
check_untyped_defs = true
|
||||||
|
disallow_incomplete_defs = true
|
||||||
|
disallow_subclassing_any = true
|
||||||
|
disallow_untyped_calls = true
|
||||||
|
disallow_untyped_decorators = true
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
no_implicit_optional = true
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unreachable = true
|
||||||
|
|
||||||
[mypy-homeassistant.components.recorder.history]
|
[mypy-homeassistant.components.recorder.history]
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
disallow_incomplete_defs = true
|
disallow_incomplete_defs = true
|
||||||
|
@ -1617,6 +1628,17 @@ no_implicit_optional = true
|
||||||
warn_return_any = true
|
warn_return_any = true
|
||||||
warn_unreachable = true
|
warn_unreachable = true
|
||||||
|
|
||||||
|
[mypy-homeassistant.components.recorder.pool]
|
||||||
|
check_untyped_defs = true
|
||||||
|
disallow_incomplete_defs = true
|
||||||
|
disallow_subclassing_any = true
|
||||||
|
disallow_untyped_calls = true
|
||||||
|
disallow_untyped_decorators = true
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
no_implicit_optional = true
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unreachable = true
|
||||||
|
|
||||||
[mypy-homeassistant.components.recorder.purge]
|
[mypy-homeassistant.components.recorder.purge]
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
disallow_incomplete_defs = true
|
disallow_incomplete_defs = true
|
||||||
|
@ -1661,6 +1683,17 @@ no_implicit_optional = true
|
||||||
warn_return_any = true
|
warn_return_any = true
|
||||||
warn_unreachable = true
|
warn_unreachable = true
|
||||||
|
|
||||||
|
[mypy-homeassistant.components.recorder.websocket_api]
|
||||||
|
check_untyped_defs = true
|
||||||
|
disallow_incomplete_defs = true
|
||||||
|
disallow_subclassing_any = true
|
||||||
|
disallow_untyped_calls = true
|
||||||
|
disallow_untyped_decorators = true
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
no_implicit_optional = true
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unreachable = true
|
||||||
|
|
||||||
[mypy-homeassistant.components.remote.*]
|
[mypy-homeassistant.components.remote.*]
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
disallow_incomplete_defs = true
|
disallow_incomplete_defs = true
|
||||||
|
|
Loading…
Add table
Reference in a new issue