Upgrade to sqlalchemy 1.4.11 (#49538)

This commit is contained in:
J. Nick Koston 2021-04-21 20:29:36 -10:00 committed by GitHub
parent 303ab36c54
commit c10836fcee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 84 additions and 74 deletions

View file

@ -13,6 +13,7 @@ env:
CACHE_VERSION: 1 CACHE_VERSION: 1
DEFAULT_PYTHON: 3.8 DEFAULT_PYTHON: 3.8
PRE_COMMIT_CACHE: ~/.cache/pre-commit PRE_COMMIT_CACHE: ~/.cache/pre-commit
SQLALCHEMY_WARN_20: 1
jobs: jobs:
# Separate job to pre-populate the base dependency cache # Separate job to pre-populate the base dependency cache

View file

@ -2,7 +2,7 @@
"domain": "recorder", "domain": "recorder",
"name": "Recorder", "name": "Recorder",
"documentation": "https://www.home-assistant.io/integrations/recorder", "documentation": "https://www.home-assistant.io/integrations/recorder",
"requirements": ["sqlalchemy==1.3.23"], "requirements": ["sqlalchemy==1.4.11"],
"codeowners": [], "codeowners": [],
"quality_scale": "internal", "quality_scale": "internal",
"iot_class": "local_push" "iot_class": "local_push"

View file

@ -1,8 +1,8 @@
"""Schema migration helpers.""" """Schema migration helpers."""
import logging import logging
import sqlalchemy
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, text from sqlalchemy import ForeignKeyConstraint, MetaData, Table, text
from sqlalchemy.engine import reflection
from sqlalchemy.exc import ( from sqlalchemy.exc import (
InternalError, InternalError,
OperationalError, OperationalError,
@ -50,13 +50,13 @@ def migrate_schema(instance, current_version):
for version in range(current_version, SCHEMA_VERSION): for version in range(current_version, SCHEMA_VERSION):
new_version = version + 1 new_version = version + 1
_LOGGER.info("Upgrading recorder db schema to version %s", new_version) _LOGGER.info("Upgrading recorder db schema to version %s", new_version)
_apply_update(instance.engine, new_version, current_version) _apply_update(instance.engine, session, new_version, current_version)
session.add(SchemaChanges(schema_version=new_version)) session.add(SchemaChanges(schema_version=new_version))
_LOGGER.info("Upgrade to version %s done", new_version) _LOGGER.info("Upgrade to version %s done", new_version)
def _create_index(engine, table_name, index_name): def _create_index(connection, table_name, index_name):
"""Create an index for the specified table. """Create an index for the specified table.
The index name should match the name given for the index The index name should match the name given for the index
@ -78,7 +78,7 @@ def _create_index(engine, table_name, index_name):
index_name, index_name,
) )
try: try:
index.create(engine) index.create(connection)
except (InternalError, ProgrammingError, OperationalError) as err: except (InternalError, ProgrammingError, OperationalError) as err:
lower_err_str = str(err).lower() lower_err_str = str(err).lower()
@ -92,7 +92,7 @@ def _create_index(engine, table_name, index_name):
_LOGGER.debug("Finished creating %s", index_name) _LOGGER.debug("Finished creating %s", index_name)
def _drop_index(engine, table_name, index_name): def _drop_index(connection, table_name, index_name):
"""Drop an index from a specified table. """Drop an index from a specified table.
There is no universal way to do something like `DROP INDEX IF EXISTS` There is no universal way to do something like `DROP INDEX IF EXISTS`
@ -108,7 +108,7 @@ def _drop_index(engine, table_name, index_name):
# Engines like DB2/Oracle # Engines like DB2/Oracle
try: try:
engine.execute(text(f"DROP INDEX {index_name}")) connection.execute(text(f"DROP INDEX {index_name}"))
except SQLAlchemyError: except SQLAlchemyError:
pass pass
else: else:
@ -117,7 +117,7 @@ def _drop_index(engine, table_name, index_name):
# Engines like SQLite, SQL Server # Engines like SQLite, SQL Server
if not success: if not success:
try: try:
engine.execute( connection.execute(
text( text(
"DROP INDEX {table}.{index}".format( "DROP INDEX {table}.{index}".format(
index=index_name, table=table_name index=index_name, table=table_name
@ -132,7 +132,7 @@ def _drop_index(engine, table_name, index_name):
if not success: if not success:
# Engines like MySQL, MS Access # Engines like MySQL, MS Access
try: try:
engine.execute( connection.execute(
text( text(
"DROP INDEX {index} ON {table}".format( "DROP INDEX {index} ON {table}".format(
index=index_name, table=table_name index=index_name, table=table_name
@ -163,7 +163,7 @@ def _drop_index(engine, table_name, index_name):
) )
def _add_columns(engine, table_name, columns_def): def _add_columns(connection, table_name, columns_def):
"""Add columns to a table.""" """Add columns to a table."""
_LOGGER.warning( _LOGGER.warning(
"Adding columns %s to table %s. Note: this can take several " "Adding columns %s to table %s. Note: this can take several "
@ -176,7 +176,7 @@ def _add_columns(engine, table_name, columns_def):
columns_def = [f"ADD {col_def}" for col_def in columns_def] columns_def = [f"ADD {col_def}" for col_def in columns_def]
try: try:
engine.execute( connection.execute(
text( text(
"ALTER TABLE {table} {columns_def}".format( "ALTER TABLE {table} {columns_def}".format(
table=table_name, columns_def=", ".join(columns_def) table=table_name, columns_def=", ".join(columns_def)
@ -191,7 +191,7 @@ def _add_columns(engine, table_name, columns_def):
for column_def in columns_def: for column_def in columns_def:
try: try:
engine.execute( connection.execute(
text( text(
"ALTER TABLE {table} {column_def}".format( "ALTER TABLE {table} {column_def}".format(
table=table_name, column_def=column_def table=table_name, column_def=column_def
@ -209,7 +209,7 @@ def _add_columns(engine, table_name, columns_def):
) )
def _modify_columns(engine, table_name, columns_def): def _modify_columns(connection, engine, table_name, columns_def):
"""Modify columns in a table.""" """Modify columns in a table."""
if engine.dialect.name == "sqlite": if engine.dialect.name == "sqlite":
_LOGGER.debug( _LOGGER.debug(
@ -242,7 +242,7 @@ def _modify_columns(engine, table_name, columns_def):
columns_def = [f"MODIFY {col_def}" for col_def in columns_def] columns_def = [f"MODIFY {col_def}" for col_def in columns_def]
try: try:
engine.execute( connection.execute(
text( text(
"ALTER TABLE {table} {columns_def}".format( "ALTER TABLE {table} {columns_def}".format(
table=table_name, columns_def=", ".join(columns_def) table=table_name, columns_def=", ".join(columns_def)
@ -255,7 +255,7 @@ def _modify_columns(engine, table_name, columns_def):
for column_def in columns_def: for column_def in columns_def:
try: try:
engine.execute( connection.execute(
text( text(
"ALTER TABLE {table} {column_def}".format( "ALTER TABLE {table} {column_def}".format(
table=table_name, column_def=column_def table=table_name, column_def=column_def
@ -268,9 +268,9 @@ def _modify_columns(engine, table_name, columns_def):
) )
def _update_states_table_with_foreign_key_options(engine): def _update_states_table_with_foreign_key_options(connection, engine):
"""Add the options to foreign key constraints.""" """Add the options to foreign key constraints."""
inspector = reflection.Inspector.from_engine(engine) inspector = sqlalchemy.inspect(engine)
alters = [] alters = []
for foreign_key in inspector.get_foreign_keys(TABLE_STATES): for foreign_key in inspector.get_foreign_keys(TABLE_STATES):
if foreign_key["name"] and ( if foreign_key["name"] and (
@ -297,25 +297,26 @@ def _update_states_table_with_foreign_key_options(engine):
for alter in alters: for alter in alters:
try: try:
engine.execute(DropConstraint(alter["old_fk"])) connection.execute(DropConstraint(alter["old_fk"]))
for fkc in states_key_constraints: for fkc in states_key_constraints:
if fkc.column_keys == alter["columns"]: if fkc.column_keys == alter["columns"]:
engine.execute(AddConstraint(fkc)) connection.execute(AddConstraint(fkc))
except (InternalError, OperationalError): except (InternalError, OperationalError):
_LOGGER.exception( _LOGGER.exception(
"Could not update foreign options in %s table", TABLE_STATES "Could not update foreign options in %s table", TABLE_STATES
) )
def _apply_update(engine, new_version, old_version): def _apply_update(engine, session, new_version, old_version):
"""Perform operations to bring schema up to date.""" """Perform operations to bring schema up to date."""
connection = session.connection()
if new_version == 1: if new_version == 1:
_create_index(engine, "events", "ix_events_time_fired") _create_index(connection, "events", "ix_events_time_fired")
elif new_version == 2: elif new_version == 2:
# Create compound start/end index for recorder_runs # Create compound start/end index for recorder_runs
_create_index(engine, "recorder_runs", "ix_recorder_runs_start_end") _create_index(connection, "recorder_runs", "ix_recorder_runs_start_end")
# Create indexes for states # Create indexes for states
_create_index(engine, "states", "ix_states_last_updated") _create_index(connection, "states", "ix_states_last_updated")
elif new_version == 3: elif new_version == 3:
# There used to be a new index here, but it was removed in version 4. # There used to be a new index here, but it was removed in version 4.
pass pass
@ -325,41 +326,41 @@ def _apply_update(engine, new_version, old_version):
if old_version == 3: if old_version == 3:
# Remove index that was added in version 3 # Remove index that was added in version 3
_drop_index(engine, "states", "ix_states_created_domain") _drop_index(connection, "states", "ix_states_created_domain")
if old_version == 2: if old_version == 2:
# Remove index that was added in version 2 # Remove index that was added in version 2
_drop_index(engine, "states", "ix_states_entity_id_created") _drop_index(connection, "states", "ix_states_entity_id_created")
# Remove indexes that were added in version 0 # Remove indexes that were added in version 0
_drop_index(engine, "states", "states__state_changes") _drop_index(connection, "states", "states__state_changes")
_drop_index(engine, "states", "states__significant_changes") _drop_index(connection, "states", "states__significant_changes")
_drop_index(engine, "states", "ix_states_entity_id_created") _drop_index(connection, "states", "ix_states_entity_id_created")
_create_index(engine, "states", "ix_states_entity_id_last_updated") _create_index(connection, "states", "ix_states_entity_id_last_updated")
elif new_version == 5: elif new_version == 5:
# Create supporting index for States.event_id foreign key # Create supporting index for States.event_id foreign key
_create_index(engine, "states", "ix_states_event_id") _create_index(connection, "states", "ix_states_event_id")
elif new_version == 6: elif new_version == 6:
_add_columns( _add_columns(
engine, session,
"events", "events",
["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"], ["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"],
) )
_create_index(engine, "events", "ix_events_context_id") _create_index(connection, "events", "ix_events_context_id")
_create_index(engine, "events", "ix_events_context_user_id") _create_index(connection, "events", "ix_events_context_user_id")
_add_columns( _add_columns(
engine, connection,
"states", "states",
["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"], ["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"],
) )
_create_index(engine, "states", "ix_states_context_id") _create_index(connection, "states", "ix_states_context_id")
_create_index(engine, "states", "ix_states_context_user_id") _create_index(connection, "states", "ix_states_context_user_id")
elif new_version == 7: elif new_version == 7:
_create_index(engine, "states", "ix_states_entity_id") _create_index(connection, "states", "ix_states_entity_id")
elif new_version == 8: elif new_version == 8:
_add_columns(engine, "events", ["context_parent_id CHARACTER(36)"]) _add_columns(connection, "events", ["context_parent_id CHARACTER(36)"])
_add_columns(engine, "states", ["old_state_id INTEGER"]) _add_columns(connection, "states", ["old_state_id INTEGER"])
_create_index(engine, "events", "ix_events_context_parent_id") _create_index(connection, "events", "ix_events_context_parent_id")
elif new_version == 9: elif new_version == 9:
# We now get the context from events with a join # We now get the context from events with a join
# since its always there on state_changed events # since its always there on state_changed events
@ -369,32 +370,36 @@ def _apply_update(engine, new_version, old_version):
# and we would have to move to something like # and we would have to move to something like
# sqlalchemy alembic to make that work # sqlalchemy alembic to make that work
# #
_drop_index(engine, "states", "ix_states_context_id") _drop_index(connection, "states", "ix_states_context_id")
_drop_index(engine, "states", "ix_states_context_user_id") _drop_index(connection, "states", "ix_states_context_user_id")
# This index won't be there if they were not running # This index won't be there if they were not running
# nightly but we don't treat that as a critical issue # nightly but we don't treat that as a critical issue
_drop_index(engine, "states", "ix_states_context_parent_id") _drop_index(connection, "states", "ix_states_context_parent_id")
# Redundant keys on composite index: # Redundant keys on composite index:
# We already have ix_states_entity_id_last_updated # We already have ix_states_entity_id_last_updated
_drop_index(engine, "states", "ix_states_entity_id") _drop_index(connection, "states", "ix_states_entity_id")
_create_index(engine, "events", "ix_events_event_type_time_fired") _create_index(connection, "events", "ix_events_event_type_time_fired")
_drop_index(engine, "events", "ix_events_event_type") _drop_index(connection, "events", "ix_events_event_type")
elif new_version == 10: elif new_version == 10:
# Now done in step 11 # Now done in step 11
pass pass
elif new_version == 11: elif new_version == 11:
_create_index(engine, "states", "ix_states_old_state_id") _create_index(connection, "states", "ix_states_old_state_id")
_update_states_table_with_foreign_key_options(engine) _update_states_table_with_foreign_key_options(connection, engine)
elif new_version == 12: elif new_version == 12:
if engine.dialect.name == "mysql": if engine.dialect.name == "mysql":
_modify_columns(engine, "events", ["event_data LONGTEXT"]) _modify_columns(connection, engine, "events", ["event_data LONGTEXT"])
_modify_columns(engine, "states", ["attributes LONGTEXT"]) _modify_columns(connection, engine, "states", ["attributes LONGTEXT"])
elif new_version == 13: elif new_version == 13:
if engine.dialect.name == "mysql": if engine.dialect.name == "mysql":
_modify_columns( _modify_columns(
engine, "events", ["time_fired DATETIME(6)", "created DATETIME(6)"] connection,
engine,
"events",
["time_fired DATETIME(6)", "created DATETIME(6)"],
) )
_modify_columns( _modify_columns(
connection,
engine, engine,
"states", "states",
[ [
@ -404,7 +409,7 @@ def _apply_update(engine, new_version, old_version):
], ],
) )
elif new_version == 14: elif new_version == 14:
_modify_columns(engine, "events", ["event_type VARCHAR(64)"]) _modify_columns(connection, engine, "events", ["event_type VARCHAR(64)"])
else: else:
raise ValueError(f"No schema migration defined for version {new_version}") raise ValueError(f"No schema migration defined for version {new_version}")
@ -418,7 +423,7 @@ def _inspect_schema_version(engine, session):
version 1 are present to make the determination. Eventually this logic version 1 are present to make the determination. Eventually this logic
can be removed and we can assume a new db is being created. can be removed and we can assume a new db is being created.
""" """
inspector = reflection.Inspector.from_engine(engine) inspector = sqlalchemy.inspect(engine)
indexes = inspector.get_indexes("events") indexes = inspector.get_indexes("events")
for index in indexes: for index in indexes:

View file

@ -49,7 +49,7 @@ def session_scope(
need_rollback = False need_rollback = False
try: try:
yield session yield session
if session.transaction: if session.get_transaction():
need_rollback = True need_rollback = True
session.commit() session.commit()
except Exception as err: except Exception as err:

View file

@ -2,7 +2,7 @@
"domain": "sql", "domain": "sql",
"name": "SQL", "name": "SQL",
"documentation": "https://www.home-assistant.io/integrations/sql", "documentation": "https://www.home-assistant.io/integrations/sql",
"requirements": ["sqlalchemy==1.3.23"], "requirements": ["sqlalchemy==1.4.11"],
"codeowners": ["@dgomes"], "codeowners": ["@dgomes"],
"iot_class": "local_polling" "iot_class": "local_polling"
} }

View file

@ -151,7 +151,7 @@ class SQLSensor(SensorEntity):
self._state = None self._state = None
return return
for res in result: for res in result.mappings():
_LOGGER.debug("result = %s", res.items()) _LOGGER.debug("result = %s", res.items())
data = res[self._column_name] data = res[self._column_name]
for key, value in res.items(): for key, value in res.items():

View file

@ -30,7 +30,7 @@ pyyaml==5.4.1
requests==2.25.1 requests==2.25.1
ruamel.yaml==0.15.100 ruamel.yaml==0.15.100
scapy==2.4.4 scapy==2.4.4
sqlalchemy==1.3.23 sqlalchemy==1.4.11
voluptuous-serialize==2.4.0 voluptuous-serialize==2.4.0
voluptuous==0.12.1 voluptuous==0.12.1
yarl==1.6.3 yarl==1.6.3

View file

@ -2136,7 +2136,7 @@ spotipy==2.18.0
# homeassistant.components.recorder # homeassistant.components.recorder
# homeassistant.components.sql # homeassistant.components.sql
sqlalchemy==1.3.23 sqlalchemy==1.4.11
# homeassistant.components.srp_energy # homeassistant.components.srp_energy
srpenergy==1.3.2 srpenergy==1.3.2

View file

@ -1138,7 +1138,7 @@ spotipy==2.18.0
# homeassistant.components.recorder # homeassistant.components.recorder
# homeassistant.components.sql # homeassistant.components.sql
sqlalchemy==1.3.23 sqlalchemy==1.4.11
# homeassistant.components.srp_energy # homeassistant.components.srp_energy
srpenergy==1.3.2 srpenergy==1.3.2

View file

@ -19,7 +19,7 @@ from sqlalchemy import (
Text, Text,
distinct, distinct,
) )
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import declarative_base
from homeassistant.core import Event, EventOrigin, State, split_entity_id from homeassistant.core import Event, EventOrigin, State, split_entity_id
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder

View file

@ -2,16 +2,17 @@
# pylint: disable=protected-access # pylint: disable=protected-access
import datetime import datetime
import sqlite3 import sqlite3
from unittest.mock import Mock, PropertyMock, call, patch from unittest.mock import ANY, Mock, PropertyMock, call, patch
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine, text
from sqlalchemy.exc import ( from sqlalchemy.exc import (
DatabaseError, DatabaseError,
InternalError, InternalError,
OperationalError, OperationalError,
ProgrammingError, ProgrammingError,
) )
from sqlalchemy.orm import Session
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
from homeassistant.bootstrap import async_setup_component from homeassistant.bootstrap import async_setup_component
@ -64,7 +65,7 @@ async def test_schema_update_calls(hass):
assert await recorder.async_migration_in_progress(hass) is False assert await recorder.async_migration_in_progress(hass) is False
update.assert_has_calls( update.assert_has_calls(
[ [
call(hass.data[DATA_INSTANCE].engine, version + 1, 0) call(hass.data[DATA_INSTANCE].engine, ANY, version + 1, 0)
for version in range(0, models.SCHEMA_VERSION) for version in range(0, models.SCHEMA_VERSION)
] ]
) )
@ -259,7 +260,7 @@ async def test_schema_migrate(hass):
def test_invalid_update(): def test_invalid_update():
"""Test that an invalid new version raises an exception.""" """Test that an invalid new version raises an exception."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
migration._apply_update(None, -1, 0) migration._apply_update(Mock(), Mock(), -1, 0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -273,28 +274,31 @@ def test_invalid_update():
) )
def test_modify_column(engine_type, substr): def test_modify_column(engine_type, substr):
"""Test that modify column generates the expected query.""" """Test that modify column generates the expected query."""
connection = Mock()
engine = Mock() engine = Mock()
engine.dialect.name = engine_type engine.dialect.name = engine_type
migration._modify_columns(engine, "events", ["event_type VARCHAR(64)"]) migration._modify_columns(connection, engine, "events", ["event_type VARCHAR(64)"])
if substr: if substr:
assert substr in engine.execute.call_args[0][0].text assert substr in connection.execute.call_args[0][0].text
else: else:
assert not engine.execute.called assert not connection.execute.called
def test_forgiving_add_column(): def test_forgiving_add_column():
"""Test that add column will continue if column exists.""" """Test that add column will continue if column exists."""
engine = create_engine("sqlite://", poolclass=StaticPool) engine = create_engine("sqlite://", poolclass=StaticPool)
engine.execute("CREATE TABLE hello (id int)") with Session(engine) as session:
migration._add_columns(engine, "hello", ["context_id CHARACTER(36)"]) session.execute(text("CREATE TABLE hello (id int)"))
migration._add_columns(engine, "hello", ["context_id CHARACTER(36)"]) migration._add_columns(session, "hello", ["context_id CHARACTER(36)"])
migration._add_columns(session, "hello", ["context_id CHARACTER(36)"])
def test_forgiving_add_index(): def test_forgiving_add_index():
"""Test that add index will continue if index exists.""" """Test that add index will continue if index exists."""
engine = create_engine("sqlite://", poolclass=StaticPool) engine = create_engine("sqlite://", poolclass=StaticPool)
models.Base.metadata.create_all(engine) models.Base.metadata.create_all(engine)
migration._create_index(engine, "states", "ix_states_context_id") with Session(engine) as session:
migration._create_index(session, "states", "ix_states_context_id")
@pytest.mark.parametrize( @pytest.mark.parametrize(

View file

@ -5,12 +5,12 @@ import sqlite3
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from sqlalchemy import text
from homeassistant.components.recorder import run_information_with_session, util from homeassistant.components.recorder import run_information_with_session, util
from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX
from homeassistant.components.recorder.models import RecorderRuns from homeassistant.components.recorder.models import RecorderRuns
from homeassistant.components.recorder.util import end_incomplete_runs, session_scope from homeassistant.components.recorder.util import end_incomplete_runs, session_scope
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from .common import corrupt_db_file from .common import corrupt_db_file
@ -55,7 +55,7 @@ def test_recorder_bad_commit(hass_recorder):
def work(session): def work(session):
"""Bad work.""" """Bad work."""
session.execute("select * from notthere") session.execute(text("select * from notthere"))
with patch( with patch(
"homeassistant.components.recorder.time.sleep" "homeassistant.components.recorder.time.sleep"
@ -122,7 +122,7 @@ async def test_last_run_was_recently_clean(hass):
is False is False
) )
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) await hass.async_add_executor_job(hass.data[DATA_INSTANCE]._end_session)
await hass.async_block_till_done() await hass.async_block_till_done()
assert ( assert (