diff --git a/homeassistant/components/sql/__init__.py b/homeassistant/components/sql/__init__.py index 71e3671ce96..7267ed04408 100644 --- a/homeassistant/components/sql/__init__.py +++ b/homeassistant/components/sql/__init__.py @@ -26,6 +26,7 @@ from homeassistant.const import ( from homeassistant.core import HomeAssistant from homeassistant.helpers import discovery import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.template import Template from homeassistant.helpers.trigger_template_entity import ( CONF_AVAILABILITY, CONF_PICTURE, @@ -38,23 +39,24 @@ from .util import redact_credentials _LOGGER = logging.getLogger(__name__) -def validate_sql_select(value: str) -> str: +def validate_sql_select(value: Template) -> Template: """Validate that value is a SQL SELECT query.""" - if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1: + rendered_value = value.async_render() + if len(query := sqlparse.parse(rendered_value.lstrip().lstrip(";"))) > 1: raise vol.Invalid("Multiple SQL queries are not supported") if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN": raise vol.Invalid("Invalid SQL query") if query_type != "SELECT": _LOGGER.debug("The SQL query %s is of type %s", query, query_type) raise vol.Invalid("Only SELECT queries allowed") - return str(query[0]) + return value QUERY_SCHEMA = vol.Schema( { vol.Required(CONF_COLUMN_NAME): cv.string, vol.Required(CONF_NAME): cv.template, - vol.Required(CONF_QUERY): vol.All(cv.string, validate_sql_select), + vol.Required(CONF_QUERY): vol.All(cv.template, validate_sql_select), vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, vol.Optional(CONF_VALUE_TEMPLATE): cv.template, vol.Optional(CONF_UNIQUE_ID): cv.string, diff --git a/homeassistant/components/sql/config_flow.py b/homeassistant/components/sql/config_flow.py index 5537c7ff3b0..5acc3637230 100644 --- a/homeassistant/components/sql/config_flow.py +++ b/homeassistant/components/sql/config_flow.py @@ -31,8 +31,9 @@ from homeassistant.const import ( CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE, ) -from homeassistant.core import callback +from homeassistant.core import async_get_hass, callback from homeassistant.helpers import selector +from homeassistant.helpers.template import Template from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN from .util import resolve_db_url @@ -50,7 +51,7 @@ OPTIONS_SCHEMA: vol.Schema = vol.Schema( ): selector.TextSelector(), vol.Required( CONF_QUERY, - ): selector.TextSelector(selector.TextSelectorConfig(multiline=True)), + ): selector.TemplateSelector(), vol.Optional( CONF_UNIT_OF_MEASUREMENT, ): selector.TextSelector(), @@ -89,6 +90,7 @@ CONFIG_SCHEMA: vol.Schema = vol.Schema( def validate_sql_select(value: str) -> str: """Validate that value is a SQL SELECT query.""" + value = Template(value, async_get_hass()).async_render() if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1: raise MultipleResultsFound if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN": @@ -96,7 +98,7 @@ def validate_sql_select(value: str) -> str: if query_type != "SELECT": _LOGGER.debug("The SQL query %s is of type %s", query, query_type) raise SQLParseError - return str(query[0]) + return value def validate_query(db_url: str, query: str, column: str) -> bool: @@ -160,10 +162,10 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN): db_url_for_validation = None try: - query = validate_sql_select(query) + test_query = validate_sql_select(query) db_url_for_validation = resolve_db_url(self.hass, db_url) await self.hass.async_add_executor_job( - validate_query, db_url_for_validation, query, column + validate_query, db_url_for_validation, test_query, column ) except NoSuchColumnError: errors["column"] = "column_invalid" @@ -226,10 +228,10 @@ class SQLOptionsFlowHandler(OptionsFlowWithConfigEntry): name = self.options.get(CONF_NAME, self.config_entry.title) try: - query = validate_sql_select(query) + test_query = validate_sql_select(query) db_url_for_validation = resolve_db_url(self.hass, db_url) await self.hass.async_add_executor_job( - validate_query, db_url_for_validation, query, column + validate_query, db_url_for_validation, test_query, column ) except NoSuchColumnError: errors["column"] = "column_invalid" diff --git a/homeassistant/components/sql/sensor.py b/homeassistant/components/sql/sensor.py index 1d033728c0d..0e876e81165 100644 --- a/homeassistant/components/sql/sensor.py +++ b/homeassistant/components/sql/sensor.py @@ -75,7 +75,7 @@ async def async_setup_platform( return name: Template = conf[CONF_NAME] - query_str: str = conf[CONF_QUERY] + query: Template = conf[CONF_QUERY] value_template: Template | None = conf.get(CONF_VALUE_TEMPLATE) column_name: str = conf[CONF_COLUMN_NAME] unique_id: str | None = conf.get(CONF_UNIQUE_ID) @@ -90,7 +90,7 @@ async def async_setup_platform( await async_setup_sensor( hass, trigger_entity_config, - query_str, + query, column_name, value_template, unique_id, @@ -119,6 +119,7 @@ async def async_setup_entry( except TemplateError: value_template = None + query = Template(query_str, hass) name_template = Template(name, hass) trigger_entity_config = {CONF_NAME: name_template, CONF_UNIQUE_ID: entry.entry_id} for key in TRIGGER_ENTITY_OPTIONS: @@ -129,7 +130,7 @@ async def async_setup_entry( await async_setup_sensor( hass, trigger_entity_config, - query_str, + query, column_name, value_template, entry.entry_id, @@ -172,7 +173,7 @@ def _async_get_or_init_domain_data(hass: HomeAssistant) -> SQLData: async def async_setup_sensor( hass: HomeAssistant, trigger_entity_config: ConfigType, - query_str: str, + query: Template, column_name: str, value_template: Template | None, unique_id: str | None, @@ -209,9 +210,10 @@ async def async_setup_sensor( else: return - upper_query = query_str.upper() + rendered_query = query.async_render() + upper_query = rendered_query.upper() if uses_recorder_db: - redacted_query = redact_credentials(query_str) + redacted_query = redact_credentials(rendered_query) issue_key = unique_id if unique_id else redacted_query # If the query has a unique id and they fix it we can dismiss the issue @@ -247,17 +249,20 @@ async def async_setup_sensor( # MSSQL uses TOP and not LIMIT if not ("LIMIT" in upper_query or "SELECT TOP" in upper_query): + query_str = query.template if "mssql" in db_url: query_str = upper_query.replace("SELECT", "SELECT TOP 1") + query = Template(query_str, hass) else: query_str = query_str.replace(";", "") + " LIMIT 1;" + query = Template(query_str, hass) async_add_entities( [ SQLSensor( trigger_entity_config, sessmaker, - query_str, + query, column_name, value_template, yaml, @@ -309,7 +314,7 @@ class SQLSensor(ManualTriggerSensorEntity): self, trigger_entity_config: ConfigType, sessmaker: scoped_session, - query: str, + query: Template, column: str, value_template: Template | None, yaml: bool, @@ -323,7 +328,7 @@ class SQLSensor(ManualTriggerSensorEntity): self.sessionmaker = sessmaker self._attr_extra_state_attributes = {} self._use_database_executor = use_database_executor - self._lambda_stmt = _generate_lambda_stmt(query) + self._lambda_stmt = _generate_lambda_stmt(query.async_render()) if not yaml and (unique_id := trigger_entity_config.get(CONF_UNIQUE_ID)): self._attr_name = None self._attr_has_entity_name = True @@ -346,6 +351,7 @@ class SQLSensor(ManualTriggerSensorEntity): async def async_update(self) -> None: """Retrieve sensor data from the query using the right executor.""" + self._lambda_stmt = _generate_lambda_stmt(self._query.async_render()) if self._use_database_executor: data = await get_instance(self.hass).async_add_executor_job(self._update) else: @@ -362,7 +368,7 @@ class SQLSensor(ManualTriggerSensorEntity): except SQLAlchemyError as err: _LOGGER.error( "Error executing query %s: %s", - self._query, + self._query.template, redact_credentials(str(err)), ) sess.rollback() @@ -370,7 +376,7 @@ class SQLSensor(ManualTriggerSensorEntity): return None for res in result.mappings(): - _LOGGER.debug("Query %s result in %s", self._query, res.items()) + _LOGGER.debug("Query %s result in %s", self._query.template, res.items()) data = res[self._column_name] for key, value in res.items(): if isinstance(value, decimal.Decimal): @@ -392,7 +398,7 @@ class SQLSensor(ManualTriggerSensorEntity): self._attr_native_value = data if data is None: - _LOGGER.warning("%s returned no results", self._query) + _LOGGER.warning("%s returned no results", self._query.template) sess.close() return data diff --git a/tests/components/sql/test_init.py b/tests/components/sql/test_init.py index 409ebca27c0..9d6beed778c 100644 --- a/tests/components/sql/test_init.py +++ b/tests/components/sql/test_init.py @@ -13,6 +13,7 @@ from homeassistant.components.sql import validate_sql_select from homeassistant.components.sql.const import DOMAIN from homeassistant.config_entries import ConfigEntryState from homeassistant.core import HomeAssistant +from homeassistant.helpers.template import Template from homeassistant.setup import async_setup_component from . import YAML_CONFIG_INVALID, YAML_CONFIG_NO_DB, init_integration @@ -57,33 +58,39 @@ async def test_setup_invalid_config( async def test_invalid_query(hass: HomeAssistant) -> None: """Test invalid query.""" with pytest.raises(vol.Invalid): - validate_sql_select("DROP TABLE *") + validate_sql_select(Template("DROP TABLE *")) with pytest.raises(vol.Invalid): - validate_sql_select("SELECT5 as value") + validate_sql_select(Template("SELECT5 as value")) with pytest.raises(vol.Invalid): - validate_sql_select(";;") + validate_sql_select(Template(";;")) async def test_query_no_read_only(hass: HomeAssistant) -> None: """Test query no read only.""" with pytest.raises(vol.Invalid): - validate_sql_select("UPDATE states SET state = 999999 WHERE state_id = 11125") + validate_sql_select( + Template("UPDATE states SET state = 999999 WHERE state_id = 11125") + ) async def test_query_no_read_only_cte(hass: HomeAssistant) -> None: """Test query no read only CTE.""" with pytest.raises(vol.Invalid): validate_sql_select( - "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;" + Template( + "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;" + ) ) async def test_multiple_queries(hass: HomeAssistant) -> None: """Test multiple queries.""" with pytest.raises(vol.Invalid): - validate_sql_select("SELECT 5 as value; UPDATE states SET state = 10;") + validate_sql_select( + Template("SELECT 5 as value; UPDATE states SET state = 10;") + ) async def test_remove_configured_db_url_if_not_needed_when_not_needed( diff --git a/tests/components/sql/test_sensor.py b/tests/components/sql/test_sensor.py index b219ad47f3a..499a9547762 100644 --- a/tests/components/sql/test_sensor.py +++ b/tests/components/sql/test_sensor.py @@ -25,7 +25,6 @@ from homeassistant.const import ( ) from homeassistant.core import HomeAssistant from homeassistant.helpers import issue_registry as ir -from homeassistant.helpers.entity_platform import async_get_platforms from homeassistant.setup import async_setup_component from homeassistant.util import dt as dt_util @@ -613,17 +612,17 @@ async def test_query_recover_from_rollback( "unique_id": "very_unique_id", } await init_integration(hass, config) - platforms = async_get_platforms(hass, "sql") - sql_entity = platforms[0].entities["sensor.select_value_sql_query"] state = hass.states.get("sensor.select_value_sql_query") assert state.state == "5" assert state.attributes["value"] == 5 - with patch.object( - sql_entity, - "_lambda_stmt", - _generate_lambda_stmt("Faulty syntax create operational issue"), + incorrect_lambda_stmt = _generate_lambda_stmt( + "Faulty syntax create operational issue" + ) + with patch( + "homeassistant.components.sql.sensor._generate_lambda_stmt", + return_value=incorrect_lambda_stmt, ): freezer.tick(timedelta(minutes=1)) async_fire_time_changed(hass)