Improve code quality sql (#65321)

This commit is contained in:
G Johansson 2022-02-12 15:13:01 +01:00 committed by GitHub
parent 65ce2108d3
commit 8da150bd71
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 167 additions and 64 deletions

View file

@ -1,8 +1,7 @@
"""Sensor from an SQL Query.""" """Sensor from an SQL Query."""
from __future__ import annotations from __future__ import annotations
import datetime from datetime import date
import decimal
import logging import logging
import re import re
@ -11,11 +10,15 @@ from sqlalchemy.orm import scoped_session, sessionmaker
import voluptuous as vol import voluptuous as vol
from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL
from homeassistant.components.sensor import PLATFORM_SCHEMA, SensorEntity from homeassistant.components.sensor import (
PLATFORM_SCHEMA as PARENT_PLATFORM_SCHEMA,
SensorEntity,
)
from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -27,12 +30,12 @@ CONF_QUERY = "query"
DB_URL_RE = re.compile("//.*:.*@") DB_URL_RE = re.compile("//.*:.*@")
def redact_credentials(data): def redact_credentials(data: str) -> str:
"""Redact credentials from string data.""" """Redact credentials from string data."""
return DB_URL_RE.sub("//****:****@", data) return DB_URL_RE.sub("//****:****@", data)
def validate_sql_select(value): def validate_sql_select(value: str) -> str:
"""Validate that value is a SQL SELECT query.""" """Validate that value is a SQL SELECT query."""
if not value.lstrip().lower().startswith("select"): if not value.lstrip().lower().startswith("select"):
raise vol.Invalid("Only SELECT queries allowed") raise vol.Invalid("Only SELECT queries allowed")
@ -49,7 +52,7 @@ _QUERY_SCHEME = vol.Schema(
} }
) )
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( PLATFORM_SCHEMA = PARENT_PLATFORM_SCHEMA.extend(
{vol.Required(CONF_QUERIES): [_QUERY_SCHEME], vol.Optional(CONF_DB_URL): cv.string} {vol.Required(CONF_QUERIES): [_QUERY_SCHEME], vol.Optional(CONF_DB_URL): cv.string}
) )
@ -64,7 +67,7 @@ def setup_platform(
if not (db_url := config.get(CONF_DB_URL)): if not (db_url := config.get(CONF_DB_URL)):
db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE)) db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE))
sess = None sess: scoped_session | None = None
try: try:
engine = sqlalchemy.create_engine(db_url) engine = sqlalchemy.create_engine(db_url)
sessmaker = scoped_session(sessionmaker(bind=engine)) sessmaker = scoped_session(sessionmaker(bind=engine))
@ -87,11 +90,11 @@ def setup_platform(
queries = [] queries = []
for query in config[CONF_QUERIES]: for query in config[CONF_QUERIES]:
name = query.get(CONF_NAME) name: str = query[CONF_NAME]
query_str = query.get(CONF_QUERY) query_str: str = query[CONF_QUERY]
unit = query.get(CONF_UNIT_OF_MEASUREMENT) unit: str | None = query.get(CONF_UNIT_OF_MEASUREMENT)
value_template = query.get(CONF_VALUE_TEMPLATE) value_template: Template | None = query.get(CONF_VALUE_TEMPLATE)
column_name = query.get(CONF_COLUMN_NAME) column_name: str = query[CONF_COLUMN_NAME]
if value_template is not None: if value_template is not None:
value_template.hass = hass value_template.hass = hass
@ -115,60 +118,32 @@ def setup_platform(
class SQLSensor(SensorEntity): class SQLSensor(SensorEntity):
"""Representation of an SQL sensor.""" """Representation of an SQL sensor."""
def __init__(self, name, sessmaker, query, column, unit, value_template): def __init__(
self,
name: str,
sessmaker: scoped_session,
query: str,
column: str,
unit: str | None,
value_template: Template | None,
) -> None:
"""Initialize the SQL sensor.""" """Initialize the SQL sensor."""
self._name = name self._attr_name = name
self._query = query self._query = query
self._unit_of_measurement = unit self._attr_native_unit_of_measurement = unit
self._template = value_template self._template = value_template
self._column_name = column self._column_name = column
self.sessionmaker = sessmaker self.sessionmaker = sessmaker
self._state = None self._attr_extra_state_attributes = {}
self._attributes = None
@property def update(self) -> None:
def name(self):
"""Return the name of the query."""
return self._name
@property
def native_value(self):
"""Return the query's current state."""
return self._state
@property
def native_unit_of_measurement(self):
"""Return the unit of measurement."""
return self._unit_of_measurement
@property
def extra_state_attributes(self):
"""Return the state attributes."""
return self._attributes
def update(self):
"""Retrieve sensor data from the query.""" """Retrieve sensor data from the query."""
data = None data = None
self._attr_extra_state_attributes = {}
sess: scoped_session = self.sessionmaker()
try: try:
sess = self.sessionmaker()
result = sess.execute(self._query) result = sess.execute(self._query)
self._attributes = {}
if not result.returns_rows or result.rowcount == 0:
_LOGGER.warning("%s returned no results", self._query)
self._state = None
return
for res in result.mappings():
_LOGGER.debug("result = %s", res.items())
data = res[self._column_name]
for key, value in res.items():
if isinstance(value, decimal.Decimal):
value = float(value)
if isinstance(value, datetime.date):
value = str(value)
self._attributes[key] = value
except sqlalchemy.exc.SQLAlchemyError as err: except sqlalchemy.exc.SQLAlchemyError as err:
_LOGGER.error( _LOGGER.error(
"Error executing query %s: %s", "Error executing query %s: %s",
@ -176,12 +151,27 @@ class SQLSensor(SensorEntity):
redact_credentials(str(err)), redact_credentials(str(err)),
) )
return return
finally:
sess.close() _LOGGER.debug("Result %s, ResultMapping %s", result, result.mappings())
for res in result.mappings():
_LOGGER.debug("result = %s", res.items())
data = res[self._column_name]
for key, value in res.items():
if isinstance(value, float):
value = float(value)
if isinstance(value, date):
value = value.isoformat()
self._attr_extra_state_attributes[key] = value
if data is not None and self._template is not None: if data is not None and self._template is not None:
self._state = self._template.async_render_with_possible_json_value( self._attr_native_value = (
data, None self._template.async_render_with_possible_json_value(data, None)
) )
else: else:
self._state = data self._attr_native_value = data
if not data:
_LOGGER.warning("%s returned no results", self._query)
sess.close()

View file

@ -1,13 +1,27 @@
"""The test for the sql sensor platform.""" """The test for the sql sensor platform."""
import os
import pytest import pytest
import voluptuous as vol import voluptuous as vol
from homeassistant.components.sql.sensor import validate_sql_select from homeassistant.components.sql.sensor import validate_sql_select
from homeassistant.const import STATE_UNKNOWN from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import get_test_config_dir
async def test_query(hass):
@pytest.fixture(autouse=True)
def remove_file():
"""Remove db."""
yield
file = os.path.join(get_test_config_dir(), "home-assistant_v2.db")
if os.path.isfile(file):
os.remove(file)
async def test_query(hass: HomeAssistant) -> None:
"""Test the SQL sensor.""" """Test the SQL sensor."""
config = { config = {
"sensor": { "sensor": {
@ -31,7 +45,53 @@ async def test_query(hass):
assert state.attributes["value"] == 5 assert state.attributes["value"] == 5
async def test_query_limit(hass): async def test_query_no_db(hass: HomeAssistant) -> None:
"""Test the SQL sensor."""
config = {
"sensor": {
"platform": "sql",
"queries": [
{
"name": "count_tables",
"query": "SELECT 5 as value",
"column": "value",
}
],
}
}
assert await async_setup_component(hass, "sensor", config)
await hass.async_block_till_done()
state = hass.states.get("sensor.count_tables")
assert state.state == "5"
async def test_query_value_template(hass: HomeAssistant) -> None:
"""Test the SQL sensor."""
config = {
"sensor": {
"platform": "sql",
"db_url": "sqlite://",
"queries": [
{
"name": "count_tables",
"query": "SELECT 5.01 as value",
"column": "value",
"value_template": "{{ value | int }}",
}
],
}
}
assert await async_setup_component(hass, "sensor", config)
await hass.async_block_till_done()
state = hass.states.get("sensor.count_tables")
assert state.state == "5"
async def test_query_limit(hass: HomeAssistant) -> None:
"""Test the SQL sensor with a query containing 'LIMIT' in lowercase.""" """Test the SQL sensor with a query containing 'LIMIT' in lowercase."""
config = { config = {
"sensor": { "sensor": {
@ -55,7 +115,30 @@ async def test_query_limit(hass):
assert state.attributes["value"] == 5 assert state.attributes["value"] == 5
async def test_invalid_query(hass): async def test_query_no_value(hass: HomeAssistant) -> None:
"""Test the SQL sensor with a query that returns no value."""
config = {
"sensor": {
"platform": "sql",
"db_url": "sqlite://",
"queries": [
{
"name": "count_tables",
"query": "SELECT 5 as value where 1=2",
"column": "value",
}
],
}
}
assert await async_setup_component(hass, "sensor", config)
await hass.async_block_till_done()
state = hass.states.get("sensor.count_tables")
assert state.state == STATE_UNKNOWN
async def test_invalid_query(hass: HomeAssistant) -> None:
"""Test the SQL sensor for invalid queries.""" """Test the SQL sensor for invalid queries."""
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
validate_sql_select("DROP TABLE *") validate_sql_select("DROP TABLE *")
@ -81,6 +164,30 @@ async def test_invalid_query(hass):
assert state.state == STATE_UNKNOWN assert state.state == STATE_UNKNOWN
async def test_value_float_and_date(hass: HomeAssistant) -> None:
"""Test the SQL sensor with a query has float as value."""
config = {
"sensor": {
"platform": "sql",
"db_url": "sqlite://",
"queries": [
{
"name": "float_value",
"query": "SELECT 5 as value, cast(5.01 as decimal(10,2)) as value2",
"column": "value",
},
],
}
}
assert await async_setup_component(hass, "sensor", config)
await hass.async_block_till_done()
state = hass.states.get("sensor.float_value")
assert state.state == "5"
assert isinstance(state.attributes["value2"], float)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"url,expected_patterns,not_expected_patterns", "url,expected_patterns,not_expected_patterns",
[ [
@ -96,7 +203,13 @@ async def test_invalid_query(hass):
), ),
], ],
) )
async def test_invalid_url(hass, caplog, url, expected_patterns, not_expected_patterns): async def test_invalid_url(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
url: str,
expected_patterns: str,
not_expected_patterns: str,
):
"""Test credentials in url is not logged.""" """Test credentials in url is not logged."""
config = { config = {
"sensor": { "sensor": {