Move imports in recorder component (#27859)

* move imports to top-level in recorder init

* move imports to top-level in recorder migration

* move imports to top-level in recorder models

* move imports to top-level in recorder purge

* move imports to top-level in recorder util

* fix pylint
This commit is contained in:
Malte Franken 2019-10-19 04:14:54 +11:00 committed by Paulus Schoutsen
parent e95b8035ed
commit 83a709b768
6 changed files with 27 additions and 50 deletions

View file

@ -8,8 +8,14 @@ import queue
import threading import threading
import time import time
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from sqlite3 import Connection
import voluptuous as vol import voluptuous as vol
from sqlalchemy import exc, create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.event import listens_for
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import StaticPool
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
@ -23,6 +29,7 @@ from homeassistant.const import (
EVENT_TIME_CHANGED, EVENT_TIME_CHANGED,
MATCH_ALL, MATCH_ALL,
) )
from homeassistant.components import persistent_notification
from homeassistant.core import CoreState, HomeAssistant, callback from homeassistant.core import CoreState, HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import generate_filter from homeassistant.helpers.entityfilter import generate_filter
@ -31,6 +38,7 @@ import homeassistant.util.dt as dt_util
from . import migration, purge from . import migration, purge
from .const import DATA_INSTANCE from .const import DATA_INSTANCE
from .models import Base, Events, RecorderRuns, States
from .util import session_scope from .util import session_scope
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -100,11 +108,9 @@ def run_information(hass, point_in_time: Optional[datetime] = None):
There is also the run that covers point_in_time. There is also the run that covers point_in_time.
""" """
from . import models
ins = hass.data[DATA_INSTANCE] ins = hass.data[DATA_INSTANCE]
recorder_runs = models.RecorderRuns recorder_runs = RecorderRuns
if point_in_time is None or point_in_time > ins.recording_start: if point_in_time is None or point_in_time > ins.recording_start:
return ins.run_info return ins.run_info
@ -208,10 +214,6 @@ class Recorder(threading.Thread):
def run(self): def run(self):
"""Start processing events to save.""" """Start processing events to save."""
from .models import States, Events
from homeassistant.components import persistent_notification
from sqlalchemy import exc
tries = 1 tries = 1
connected = False connected = False
@ -393,18 +395,10 @@ class Recorder(threading.Thread):
def _setup_connection(self): def _setup_connection(self):
"""Ensure database is ready to fly.""" """Ensure database is ready to fly."""
from sqlalchemy import create_engine, event
from sqlalchemy.engine import Engine
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import sessionmaker
from sqlite3 import Connection
from . import models
kwargs = {} kwargs = {}
# pylint: disable=unused-variable # pylint: disable=unused-variable
@event.listens_for(Engine, "connect") @listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record): def set_sqlite_pragma(dbapi_connection, connection_record):
"""Set sqlite's WAL mode.""" """Set sqlite's WAL mode."""
if isinstance(dbapi_connection, Connection): if isinstance(dbapi_connection, Connection):
@ -416,8 +410,6 @@ class Recorder(threading.Thread):
dbapi_connection.isolation_level = old_isolation dbapi_connection.isolation_level = old_isolation
if self.db_url == "sqlite://" or ":memory:" in self.db_url: if self.db_url == "sqlite://" or ":memory:" in self.db_url:
from sqlalchemy.pool import StaticPool
kwargs["connect_args"] = {"check_same_thread": False} kwargs["connect_args"] = {"check_same_thread": False}
kwargs["poolclass"] = StaticPool kwargs["poolclass"] = StaticPool
kwargs["pool_reset_on_return"] = None kwargs["pool_reset_on_return"] = None
@ -428,7 +420,7 @@ class Recorder(threading.Thread):
self.engine.dispose() self.engine.dispose()
self.engine = create_engine(self.db_url, **kwargs) self.engine = create_engine(self.db_url, **kwargs)
models.Base.metadata.create_all(self.engine) Base.metadata.create_all(self.engine)
self.get_session = scoped_session(sessionmaker(bind=self.engine)) self.get_session = scoped_session(sessionmaker(bind=self.engine))
def _close_connection(self): def _close_connection(self):
@ -439,8 +431,6 @@ class Recorder(threading.Thread):
def _setup_run(self): def _setup_run(self):
"""Log the start of the current run.""" """Log the start of the current run."""
from .models import RecorderRuns
with session_scope(session=self.get_session()) as session: with session_scope(session=self.get_session()) as session:
for run in session.query(RecorderRuns).filter_by(end=None): for run in session.query(RecorderRuns).filter_by(end=None):
run.closed_incorrect = True run.closed_incorrect = True

View file

@ -2,6 +2,11 @@
import logging import logging
import os import os
from sqlalchemy import Table, text
from sqlalchemy.engine import reflection
from sqlalchemy.exc import OperationalError, SQLAlchemyError
from .models import SchemaChanges, SCHEMA_VERSION, Base
from .util import session_scope from .util import session_scope
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -10,8 +15,6 @@ PROGRESS_FILE = ".migration_progress"
def migrate_schema(instance): def migrate_schema(instance):
"""Check if the schema needs to be upgraded.""" """Check if the schema needs to be upgraded."""
from .models import SchemaChanges, SCHEMA_VERSION
progress_path = instance.hass.config.path(PROGRESS_FILE) progress_path = instance.hass.config.path(PROGRESS_FILE)
with session_scope(session=instance.get_session()) as session: with session_scope(session=instance.get_session()) as session:
@ -60,11 +63,7 @@ def _create_index(engine, table_name, index_name):
The index name should match the name given for the index The index name should match the name given for the index
within the table definition described in the models within the table definition described in the models
""" """
from sqlalchemy import Table table = Table(table_name, Base.metadata)
from sqlalchemy.exc import OperationalError
from . import models
table = Table(table_name, models.Base.metadata)
_LOGGER.debug("Looking up index for table %s", table_name) _LOGGER.debug("Looking up index for table %s", table_name)
# Look up the index object by name from the table is the models # Look up the index object by name from the table is the models
index = next(idx for idx in table.indexes if idx.name == index_name) index = next(idx for idx in table.indexes if idx.name == index_name)
@ -99,9 +98,6 @@ def _drop_index(engine, table_name, index_name):
string here is generated from the method parameters without sanitizing. string here is generated from the method parameters without sanitizing.
DO NOT USE THIS FUNCTION IN ANY OPERATION THAT TAKES USER INPUT. DO NOT USE THIS FUNCTION IN ANY OPERATION THAT TAKES USER INPUT.
""" """
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
_LOGGER.debug("Dropping index %s from table %s", index_name, table_name) _LOGGER.debug("Dropping index %s from table %s", index_name, table_name)
success = False success = False
@ -159,9 +155,6 @@ def _drop_index(engine, table_name, index_name):
def _add_columns(engine, table_name, columns_def): def _add_columns(engine, table_name, columns_def):
"""Add columns to a table.""" """Add columns to a table."""
from sqlalchemy import text
from sqlalchemy.exc import OperationalError
_LOGGER.info( _LOGGER.info(
"Adding columns %s to table %s. Note: this can take several " "Adding columns %s to table %s. Note: this can take several "
"minutes on large databases and slow computers. Please " "minutes on large databases and slow computers. Please "
@ -277,9 +270,6 @@ def _inspect_schema_version(engine, session):
version 1 are present to make the determination. Eventually this logic version 1 are present to make the determination. Eventually this logic
can be removed and we can assume a new db is being created. can be removed and we can assume a new db is being created.
""" """
from sqlalchemy.engine import reflection
from .models import SchemaChanges, SCHEMA_VERSION
inspector = reflection.Inspector.from_engine(engine) inspector = reflection.Inspector.from_engine(engine)
indexes = inspector.get_indexes("events") indexes = inspector.get_indexes("events")

View file

@ -15,6 +15,7 @@ from sqlalchemy import (
distinct, distinct,
) )
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm.session import Session
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
@ -164,8 +165,6 @@ class RecorderRuns(Base): # type: ignore
Specify point_in_time if you want to know which existed at that point Specify point_in_time if you want to know which existed at that point
in time inside the run. in time inside the run.
""" """
from sqlalchemy.orm.session import Session
session = Session.object_session(self) session = Session.object_session(self)
assert session is not None, "RecorderRuns need to be persisted" assert session is not None, "RecorderRuns need to be persisted"

View file

@ -2,7 +2,10 @@
from datetime import timedelta from datetime import timedelta
import logging import logging
from sqlalchemy.exc import SQLAlchemyError
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .models import Events, States
from .util import session_scope from .util import session_scope
@ -11,9 +14,6 @@ _LOGGER = logging.getLogger(__name__)
def purge_old_data(instance, purge_days, repack): def purge_old_data(instance, purge_days, repack):
"""Purge events and states older than purge_days ago.""" """Purge events and states older than purge_days ago."""
from .models import States, Events
from sqlalchemy.exc import SQLAlchemyError
purge_before = dt_util.utcnow() - timedelta(days=purge_days) purge_before = dt_util.utcnow() - timedelta(days=purge_days)
_LOGGER.debug("Purging events before %s", purge_before) _LOGGER.debug("Purging events before %s", purge_before)

View file

@ -3,6 +3,8 @@ from contextlib import contextmanager
import logging import logging
import time import time
from sqlalchemy.exc import OperationalError, SQLAlchemyError
from .const import DATA_INSTANCE from .const import DATA_INSTANCE
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -37,8 +39,6 @@ def session_scope(*, hass=None, session=None):
def commit(session, work): def commit(session, work):
"""Commit & retry work: Either a model or in a function.""" """Commit & retry work: Either a model or in a function."""
import sqlalchemy.exc
for _ in range(0, RETRIES): for _ in range(0, RETRIES):
try: try:
if callable(work): if callable(work):
@ -47,7 +47,7 @@ def commit(session, work):
session.add(work) session.add(work)
session.commit() session.commit()
return True return True
except sqlalchemy.exc.OperationalError as err: except OperationalError as err:
_LOGGER.error("Error executing query: %s", err) _LOGGER.error("Error executing query: %s", err)
session.rollback() session.rollback()
time.sleep(QUERY_RETRY_WAIT) time.sleep(QUERY_RETRY_WAIT)
@ -59,8 +59,6 @@ def execute(qry):
This method also retries a few times in the case of stale connections. This method also retries a few times in the case of stale connections.
""" """
from sqlalchemy.exc import SQLAlchemyError
for tryno in range(0, RETRIES): for tryno in range(0, RETRIES):
try: try:
timer_start = time.perf_counter() timer_start = time.perf_counter()

View file

@ -23,9 +23,9 @@ def create_engine_test(*args, **kwargs):
async def test_schema_update_calls(hass): async def test_schema_update_calls(hass):
"""Test that schema migrations occur in correct order.""" """Test that schema migrations occur in correct order."""
with patch("sqlalchemy.create_engine", new=create_engine_test), patch( with patch(
"homeassistant.components.recorder.migration._apply_update" "homeassistant.components.recorder.create_engine", new=create_engine_test
) as update: ), patch("homeassistant.components.recorder.migration._apply_update") as update:
await async_setup_component( await async_setup_component(
hass, "recorder", {"recorder": {"db_url": "sqlite://"}} hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
) )