Modify query as template in SQL integration
This commit is contained in:
parent
611723e44b
commit
45aab80b59
5 changed files with 52 additions and 36 deletions
|
@ -26,6 +26,7 @@ from homeassistant.const import (
|
|||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import discovery
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.template import Template
|
||||
from homeassistant.helpers.trigger_template_entity import (
|
||||
CONF_AVAILABILITY,
|
||||
CONF_PICTURE,
|
||||
|
@ -38,23 +39,24 @@ from .util import redact_credentials
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_sql_select(value: str) -> str:
|
||||
def validate_sql_select(value: Template) -> Template:
|
||||
"""Validate that value is a SQL SELECT query."""
|
||||
if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1:
|
||||
rendered_value = value.async_render()
|
||||
if len(query := sqlparse.parse(rendered_value.lstrip().lstrip(";"))) > 1:
|
||||
raise vol.Invalid("Multiple SQL queries are not supported")
|
||||
if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN":
|
||||
raise vol.Invalid("Invalid SQL query")
|
||||
if query_type != "SELECT":
|
||||
_LOGGER.debug("The SQL query %s is of type %s", query, query_type)
|
||||
raise vol.Invalid("Only SELECT queries allowed")
|
||||
return str(query[0])
|
||||
return value
|
||||
|
||||
|
||||
QUERY_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_COLUMN_NAME): cv.string,
|
||||
vol.Required(CONF_NAME): cv.template,
|
||||
vol.Required(CONF_QUERY): vol.All(cv.string, validate_sql_select),
|
||||
vol.Required(CONF_QUERY): vol.All(cv.template, validate_sql_select),
|
||||
vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string,
|
||||
vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
|
||||
vol.Optional(CONF_UNIQUE_ID): cv.string,
|
||||
|
|
|
@ -31,8 +31,9 @@ from homeassistant.const import (
|
|||
CONF_UNIT_OF_MEASUREMENT,
|
||||
CONF_VALUE_TEMPLATE,
|
||||
)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import async_get_hass, callback
|
||||
from homeassistant.helpers import selector
|
||||
from homeassistant.helpers.template import Template
|
||||
|
||||
from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
|
||||
from .util import resolve_db_url
|
||||
|
@ -50,7 +51,7 @@ OPTIONS_SCHEMA: vol.Schema = vol.Schema(
|
|||
): selector.TextSelector(),
|
||||
vol.Required(
|
||||
CONF_QUERY,
|
||||
): selector.TextSelector(selector.TextSelectorConfig(multiline=True)),
|
||||
): selector.TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_UNIT_OF_MEASUREMENT,
|
||||
): selector.TextSelector(),
|
||||
|
@ -89,6 +90,7 @@ CONFIG_SCHEMA: vol.Schema = vol.Schema(
|
|||
|
||||
def validate_sql_select(value: str) -> str:
|
||||
"""Validate that value is a SQL SELECT query."""
|
||||
value = Template(value, async_get_hass()).async_render()
|
||||
if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1:
|
||||
raise MultipleResultsFound
|
||||
if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN":
|
||||
|
@ -96,7 +98,7 @@ def validate_sql_select(value: str) -> str:
|
|||
if query_type != "SELECT":
|
||||
_LOGGER.debug("The SQL query %s is of type %s", query, query_type)
|
||||
raise SQLParseError
|
||||
return str(query[0])
|
||||
return value
|
||||
|
||||
|
||||
def validate_query(db_url: str, query: str, column: str) -> bool:
|
||||
|
@ -160,10 +162,10 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
db_url_for_validation = None
|
||||
|
||||
try:
|
||||
query = validate_sql_select(query)
|
||||
test_query = validate_sql_select(query)
|
||||
db_url_for_validation = resolve_db_url(self.hass, db_url)
|
||||
await self.hass.async_add_executor_job(
|
||||
validate_query, db_url_for_validation, query, column
|
||||
validate_query, db_url_for_validation, test_query, column
|
||||
)
|
||||
except NoSuchColumnError:
|
||||
errors["column"] = "column_invalid"
|
||||
|
@ -226,10 +228,10 @@ class SQLOptionsFlowHandler(OptionsFlowWithConfigEntry):
|
|||
name = self.options.get(CONF_NAME, self.config_entry.title)
|
||||
|
||||
try:
|
||||
query = validate_sql_select(query)
|
||||
test_query = validate_sql_select(query)
|
||||
db_url_for_validation = resolve_db_url(self.hass, db_url)
|
||||
await self.hass.async_add_executor_job(
|
||||
validate_query, db_url_for_validation, query, column
|
||||
validate_query, db_url_for_validation, test_query, column
|
||||
)
|
||||
except NoSuchColumnError:
|
||||
errors["column"] = "column_invalid"
|
||||
|
|
|
@ -75,7 +75,7 @@ async def async_setup_platform(
|
|||
return
|
||||
|
||||
name: Template = conf[CONF_NAME]
|
||||
query_str: str = conf[CONF_QUERY]
|
||||
query: Template = conf[CONF_QUERY]
|
||||
value_template: Template | None = conf.get(CONF_VALUE_TEMPLATE)
|
||||
column_name: str = conf[CONF_COLUMN_NAME]
|
||||
unique_id: str | None = conf.get(CONF_UNIQUE_ID)
|
||||
|
@ -90,7 +90,7 @@ async def async_setup_platform(
|
|||
await async_setup_sensor(
|
||||
hass,
|
||||
trigger_entity_config,
|
||||
query_str,
|
||||
query,
|
||||
column_name,
|
||||
value_template,
|
||||
unique_id,
|
||||
|
@ -119,6 +119,7 @@ async def async_setup_entry(
|
|||
except TemplateError:
|
||||
value_template = None
|
||||
|
||||
query = Template(query_str, hass)
|
||||
name_template = Template(name, hass)
|
||||
trigger_entity_config = {CONF_NAME: name_template, CONF_UNIQUE_ID: entry.entry_id}
|
||||
for key in TRIGGER_ENTITY_OPTIONS:
|
||||
|
@ -129,7 +130,7 @@ async def async_setup_entry(
|
|||
await async_setup_sensor(
|
||||
hass,
|
||||
trigger_entity_config,
|
||||
query_str,
|
||||
query,
|
||||
column_name,
|
||||
value_template,
|
||||
entry.entry_id,
|
||||
|
@ -172,7 +173,7 @@ def _async_get_or_init_domain_data(hass: HomeAssistant) -> SQLData:
|
|||
async def async_setup_sensor(
|
||||
hass: HomeAssistant,
|
||||
trigger_entity_config: ConfigType,
|
||||
query_str: str,
|
||||
query: Template,
|
||||
column_name: str,
|
||||
value_template: Template | None,
|
||||
unique_id: str | None,
|
||||
|
@ -209,9 +210,10 @@ async def async_setup_sensor(
|
|||
else:
|
||||
return
|
||||
|
||||
upper_query = query_str.upper()
|
||||
rendered_query = query.async_render()
|
||||
upper_query = rendered_query.upper()
|
||||
if uses_recorder_db:
|
||||
redacted_query = redact_credentials(query_str)
|
||||
redacted_query = redact_credentials(rendered_query)
|
||||
|
||||
issue_key = unique_id if unique_id else redacted_query
|
||||
# If the query has a unique id and they fix it we can dismiss the issue
|
||||
|
@ -247,17 +249,20 @@ async def async_setup_sensor(
|
|||
|
||||
# MSSQL uses TOP and not LIMIT
|
||||
if not ("LIMIT" in upper_query or "SELECT TOP" in upper_query):
|
||||
query_str = query.template
|
||||
if "mssql" in db_url:
|
||||
query_str = upper_query.replace("SELECT", "SELECT TOP 1")
|
||||
query = Template(query_str, hass)
|
||||
else:
|
||||
query_str = query_str.replace(";", "") + " LIMIT 1;"
|
||||
query = Template(query_str, hass)
|
||||
|
||||
async_add_entities(
|
||||
[
|
||||
SQLSensor(
|
||||
trigger_entity_config,
|
||||
sessmaker,
|
||||
query_str,
|
||||
query,
|
||||
column_name,
|
||||
value_template,
|
||||
yaml,
|
||||
|
@ -309,7 +314,7 @@ class SQLSensor(ManualTriggerSensorEntity):
|
|||
self,
|
||||
trigger_entity_config: ConfigType,
|
||||
sessmaker: scoped_session,
|
||||
query: str,
|
||||
query: Template,
|
||||
column: str,
|
||||
value_template: Template | None,
|
||||
yaml: bool,
|
||||
|
@ -323,7 +328,7 @@ class SQLSensor(ManualTriggerSensorEntity):
|
|||
self.sessionmaker = sessmaker
|
||||
self._attr_extra_state_attributes = {}
|
||||
self._use_database_executor = use_database_executor
|
||||
self._lambda_stmt = _generate_lambda_stmt(query)
|
||||
self._lambda_stmt = _generate_lambda_stmt(query.async_render())
|
||||
if not yaml and (unique_id := trigger_entity_config.get(CONF_UNIQUE_ID)):
|
||||
self._attr_name = None
|
||||
self._attr_has_entity_name = True
|
||||
|
@ -346,6 +351,7 @@ class SQLSensor(ManualTriggerSensorEntity):
|
|||
|
||||
async def async_update(self) -> None:
|
||||
"""Retrieve sensor data from the query using the right executor."""
|
||||
self._lambda_stmt = _generate_lambda_stmt(self._query.async_render())
|
||||
if self._use_database_executor:
|
||||
data = await get_instance(self.hass).async_add_executor_job(self._update)
|
||||
else:
|
||||
|
@ -362,7 +368,7 @@ class SQLSensor(ManualTriggerSensorEntity):
|
|||
except SQLAlchemyError as err:
|
||||
_LOGGER.error(
|
||||
"Error executing query %s: %s",
|
||||
self._query,
|
||||
self._query.template,
|
||||
redact_credentials(str(err)),
|
||||
)
|
||||
sess.rollback()
|
||||
|
@ -370,7 +376,7 @@ class SQLSensor(ManualTriggerSensorEntity):
|
|||
return None
|
||||
|
||||
for res in result.mappings():
|
||||
_LOGGER.debug("Query %s result in %s", self._query, res.items())
|
||||
_LOGGER.debug("Query %s result in %s", self._query.template, res.items())
|
||||
data = res[self._column_name]
|
||||
for key, value in res.items():
|
||||
if isinstance(value, decimal.Decimal):
|
||||
|
@ -392,7 +398,7 @@ class SQLSensor(ManualTriggerSensorEntity):
|
|||
self._attr_native_value = data
|
||||
|
||||
if data is None:
|
||||
_LOGGER.warning("%s returned no results", self._query)
|
||||
_LOGGER.warning("%s returned no results", self._query.template)
|
||||
|
||||
sess.close()
|
||||
return data
|
||||
|
|
|
@ -13,6 +13,7 @@ from homeassistant.components.sql import validate_sql_select
|
|||
from homeassistant.components.sql.const import DOMAIN
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.template import Template
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import YAML_CONFIG_INVALID, YAML_CONFIG_NO_DB, init_integration
|
||||
|
@ -57,33 +58,39 @@ async def test_setup_invalid_config(
|
|||
async def test_invalid_query(hass: HomeAssistant) -> None:
|
||||
"""Test invalid query."""
|
||||
with pytest.raises(vol.Invalid):
|
||||
validate_sql_select("DROP TABLE *")
|
||||
validate_sql_select(Template("DROP TABLE *"))
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
validate_sql_select("SELECT5 as value")
|
||||
validate_sql_select(Template("SELECT5 as value"))
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
validate_sql_select(";;")
|
||||
validate_sql_select(Template(";;"))
|
||||
|
||||
|
||||
async def test_query_no_read_only(hass: HomeAssistant) -> None:
|
||||
"""Test query no read only."""
|
||||
with pytest.raises(vol.Invalid):
|
||||
validate_sql_select("UPDATE states SET state = 999999 WHERE state_id = 11125")
|
||||
validate_sql_select(
|
||||
Template("UPDATE states SET state = 999999 WHERE state_id = 11125")
|
||||
)
|
||||
|
||||
|
||||
async def test_query_no_read_only_cte(hass: HomeAssistant) -> None:
|
||||
"""Test query no read only CTE."""
|
||||
with pytest.raises(vol.Invalid):
|
||||
validate_sql_select(
|
||||
Template(
|
||||
"WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def test_multiple_queries(hass: HomeAssistant) -> None:
|
||||
"""Test multiple queries."""
|
||||
with pytest.raises(vol.Invalid):
|
||||
validate_sql_select("SELECT 5 as value; UPDATE states SET state = 10;")
|
||||
validate_sql_select(
|
||||
Template("SELECT 5 as value; UPDATE states SET state = 10;")
|
||||
)
|
||||
|
||||
|
||||
async def test_remove_configured_db_url_if_not_needed_when_not_needed(
|
||||
|
|
|
@ -25,7 +25,6 @@ from homeassistant.const import (
|
|||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import issue_registry as ir
|
||||
from homeassistant.helpers.entity_platform import async_get_platforms
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
|
@ -613,17 +612,17 @@ async def test_query_recover_from_rollback(
|
|||
"unique_id": "very_unique_id",
|
||||
}
|
||||
await init_integration(hass, config)
|
||||
platforms = async_get_platforms(hass, "sql")
|
||||
sql_entity = platforms[0].entities["sensor.select_value_sql_query"]
|
||||
|
||||
state = hass.states.get("sensor.select_value_sql_query")
|
||||
assert state.state == "5"
|
||||
assert state.attributes["value"] == 5
|
||||
|
||||
with patch.object(
|
||||
sql_entity,
|
||||
"_lambda_stmt",
|
||||
_generate_lambda_stmt("Faulty syntax create operational issue"),
|
||||
incorrect_lambda_stmt = _generate_lambda_stmt(
|
||||
"Faulty syntax create operational issue"
|
||||
)
|
||||
with patch(
|
||||
"homeassistant.components.sql.sensor._generate_lambda_stmt",
|
||||
return_value=incorrect_lambda_stmt,
|
||||
):
|
||||
freezer.tick(timedelta(minutes=1))
|
||||
async_fire_time_changed(hass)
|
||||
|
|
Loading…
Add table
Reference in a new issue