Add rollback on exception that needs rollback in SQL (#104948)

This commit is contained in:
G Johansson 2023-12-08 21:13:37 +01:00 committed by GitHub
parent 4bb0e13cda
commit 99cf4a6b2d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 0 deletions

View file

@ -362,6 +362,8 @@ class SQLSensor(ManualTriggerSensorEntity):
self._query,
redact_credentials(str(err)),
)
sess.rollback()
sess.close()
return
for res in result.mappings():

View file

@ -5,6 +5,7 @@ from datetime import timedelta
from typing import Any
from unittest.mock import patch
from freezegun.api import FrozenDateTimeFactory
import pytest
from sqlalchemy import text as sql_text
from sqlalchemy.exc import SQLAlchemyError
@ -12,6 +13,7 @@ from sqlalchemy.exc import SQLAlchemyError
from homeassistant.components.recorder import Recorder
from homeassistant.components.sensor import SensorDeviceClass, SensorStateClass
from homeassistant.components.sql.const import CONF_QUERY, DOMAIN
from homeassistant.components.sql.sensor import _generate_lambda_stmt
from homeassistant.config_entries import SOURCE_USER
from homeassistant.const import (
CONF_ICON,
@ -21,6 +23,7 @@ from homeassistant.const import (
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import issue_registry as ir
from homeassistant.helpers.entity_platform import async_get_platforms
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
@ -570,3 +573,48 @@ async def test_attributes_from_entry_config(
assert state.attributes["unit_of_measurement"] == "MiB"
assert "device_class" not in state.attributes
assert "state_class" not in state.attributes
async def test_query_recover_from_rollback(
recorder_mock: Recorder,
hass: HomeAssistant,
freezer: FrozenDateTimeFactory,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test the SQL sensor."""
config = {
"db_url": "sqlite://",
"query": "SELECT 5 as value",
"column": "value",
"name": "Select value SQL query",
"unique_id": "very_unique_id",
}
await init_integration(hass, config)
platforms = async_get_platforms(hass, "sql")
sql_entity = platforms[0].entities["sensor.select_value_sql_query"]
state = hass.states.get("sensor.select_value_sql_query")
assert state.state == "5"
assert state.attributes["value"] == 5
with patch.object(
sql_entity,
"_lambda_stmt",
_generate_lambda_stmt("Faulty syntax create operational issue"),
):
freezer.tick(timedelta(minutes=1))
async_fire_time_changed(hass)
await hass.async_block_till_done()
assert "sqlite3.OperationalError" in caplog.text
state = hass.states.get("sensor.select_value_sql_query")
assert state.state == "5"
assert state.attributes.get("value") is None
freezer.tick(timedelta(minutes=1))
async_fire_time_changed(hass)
await hass.async_block_till_done()
state = hass.states.get("sensor.select_value_sql_query")
assert state.state == "5"
assert state.attributes.get("value") == 5