Fix sql doing I/O in the event loop at startup (#90335)

* Fix sql doing I/O in the event loop

* Fix sql doing I/O in the event loop

* no test query on main db

* fix mocking because it was targeting the recorder
This commit is contained in:
J. Nick Koston 2023-03-26 15:02:24 -10:00 committed by GitHub
parent 75e28826e0
commit 7098debe09
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 23 deletions

View file

@ -136,24 +136,17 @@ async def async_setup_sensor(
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up the SQL sensor."""
try:
engine = sqlalchemy.create_engine(db_url, future=True)
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
# Run a dummy query just to test the db_url
sess: Session = sessmaker()
sess.execute(sqlalchemy.text("SELECT 1;"))
except SQLAlchemyError as err:
_LOGGER.error(
"Couldn't connect using %s DB_URL: %s",
redact_credentials(db_url),
redact_credentials(str(err)),
instance = get_instance(hass)
sessmaker: scoped_session | None
if use_database_executor := (db_url == instance.db_url):
assert instance.engine is not None
sessmaker = scoped_session(sessionmaker(bind=instance.engine, future=True))
elif not (
sessmaker := await hass.async_add_executor_job(
_validate_and_get_session_maker_for_db_url, db_url
)
):
return
finally:
if sess:
sess.close()
# MSSQL uses TOP and not LIMIT
if not ("LIMIT" in query_str.upper() or "SELECT TOP" in query_str.upper()):
@ -162,8 +155,6 @@ async def async_setup_sensor(
else:
query_str = query_str.replace(";", "") + " LIMIT 1;"
use_database_executor = db_url == get_instance(hass).db_url
async_add_entities(
[
SQLSensor(
@ -184,6 +175,32 @@ async def async_setup_sensor(
)
def _validate_and_get_session_maker_for_db_url(db_url: str) -> scoped_session | None:
"""Validate the db_url and return a session maker.
This does I/O and should be run in the executor.
"""
try:
engine = sqlalchemy.create_engine(db_url, future=True)
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
# Run a dummy query just to test the db_url
sess: Session = sessmaker()
sess.execute(sqlalchemy.text("SELECT 1;"))
except SQLAlchemyError as err:
_LOGGER.error(
"Couldn't connect using %s DB_URL: %s",
redact_credentials(db_url),
redact_credentials(str(err)),
)
return None
else:
return sessmaker
finally:
if sess:
sess.close()
class SQLSensor(SensorEntity):
"""Representation of an SQL sensor."""

View file

@ -2,6 +2,7 @@
from __future__ import annotations
from datetime import timedelta
from typing import Any
from unittest.mock import patch
import pytest
@ -193,14 +194,19 @@ async def test_invalid_url_on_update(
"column": "value",
"name": "count_tables",
}
await init_integration(hass, config)
class MockSession:
"""Mock session."""
def execute(self, query: Any) -> None:
"""Execute the query."""
raise SQLAlchemyError("sqlite://homeassistant:hunter2@homeassistant.local")
with patch(
"homeassistant.components.sql.sensor.sqlalchemy.engine.cursor.CursorResult",
side_effect=SQLAlchemyError(
"sqlite://homeassistant:hunter2@homeassistant.local"
),
"homeassistant.components.sql.sensor.scoped_session",
return_value=MockSession,
):
await init_integration(hass, config)
async_fire_time_changed(
hass,
dt.utcnow() + timedelta(minutes=1),