Additional strict typing for additional recorder internals (#68689)

* Strict typing for additional recorder internals

* revert

* fix refactoring error
This commit is contained in:
J. Nick Koston 2022-03-28 21:45:25 -10:00 committed by GitHub
parent 05ddd773ff
commit d7634d1cb1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 84 additions and 42 deletions

View file

@ -164,11 +164,14 @@ homeassistant.components.pure_energie.*
homeassistant.components.rainmachine.* homeassistant.components.rainmachine.*
homeassistant.components.rdw.* homeassistant.components.rdw.*
homeassistant.components.recollect_waste.* homeassistant.components.recollect_waste.*
homeassistant.components.recorder.models
homeassistant.components.recorder.history homeassistant.components.recorder.history
homeassistant.components.recorder.pool
homeassistant.components.recorder.purge homeassistant.components.recorder.purge
homeassistant.components.recorder.repack homeassistant.components.recorder.repack
homeassistant.components.recorder.statistics homeassistant.components.recorder.statistics
homeassistant.components.recorder.util homeassistant.components.recorder.util
homeassistant.components.recorder.websocket_api
homeassistant.components.remote.* homeassistant.components.remote.*
homeassistant.components.renault.* homeassistant.components.renault.*
homeassistant.components.ridwell.* homeassistant.components.ridwell.*

View file

@ -232,8 +232,7 @@ def run_information(hass, point_in_time: datetime | None = None) -> RecorderRuns
There is also the run that covers point_in_time. There is also the run that covers point_in_time.
""" """
run_info = run_information_from_instance(hass, point_in_time) if run_info := run_information_from_instance(hass, point_in_time):
if run_info:
return run_info return run_info
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
@ -1028,8 +1027,7 @@ class Recorder(threading.Thread):
try: try:
if event.event_type == EVENT_STATE_CHANGED: if event.event_type == EVENT_STATE_CHANGED:
dbevent = Events.from_event(event, event_data="{}") dbevent = Events.from_event(event, event_data=None)
dbevent.event_data = None
else: else:
dbevent = Events.from_event(event) dbevent = Events.from_event(event)
except (TypeError, ValueError): except (TypeError, ValueError):

View file

@ -4,7 +4,7 @@ from __future__ import annotations
from datetime import datetime, timedelta from datetime import datetime, timedelta
import json import json
import logging import logging
from typing import Any, TypedDict, overload from typing import Any, TypedDict, cast, overload
from fnvhash import fnv1a_32 from fnvhash import fnv1a_32
from sqlalchemy import ( from sqlalchemy import (
@ -35,6 +35,7 @@ from homeassistant.const import (
MAX_LENGTH_STATE_STATE, MAX_LENGTH_STATE_STATE,
) )
from homeassistant.core import Context, Event, EventOrigin, State from homeassistant.core import Context, Event, EventOrigin, State
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .const import JSON_DUMP from .const import JSON_DUMP
@ -113,11 +114,13 @@ class Events(Base): # type: ignore[misc,valid-type]
) )
@staticmethod @staticmethod
def from_event(event, event_data=None): def from_event(
event: Event, event_data: UndefinedType | None = UNDEFINED
) -> Events:
"""Create an event database object from a native event.""" """Create an event database object from a native event."""
return Events( return Events(
event_type=event.event_type, event_type=event.event_type,
event_data=event_data or JSON_DUMP(event.data), event_data=JSON_DUMP(event.data) if event_data is UNDEFINED else event_data,
origin=str(event.origin.value), origin=str(event.origin.value),
time_fired=event.time_fired, time_fired=event.time_fired,
context_id=event.context.id, context_id=event.context.id,
@ -125,7 +128,7 @@ class Events(Base): # type: ignore[misc,valid-type]
context_parent_id=event.context.parent_id, context_parent_id=event.context.parent_id,
) )
def to_native(self, validate_entity_id=True): def to_native(self, validate_entity_id: bool = True) -> Event | None:
"""Convert to a native HA Event.""" """Convert to a native HA Event."""
context = Context( context = Context(
id=self.context_id, id=self.context_id,
@ -185,7 +188,7 @@ class States(Base): # type: ignore[misc,valid-type]
) )
@staticmethod @staticmethod
def from_event(event) -> States: def from_event(event: Event) -> States:
"""Create object from a state_changed event.""" """Create object from a state_changed event."""
entity_id = event.data["entity_id"] entity_id = event.data["entity_id"]
state: State | None = event.data.get("new_state") state: State | None = event.data.get("new_state")
@ -266,12 +269,12 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
@staticmethod @staticmethod
def hash_shared_attrs(shared_attrs: str) -> int: def hash_shared_attrs(shared_attrs: str) -> int:
"""Return the hash of json encoded shared attributes.""" """Return the hash of json encoded shared attributes."""
return fnv1a_32(shared_attrs.encode("utf-8")) return cast(int, fnv1a_32(shared_attrs.encode("utf-8")))
def to_native(self) -> dict[str, Any]: def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object.""" """Convert to an HA state object."""
try: try:
return json.loads(self.shared_attrs) return cast(dict[str, Any], json.loads(self.shared_attrs))
except ValueError: except ValueError:
# When json.loads fails # When json.loads fails
_LOGGER.exception("Error converting row to state attributes: %s", self) _LOGGER.exception("Error converting row to state attributes: %s", self)
@ -311,8 +314,8 @@ class StatisticsBase:
id = Column(Integer, Identity(), primary_key=True) id = Column(Integer, Identity(), primary_key=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow) created = Column(DATETIME_TYPE, default=dt_util.utcnow)
@declared_attr @declared_attr # type: ignore[misc]
def metadata_id(self): def metadata_id(self) -> Column:
"""Define the metadata_id column for sub classes.""" """Define the metadata_id column for sub classes."""
return Column( return Column(
Integer, Integer,
@ -329,7 +332,7 @@ class StatisticsBase:
sum = Column(DOUBLE_TYPE) sum = Column(DOUBLE_TYPE)
@classmethod @classmethod
def from_stats(cls, metadata_id: int, stats: StatisticData): def from_stats(cls, metadata_id: int, stats: StatisticData) -> StatisticsBase:
"""Create object from a statistics.""" """Create object from a statistics."""
return cls( # type: ignore[call-arg,misc] return cls( # type: ignore[call-arg,misc]
metadata_id=metadata_id, metadata_id=metadata_id,
@ -422,7 +425,7 @@ class RecorderRuns(Base): # type: ignore[misc,valid-type]
f")>" f")>"
) )
def entity_ids(self, point_in_time=None): def entity_ids(self, point_in_time: datetime | None = None) -> list[str]:
"""Return the entity ids that existed in this run. """Return the entity ids that existed in this run.
Specify point_in_time if you want to know which existed at that point Specify point_in_time if you want to know which existed at that point
@ -443,7 +446,7 @@ class RecorderRuns(Base): # type: ignore[misc,valid-type]
return [row[0] for row in query] return [row[0] for row in query]
def to_native(self, validate_entity_id=True): def to_native(self, validate_entity_id: bool = True) -> RecorderRuns:
"""Return self, native format is this model.""" """Return self, native format is this model."""
return self return self
@ -540,16 +543,16 @@ class LazyState(State):
) -> None: ) -> None:
"""Init the lazy state.""" """Init the lazy state."""
self._row = row self._row = row
self.entity_id = self._row.entity_id self.entity_id: str = self._row.entity_id
self.state = self._row.state or "" self.state = self._row.state or ""
self._attributes = None self._attributes: dict[str, Any] | None = None
self._last_changed = None self._last_changed: datetime | None = None
self._last_updated = None self._last_updated: datetime | None = None
self._context = None self._context: Context | None = None
self._attr_cache = attr_cache self._attr_cache = attr_cache
@property # type: ignore[override] @property # type: ignore[override]
def attributes(self): def attributes(self) -> dict[str, Any]: # type: ignore[override]
"""State attributes.""" """State attributes."""
if self._attributes is None: if self._attributes is None:
source = self._row.shared_attrs or self._row.attributes source = self._row.shared_attrs or self._row.attributes
@ -574,47 +577,47 @@ class LazyState(State):
return self._attributes return self._attributes
@attributes.setter @attributes.setter
def attributes(self, value): def attributes(self, value: dict[str, Any]) -> None:
"""Set attributes.""" """Set attributes."""
self._attributes = value self._attributes = value
@property # type: ignore[override] @property # type: ignore[override]
def context(self): def context(self) -> Context: # type: ignore[override]
"""State context.""" """State context."""
if not self._context: if self._context is None:
self._context = Context(id=None) self._context = Context(id=None) # type: ignore[arg-type]
return self._context return self._context
@context.setter @context.setter
def context(self, value): def context(self, value: Context) -> None:
"""Set context.""" """Set context."""
self._context = value self._context = value
@property # type: ignore[override] @property # type: ignore[override]
def last_changed(self): def last_changed(self) -> datetime: # type: ignore[override]
"""Last changed datetime.""" """Last changed datetime."""
if not self._last_changed: if self._last_changed is None:
self._last_changed = process_timestamp(self._row.last_changed) self._last_changed = process_timestamp(self._row.last_changed)
return self._last_changed return self._last_changed
@last_changed.setter @last_changed.setter
def last_changed(self, value): def last_changed(self, value: datetime) -> None:
"""Set last changed datetime.""" """Set last changed datetime."""
self._last_changed = value self._last_changed = value
@property # type: ignore[override] @property # type: ignore[override]
def last_updated(self): def last_updated(self) -> datetime: # type: ignore[override]
"""Last updated datetime.""" """Last updated datetime."""
if not self._last_updated: if self._last_updated is None:
self._last_updated = process_timestamp(self._row.last_updated) self._last_updated = process_timestamp(self._row.last_updated)
return self._last_updated return self._last_updated
@last_updated.setter @last_updated.setter
def last_updated(self, value): def last_updated(self, value: datetime) -> None:
"""Set last updated datetime.""" """Set last updated datetime."""
self._last_updated = value self._last_updated = value
def as_dict(self): def as_dict(self) -> dict[str, Any]: # type: ignore[override]
"""Return a dict representation of the LazyState. """Return a dict representation of the LazyState.
Async friendly. Async friendly.
@ -645,7 +648,7 @@ class LazyState(State):
"last_updated": last_updated_isoformat, "last_updated": last_updated_isoformat,
} }
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
"""Return the comparison.""" """Return the comparison."""
return ( return (
other.__class__ in [self.__class__, State] other.__class__ in [self.__class__, State]

View file

@ -1,5 +1,6 @@
"""A pool for sqlite connections.""" """A pool for sqlite connections."""
import threading import threading
from typing import Any
from sqlalchemy.pool import NullPool, SingletonThreadPool from sqlalchemy.pool import NullPool, SingletonThreadPool
@ -10,14 +11,16 @@ from .const import DB_WORKER_PREFIX
POOL_SIZE = 5 POOL_SIZE = 5
class RecorderPool(SingletonThreadPool, NullPool): class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
"""A hybrid of NullPool and SingletonThreadPool. """A hybrid of NullPool and SingletonThreadPool.
When called from the creating thread or db executor acts like SingletonThreadPool When called from the creating thread or db executor acts like SingletonThreadPool
When called from any other thread, acts like NullPool When called from any other thread, acts like NullPool
""" """
def __init__(self, *args, **kw): # pylint: disable=super-init-not-called def __init__( # pylint: disable=super-init-not-called
self, *args: Any, **kw: Any
) -> None:
"""Create the pool.""" """Create the pool."""
kw["pool_size"] = POOL_SIZE kw["pool_size"] = POOL_SIZE
SingletonThreadPool.__init__(self, *args, **kw) SingletonThreadPool.__init__(self, *args, **kw)
@ -30,22 +33,24 @@ class RecorderPool(SingletonThreadPool, NullPool):
thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX) thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX)
) )
def _do_return_conn(self, conn): # Any can be switched out for ConnectionPoolEntry in the next version of sqlalchemy
def _do_return_conn(self, conn: Any) -> Any:
if self.recorder_or_dbworker: if self.recorder_or_dbworker:
return super()._do_return_conn(conn) return super()._do_return_conn(conn)
conn.close() conn.close()
def shutdown(self): def shutdown(self) -> None:
"""Close the connection.""" """Close the connection."""
if self.recorder_or_dbworker and self._conn and (conn := self._conn.current()): if self.recorder_or_dbworker and self._conn and (conn := self._conn.current()):
conn.close() conn.close()
def dispose(self): def dispose(self) -> None:
"""Dispose of the connection.""" """Dispose of the connection."""
if self.recorder_or_dbworker: if self.recorder_or_dbworker:
return super().dispose() super().dispose()
def _do_get(self): # Any can be switched out for ConnectionPoolEntry in the next version of sqlalchemy
def _do_get(self) -> Any:
if self.recorder_or_dbworker: if self.recorder_or_dbworker:
return super()._do_get() return super()._do_get()
report( report(

View file

@ -1606,6 +1606,17 @@ no_implicit_optional = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.recorder.models]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.history] [mypy-homeassistant.components.recorder.history]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true
@ -1617,6 +1628,17 @@ no_implicit_optional = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.recorder.pool]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.purge] [mypy-homeassistant.components.recorder.purge]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true
@ -1661,6 +1683,17 @@ no_implicit_optional = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.recorder.websocket_api]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.remote.*] [mypy-homeassistant.components.remote.*]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true