From 99cf4a6b2d99ce576cf2aee8665ac1388e67d82a Mon Sep 17 00:00:00 2001 From: G Johansson Date: Fri, 8 Dec 2023 21:13:37 +0100 Subject: [PATCH] Add rollback on exception that needs rollback in SQL (#104948) --- homeassistant/components/sql/sensor.py | 2 ++ tests/components/sql/test_sensor.py | 48 ++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/homeassistant/components/sql/sensor.py b/homeassistant/components/sql/sensor.py index 3fdc6b2c079..c4e6db4c623 100644 --- a/homeassistant/components/sql/sensor.py +++ b/homeassistant/components/sql/sensor.py @@ -362,6 +362,8 @@ class SQLSensor(ManualTriggerSensorEntity): self._query, redact_credentials(str(err)), ) + sess.rollback() + sess.close() return for res in result.mappings(): diff --git a/tests/components/sql/test_sensor.py b/tests/components/sql/test_sensor.py index cb988d3f2d4..cdc9a8e07a6 100644 --- a/tests/components/sql/test_sensor.py +++ b/tests/components/sql/test_sensor.py @@ -5,6 +5,7 @@ from datetime import timedelta from typing import Any from unittest.mock import patch +from freezegun.api import FrozenDateTimeFactory import pytest from sqlalchemy import text as sql_text from sqlalchemy.exc import SQLAlchemyError @@ -12,6 +13,7 @@ from sqlalchemy.exc import SQLAlchemyError from homeassistant.components.recorder import Recorder from homeassistant.components.sensor import SensorDeviceClass, SensorStateClass from homeassistant.components.sql.const import CONF_QUERY, DOMAIN +from homeassistant.components.sql.sensor import _generate_lambda_stmt from homeassistant.config_entries import SOURCE_USER from homeassistant.const import ( CONF_ICON, @@ -21,6 +23,7 @@ from homeassistant.const import ( ) from homeassistant.core import HomeAssistant from homeassistant.helpers import issue_registry as ir +from homeassistant.helpers.entity_platform import async_get_platforms from homeassistant.setup import async_setup_component from homeassistant.util import dt as dt_util @@ -570,3 +573,48 @@ async def test_attributes_from_entry_config( assert state.attributes["unit_of_measurement"] == "MiB" assert "device_class" not in state.attributes assert "state_class" not in state.attributes + + +async def test_query_recover_from_rollback( + recorder_mock: Recorder, + hass: HomeAssistant, + freezer: FrozenDateTimeFactory, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test the SQL sensor.""" + config = { + "db_url": "sqlite://", + "query": "SELECT 5 as value", + "column": "value", + "name": "Select value SQL query", + "unique_id": "very_unique_id", + } + await init_integration(hass, config) + platforms = async_get_platforms(hass, "sql") + sql_entity = platforms[0].entities["sensor.select_value_sql_query"] + + state = hass.states.get("sensor.select_value_sql_query") + assert state.state == "5" + assert state.attributes["value"] == 5 + + with patch.object( + sql_entity, + "_lambda_stmt", + _generate_lambda_stmt("Faulty syntax create operational issue"), + ): + freezer.tick(timedelta(minutes=1)) + async_fire_time_changed(hass) + await hass.async_block_till_done() + assert "sqlite3.OperationalError" in caplog.text + + state = hass.states.get("sensor.select_value_sql_query") + assert state.state == "5" + assert state.attributes.get("value") is None + + freezer.tick(timedelta(minutes=1)) + async_fire_time_changed(hass) + await hass.async_block_till_done() + + state = hass.states.get("sensor.select_value_sql_query") + assert state.state == "5" + assert state.attributes.get("value") == 5