Refactor Influx logic to reduce V1 vs V2 code paths (#37232)

* refactoring to share logic and sensor startup error test

* Added handling for V1 InfluxDBServerError to start-up and runtime and test for it

* Added InfluxDBServerError test to sensor setup tests

* Raising PlatformNotReady exception from sensor for setup failure

* Proper testing of PlatformNotReady error
This commit is contained in:
mdegat01 2020-06-30 14:02:25 -04:00 committed by GitHub
parent 38210ebbc6
commit 24289d5dbb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 460 additions and 345 deletions

View file

@ -1,10 +1,11 @@
"""Support for sending data to an Influx database."""
from dataclasses import dataclass
import logging
import math
import queue
import threading
import time
from typing import Dict
from typing import Any, Callable, Dict, List
from influxdb import InfluxDBClient, exceptions
from influxdb_client import InfluxDBClient as InfluxDBClientV2
@ -15,6 +16,10 @@ import urllib3.exceptions
import voluptuous as vol
from homeassistant.const import (
CONF_DOMAIN,
CONF_ENTITY_ID,
CONF_TIMEOUT,
CONF_UNIT_OF_MEASUREMENT,
CONF_URL,
EVENT_HOMEASSISTANT_STOP,
EVENT_STATE_CHANGED,
@ -33,8 +38,10 @@ from .const import (
API_VERSION_2,
BATCH_BUFFER_SIZE,
BATCH_TIMEOUT,
CLIENT_ERROR_V1_WITH_RETRY,
CLIENT_ERROR_V2_WITH_RETRY,
CATCHING_UP_MESSAGE,
CLIENT_ERROR_V1,
CLIENT_ERROR_V2,
CODE_INVALID_INPUTS,
COMPONENT_CONFIG_SCHEMA_CONNECTION,
CONF_API_VERSION,
CONF_BUCKET,
@ -56,18 +63,32 @@ from .const import (
CONF_TOKEN,
CONF_USERNAME,
CONF_VERIFY_SSL,
CONNECTION_ERROR_WITH_RETRY,
CONNECTION_ERROR,
DEFAULT_API_VERSION,
DEFAULT_HOST_V2,
DEFAULT_SSL_V2,
DOMAIN,
EVENT_NEW_STATE,
INFLUX_CONF_FIELDS,
INFLUX_CONF_MEASUREMENT,
INFLUX_CONF_ORG,
INFLUX_CONF_STATE,
INFLUX_CONF_TAGS,
INFLUX_CONF_TIME,
INFLUX_CONF_VALUE,
QUERY_ERROR,
QUEUE_BACKLOG_SECONDS,
RE_DECIMAL,
RE_DIGIT_TAIL,
RESUMED_MESSAGE,
RETRY_DELAY,
RETRY_INTERVAL,
RETRY_MESSAGE,
TEST_QUERY_V1,
TEST_QUERY_V2,
TIMEOUT,
WRITE_ERROR,
WROTE_MESSAGE,
)
_LOGGER = logging.getLogger(__name__)
@ -120,9 +141,11 @@ def validate_version_specific_config(conf: Dict) -> Dict:
return conf
_CONFIG_SCHEMA_ENTRY = vol.Schema({vol.Optional(CONF_OVERRIDE_MEASUREMENT): cv.string})
_CUSTOMIZE_ENTITY_SCHEMA = vol.Schema(
{vol.Optional(CONF_OVERRIDE_MEASUREMENT): cv.string}
)
_CONFIG_SCHEMA = INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA.extend(
_INFLUX_BASE_SCHEMA = INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA.extend(
{
vol.Optional(CONF_RETRY_COUNT, default=0): cv.positive_int,
vol.Optional(CONF_DEFAULT_MEASUREMENT): cv.string,
@ -132,89 +155,28 @@ _CONFIG_SCHEMA = INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA.extend(
cv.ensure_list, [cv.string]
),
vol.Optional(CONF_COMPONENT_CONFIG, default={}): vol.Schema(
{cv.entity_id: _CONFIG_SCHEMA_ENTRY}
{cv.entity_id: _CUSTOMIZE_ENTITY_SCHEMA}
),
vol.Optional(CONF_COMPONENT_CONFIG_GLOB, default={}): vol.Schema(
{cv.string: _CONFIG_SCHEMA_ENTRY}
{cv.string: _CUSTOMIZE_ENTITY_SCHEMA}
),
vol.Optional(CONF_COMPONENT_CONFIG_DOMAIN, default={}): vol.Schema(
{cv.string: _CONFIG_SCHEMA_ENTRY}
{cv.string: _CUSTOMIZE_ENTITY_SCHEMA}
),
}
)
CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: vol.All(
_CONFIG_SCHEMA.extend(COMPONENT_CONFIG_SCHEMA_CONNECTION),
INFLUX_SCHEMA = vol.All(
_INFLUX_BASE_SCHEMA.extend(COMPONENT_CONFIG_SCHEMA_CONNECTION),
validate_version_specific_config,
create_influx_url,
),
},
extra=vol.ALLOW_EXTRA,
)
def get_influx_connection(client_kwargs, bucket):
"""Create and check the correct influx connection for the API version."""
if bucket is not None:
# Test connection by synchronously writing nothing.
# If config is valid this will generate a `Bad Request` exception but not make anything.
# If config is invalid we will output an error.
# Hopefully a better way to test connection is added in the future.
try:
influx = InfluxDBClientV2(**client_kwargs)
influx.write_api(write_options=SYNCHRONOUS).write(bucket=bucket)
except ApiException as exc:
# 400 is the success state since it means we can write we just gave a bad point.
if exc.status != 400:
raise exc
else:
influx = InfluxDBClient(**client_kwargs)
influx.write_points([])
return influx
CONFIG_SCHEMA = vol.Schema({DOMAIN: INFLUX_SCHEMA}, extra=vol.ALLOW_EXTRA,)
def setup(hass, config):
"""Set up the InfluxDB component."""
conf = config[DOMAIN]
use_v2_api = conf[CONF_API_VERSION] == API_VERSION_2
bucket = None
kwargs = {
"timeout": TIMEOUT,
}
if use_v2_api:
kwargs["url"] = conf[CONF_URL]
kwargs["token"] = conf[CONF_TOKEN]
kwargs["org"] = conf[CONF_ORG]
bucket = conf[CONF_BUCKET]
else:
kwargs["database"] = conf[CONF_DB_NAME]
kwargs["verify_ssl"] = conf[CONF_VERIFY_SSL]
if CONF_USERNAME in conf:
kwargs["username"] = conf[CONF_USERNAME]
if CONF_PASSWORD in conf:
kwargs["password"] = conf[CONF_PASSWORD]
if CONF_HOST in conf:
kwargs["host"] = conf[CONF_HOST]
if CONF_PATH in conf:
kwargs["path"] = conf[CONF_PATH]
if CONF_PORT in conf:
kwargs["port"] = conf[CONF_PORT]
if CONF_SSL in conf:
kwargs["ssl"] = conf[CONF_SSL]
def _generate_event_to_json(conf: Dict) -> Callable[[Dict], str]:
"""Build event to json converter and add to config."""
entity_filter = convert_include_exclude_filter(conf)
tags = conf.get(CONF_TAGS)
tags_attributes = conf.get(CONF_TAGS_ATTRIBUTES)
@ -225,32 +187,10 @@ def setup(hass, config):
conf[CONF_COMPONENT_CONFIG_DOMAIN],
conf[CONF_COMPONENT_CONFIG_GLOB],
)
max_tries = conf.get(CONF_RETRY_COUNT)
try:
influx = get_influx_connection(kwargs, bucket)
if use_v2_api:
write_api = influx.write_api(write_options=ASYNCHRONOUS)
except (
OSError,
requests.exceptions.ConnectionError,
urllib3.exceptions.HTTPError,
) as exc:
_LOGGER.error(CONNECTION_ERROR_WITH_RETRY, exc)
event_helper.call_later(hass, RETRY_INTERVAL, lambda _: setup(hass, config))
return True
except exceptions.InfluxDBClientError as exc:
_LOGGER.error(CLIENT_ERROR_V1_WITH_RETRY, exc)
event_helper.call_later(hass, RETRY_INTERVAL, lambda _: setup(hass, config))
return True
except ApiException as exc:
_LOGGER.error(CLIENT_ERROR_V2_WITH_RETRY, exc)
event_helper.call_later(hass, RETRY_INTERVAL, lambda _: setup(hass, config))
return True
def event_to_json(event):
"""Add an event to the outgoing Influx list."""
state = event.data.get("new_state")
def event_to_json(event: Dict) -> str:
"""Convert event into json in format Influx expects."""
state = event.data.get(EVENT_NEW_STATE)
if (
state is None
or state.state in (STATE_UNKNOWN, "", STATE_UNAVAILABLE)
@ -278,7 +218,7 @@ def setup(hass, config):
if override_measurement:
measurement = override_measurement
else:
measurement = state.attributes.get("unit_of_measurement")
measurement = state.attributes.get(CONF_UNIT_OF_MEASUREMENT)
if measurement in (None, ""):
if default_measurement:
measurement = default_measurement
@ -288,57 +228,206 @@ def setup(hass, config):
include_uom = False
json = {
"measurement": measurement,
"tags": {"domain": state.domain, "entity_id": state.object_id},
"time": event.time_fired,
"fields": {},
INFLUX_CONF_MEASUREMENT: measurement,
INFLUX_CONF_TAGS: {
CONF_DOMAIN: state.domain,
CONF_ENTITY_ID: state.object_id,
},
INFLUX_CONF_TIME: event.time_fired,
INFLUX_CONF_FIELDS: {},
}
if _include_state:
json["fields"]["state"] = state.state
json[INFLUX_CONF_FIELDS][INFLUX_CONF_STATE] = state.state
if _include_value:
json["fields"]["value"] = _state_as_value
json[INFLUX_CONF_FIELDS][INFLUX_CONF_VALUE] = _state_as_value
for key, value in state.attributes.items():
if key in tags_attributes:
json["tags"][key] = value
elif key != "unit_of_measurement" or include_uom:
json[INFLUX_CONF_TAGS][key] = value
elif key != CONF_UNIT_OF_MEASUREMENT or include_uom:
# If the key is already in fields
if key in json["fields"]:
if key in json[INFLUX_CONF_FIELDS]:
key = f"{key}_"
# Prevent column data errors in influxDB.
# For each value we try to cast it as float
# But if we can not do it we store the value
# as string add "_str" postfix to the field key
try:
json["fields"][key] = float(value)
json[INFLUX_CONF_FIELDS][key] = float(value)
except (ValueError, TypeError):
new_key = f"{key}_str"
new_value = str(value)
json["fields"][new_key] = new_value
json[INFLUX_CONF_FIELDS][new_key] = new_value
if RE_DIGIT_TAIL.match(new_value):
json["fields"][key] = float(RE_DECIMAL.sub("", new_value))
json[INFLUX_CONF_FIELDS][key] = float(
RE_DECIMAL.sub("", new_value)
)
# Infinity and NaN are not valid floats in InfluxDB
try:
if not math.isfinite(json["fields"][key]):
del json["fields"][key]
if not math.isfinite(json[INFLUX_CONF_FIELDS][key]):
del json[INFLUX_CONF_FIELDS][key]
except (KeyError, TypeError):
pass
json["tags"].update(tags)
json[INFLUX_CONF_TAGS].update(tags)
return json
if use_v2_api:
instance = hass.data[DOMAIN] = InfluxThread(
hass, None, bucket, write_api, event_to_json, max_tries
)
else:
instance = hass.data[DOMAIN] = InfluxThread(
hass, influx, None, None, event_to_json, max_tries
)
return event_to_json
@dataclass
class InfluxClient:
"""An InfluxDB client wrapper for V1 or V2."""
write: Callable[[str], None]
query: Callable[[str, str], List[Any]]
close: Callable[[], None]
def get_influx_connection(conf, test_write=False, test_read=False):
"""Create the correct influx connection for the API version."""
kwargs = {
CONF_TIMEOUT: TIMEOUT,
}
if conf[CONF_API_VERSION] == API_VERSION_2:
kwargs[CONF_URL] = conf[CONF_URL]
kwargs[CONF_TOKEN] = conf[CONF_TOKEN]
kwargs[INFLUX_CONF_ORG] = conf[CONF_ORG]
bucket = conf.get(CONF_BUCKET)
influx = InfluxDBClientV2(**kwargs)
query_api = influx.query_api()
initial_write_mode = SYNCHRONOUS if test_write else ASYNCHRONOUS
write_api = influx.write_api(write_options=initial_write_mode)
def write_v2(json):
"""Write data to V2 influx."""
try:
write_api.write(bucket=bucket, record=json)
except (urllib3.exceptions.HTTPError, OSError) as exc:
raise ConnectionError(CONNECTION_ERROR % exc)
except ApiException as exc:
if exc.status == CODE_INVALID_INPUTS:
raise ValueError(WRITE_ERROR % (json, exc))
raise ConnectionError(CLIENT_ERROR_V2 % exc)
def query_v2(query, _=None):
"""Query V2 influx."""
try:
return query_api.query(query)
except (urllib3.exceptions.HTTPError, OSError) as exc:
raise ConnectionError(CONNECTION_ERROR % exc)
except ApiException as exc:
if exc.status == CODE_INVALID_INPUTS:
raise ValueError(QUERY_ERROR % (query, exc))
raise ConnectionError(CLIENT_ERROR_V2 % exc)
def close_v2():
"""Close V2 influx client."""
influx.close()
influx_client = InfluxClient(write_v2, query_v2, close_v2)
if test_write:
# Try to write [] to influx. If we can connect and creds are valid
# Then invalid inputs is returned. Anything else is a broken config
try:
influx_client.write([])
except ValueError:
pass
write_api = influx.write_api(write_options=ASYNCHRONOUS)
if test_read:
influx_client.query(TEST_QUERY_V2)
return influx_client
# Else it's a V1 client
kwargs[CONF_VERIFY_SSL] = conf[CONF_VERIFY_SSL]
if CONF_DB_NAME in conf:
kwargs[CONF_DB_NAME] = conf[CONF_DB_NAME]
if CONF_USERNAME in conf:
kwargs[CONF_USERNAME] = conf[CONF_USERNAME]
if CONF_PASSWORD in conf:
kwargs[CONF_PASSWORD] = conf[CONF_PASSWORD]
if CONF_HOST in conf:
kwargs[CONF_HOST] = conf[CONF_HOST]
if CONF_PATH in conf:
kwargs[CONF_PATH] = conf[CONF_PATH]
if CONF_PORT in conf:
kwargs[CONF_PORT] = conf[CONF_PORT]
if CONF_SSL in conf:
kwargs[CONF_SSL] = conf[CONF_SSL]
influx = InfluxDBClient(**kwargs)
def write_v1(json):
"""Write data to V1 influx."""
try:
influx.write_points(json)
except (
requests.exceptions.RequestException,
exceptions.InfluxDBServerError,
OSError,
) as exc:
raise ConnectionError(CONNECTION_ERROR % exc)
except exceptions.InfluxDBClientError as exc:
if exc.code == CODE_INVALID_INPUTS:
raise ValueError(WRITE_ERROR % (json, exc))
raise ConnectionError(CLIENT_ERROR_V1 % exc)
def query_v1(query, database=None):
"""Query V1 influx."""
try:
return list(influx.query(query, database=database).get_points())
except (
requests.exceptions.RequestException,
exceptions.InfluxDBServerError,
OSError,
) as exc:
raise ConnectionError(CONNECTION_ERROR % exc)
except exceptions.InfluxDBClientError as exc:
if exc.code == CODE_INVALID_INPUTS:
raise ValueError(QUERY_ERROR % (query, exc))
raise ConnectionError(CLIENT_ERROR_V1 % exc)
def close_v1():
"""Close the V1 Influx client."""
influx.close()
influx_client = InfluxClient(write_v1, query_v1, close_v1)
if test_write:
influx_client.write([])
if test_read:
influx_client.query(TEST_QUERY_V1)
return influx_client
def setup(hass, config):
"""Set up the InfluxDB component."""
conf = config[DOMAIN]
try:
influx = get_influx_connection(conf, test_write=True)
except ConnectionError as exc:
_LOGGER.error(RETRY_MESSAGE, exc)
event_helper.call_later(hass, RETRY_INTERVAL, lambda _: setup(hass, config))
return True
event_to_json = _generate_event_to_json(conf)
max_tries = conf.get(CONF_RETRY_COUNT)
instance = hass.data[DOMAIN] = InfluxThread(hass, influx, event_to_json, max_tries)
instance.start()
def shutdown(event):
@ -355,13 +444,11 @@ def setup(hass, config):
class InfluxThread(threading.Thread):
"""A threaded event handler class."""
def __init__(self, hass, influx, bucket, write_api, event_to_json, max_tries):
def __init__(self, hass, influx, event_to_json, max_tries):
"""Initialize the listener."""
threading.Thread.__init__(self, name="InfluxDB")
threading.Thread.__init__(self, name=DOMAIN)
self.queue = queue.Queue()
self.influx = influx
self.bucket = bucket
self.write_api = write_api
self.event_to_json = event_to_json
self.max_tries = max_tries
self.write_errors = 0
@ -410,7 +497,7 @@ class InfluxThread(threading.Thread):
pass
if dropped:
_LOGGER.warning("Catching up, dropped %d old events", dropped)
_LOGGER.warning(CATCHING_UP_MESSAGE, dropped)
return count, json
@ -418,28 +505,23 @@ class InfluxThread(threading.Thread):
"""Write preprocessed events to influxdb, with retry."""
for retry in range(self.max_tries + 1):
try:
if self.write_api is not None:
self.write_api.write(bucket=self.bucket, record=json)
else:
self.influx.write_points(json)
self.influx.write(json)
if self.write_errors:
_LOGGER.error("Resumed, lost %d events", self.write_errors)
_LOGGER.error(RESUMED_MESSAGE, self.write_errors)
self.write_errors = 0
_LOGGER.debug("Wrote %d events", len(json))
_LOGGER.debug(WROTE_MESSAGE, len(json))
break
except (
exceptions.InfluxDBClientError,
exceptions.InfluxDBServerError,
OSError,
ApiException,
) as err:
except ValueError as err:
_LOGGER.error(err)
break
except ConnectionError as err:
if retry < self.max_tries:
time.sleep(RETRY_DELAY)
else:
if not self.write_errors:
_LOGGER.error(WRITE_ERROR, json, err)
_LOGGER.error(err)
self.write_errors += len(json)
def run(self):

View file

@ -53,7 +53,18 @@ DEFAULT_GROUP_FUNCTION = "mean"
DEFAULT_FIELD = "value"
DEFAULT_RANGE_START = "-15m"
DEFAULT_RANGE_STOP = "now()"
DEFAULT_FUNCTION_FLUX = "|> limit(n: 1)"
INFLUX_CONF_MEASUREMENT = "measurement"
INFLUX_CONF_TAGS = "tags"
INFLUX_CONF_TIME = "time"
INFLUX_CONF_FIELDS = "fields"
INFLUX_CONF_STATE = "state"
INFLUX_CONF_VALUE = "value"
INFLUX_CONF_VALUE_V2 = "_value"
INFLUX_CONF_ORG = "org"
EVENT_NEW_STATE = "new_state"
DOMAIN = "influxdb"
API_VERSION_2 = "2"
TIMEOUT = 5
@ -65,7 +76,8 @@ BATCH_BUFFER_SIZE = 100
LANGUAGE_INFLUXQL = "influxQL"
LANGUAGE_FLUX = "flux"
TEST_QUERY_V1 = "SHOW SERIES LIMIT 1;"
TEST_QUERY_V2 = "buckets() |> limit(n:1)"
TEST_QUERY_V2 = f"buckets() {DEFAULT_FUNCTION_FLUX}"
CODE_INVALID_INPUTS = 400
MIN_TIME_BETWEEN_UPDATES = timedelta(seconds=60)
@ -91,11 +103,19 @@ WRITE_ERROR = "Could not write '%s' to influx due to '%s'."
QUERY_ERROR = (
"Could not execute query '%s' due to '%s'. Check the syntax of your query."
)
RETRY_MESSAGE = f"Retrying again in {RETRY_INTERVAL} seconds."
CONNECTION_ERROR_WITH_RETRY = f"{CONNECTION_ERROR} {RETRY_MESSAGE}"
CLIENT_ERROR_V1_WITH_RETRY = f"{CLIENT_ERROR_V1} {RETRY_MESSAGE}"
CLIENT_ERROR_V2_WITH_RETRY = f"{CLIENT_ERROR_V2} {RETRY_MESSAGE}"
RETRY_MESSAGE = f"%s Retrying in {RETRY_INTERVAL} seconds."
CATCHING_UP_MESSAGE = "Catching up, dropped %d old events."
RESUMED_MESSAGE = "Resumed, lost %d events."
WROTE_MESSAGE = "Wrote %d events."
RUNNING_QUERY_MESSAGE = "Running query: %s."
QUERY_NO_RESULTS_MESSAGE = "Query returned no results, sensor state set to UNKNOWN: %s."
QUERY_MULTIPLE_RESULTS_MESSAGE = (
"Query returned multiple results, only value from first one is shown: %s."
)
RENDERING_QUERY_MESSAGE = "Rendering query: %s."
RENDERING_QUERY_ERROR_MESSAGE = "Could not render query template: %s."
RENDERING_WHERE_MESSAGE = "Rendering where: %s."
RENDERING_WHERE_ERROR_MESSAGE = "Could not render where template: %s."
COMPONENT_CONFIG_SCHEMA_CONNECTION = {
# Connection config for V1 and V2 APIs.

View file

@ -2,34 +2,23 @@
import logging
from typing import Dict
from influxdb import InfluxDBClient, exceptions
from influxdb_client import InfluxDBClient as InfluxDBClientV2
from influxdb_client.rest import ApiException
import voluptuous as vol
from homeassistant.components.sensor import PLATFORM_SCHEMA
from homeassistant.components.sensor import PLATFORM_SCHEMA as SENSOR_PLATFORM_SCHEMA
from homeassistant.const import (
CONF_API_VERSION,
CONF_HOST,
CONF_NAME,
CONF_PASSWORD,
CONF_PATH,
CONF_PORT,
CONF_SSL,
CONF_TOKEN,
CONF_UNIT_OF_MEASUREMENT,
CONF_URL,
CONF_USERNAME,
CONF_VALUE_TEMPLATE,
CONF_VERIFY_SSL,
EVENT_HOMEASSISTANT_STOP,
STATE_UNKNOWN,
)
from homeassistant.exceptions import TemplateError
from homeassistant.exceptions import PlatformNotReady, TemplateError
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.util import Throttle
from . import create_influx_url, validate_version_specific_config
from . import create_influx_url, get_influx_connection, validate_version_specific_config
from .const import (
API_VERSION_2,
COMPONENT_CONFIG_SCHEMA_CONNECTION,
@ -38,8 +27,8 @@ from .const import (
CONF_FIELD,
CONF_GROUP_FUNCTION,
CONF_IMPORTS,
CONF_LANGUAGE,
CONF_MEASUREMENT_NAME,
CONF_ORG,
CONF_QUERIES,
CONF_QUERIES_FLUX,
CONF_QUERY,
@ -48,16 +37,63 @@ from .const import (
CONF_WHERE,
DEFAULT_API_VERSION,
DEFAULT_FIELD,
DEFAULT_FUNCTION_FLUX,
DEFAULT_GROUP_FUNCTION,
DEFAULT_RANGE_START,
DEFAULT_RANGE_STOP,
INFLUX_CONF_VALUE,
INFLUX_CONF_VALUE_V2,
LANGUAGE_FLUX,
LANGUAGE_INFLUXQL,
MIN_TIME_BETWEEN_UPDATES,
TEST_QUERY_V1,
TEST_QUERY_V2,
QUERY_MULTIPLE_RESULTS_MESSAGE,
QUERY_NO_RESULTS_MESSAGE,
RENDERING_QUERY_ERROR_MESSAGE,
RENDERING_QUERY_MESSAGE,
RENDERING_WHERE_ERROR_MESSAGE,
RENDERING_WHERE_MESSAGE,
RUNNING_QUERY_MESSAGE,
)
_LOGGER = logging.getLogger(__name__)
def _merge_connection_config_into_query(conf, query):
"""Merge connection details into each configured query."""
for key in conf:
if key not in query and key not in [CONF_QUERIES, CONF_QUERIES_FLUX]:
query[key] = conf[key]
def validate_query_format_for_version(conf: Dict) -> Dict:
"""Ensure queries are provided in correct format based on API version."""
if conf[CONF_API_VERSION] == API_VERSION_2:
if CONF_QUERIES_FLUX not in conf:
raise vol.Invalid(
f"{CONF_QUERIES_FLUX} is required when {CONF_API_VERSION} is {API_VERSION_2}"
)
for query in conf[CONF_QUERIES_FLUX]:
_merge_connection_config_into_query(conf, query)
query[CONF_LANGUAGE] = LANGUAGE_FLUX
del conf[CONF_BUCKET]
else:
if CONF_QUERIES not in conf:
raise vol.Invalid(
f"{CONF_QUERIES} is required when {CONF_API_VERSION} is {DEFAULT_API_VERSION}"
)
for query in conf[CONF_QUERIES]:
_merge_connection_config_into_query(conf, query)
query[CONF_LANGUAGE] = LANGUAGE_INFLUXQL
del conf[CONF_DB_NAME]
return conf
_QUERY_SENSOR_SCHEMA = vol.Schema(
{
vol.Required(CONF_NAME): cv.string,
@ -67,7 +103,7 @@ _QUERY_SENSOR_SCHEMA = vol.Schema(
)
_QUERY_SCHEMA = {
"InfluxQL": _QUERY_SENSOR_SCHEMA.extend(
LANGUAGE_INFLUXQL: _QUERY_SENSOR_SCHEMA.extend(
{
vol.Optional(CONF_DB_NAME): cv.string,
vol.Required(CONF_MEASUREMENT_NAME): cv.string,
@ -78,7 +114,7 @@ _QUERY_SCHEMA = {
vol.Required(CONF_WHERE): cv.template,
}
),
"Flux": _QUERY_SENSOR_SCHEMA.extend(
LANGUAGE_FLUX: _QUERY_SENSOR_SCHEMA.extend(
{
vol.Optional(CONF_BUCKET): cv.string,
vol.Optional(CONF_RANGE_START, default=DEFAULT_RANGE_START): cv.string,
@ -90,29 +126,11 @@ _QUERY_SCHEMA = {
),
}
def validate_query_format_for_version(conf: Dict) -> Dict:
"""Ensure queries are provided in correct format based on API version."""
if conf[CONF_API_VERSION] == API_VERSION_2:
if CONF_QUERIES_FLUX not in conf:
raise vol.Invalid(
f"{CONF_QUERIES_FLUX} is required when {CONF_API_VERSION} is {API_VERSION_2}"
)
else:
if CONF_QUERIES not in conf:
raise vol.Invalid(
f"{CONF_QUERIES} is required when {CONF_API_VERSION} is {DEFAULT_API_VERSION}"
)
return conf
PLATFORM_SCHEMA = vol.All(
PLATFORM_SCHEMA.extend(COMPONENT_CONFIG_SCHEMA_CONNECTION).extend(
SENSOR_PLATFORM_SCHEMA.extend(COMPONENT_CONFIG_SCHEMA_CONNECTION).extend(
{
vol.Exclusive(CONF_QUERIES, "queries"): [_QUERY_SCHEMA["InfluxQL"]],
vol.Exclusive(CONF_QUERIES_FLUX, "queries"): [_QUERY_SCHEMA["Flux"]],
vol.Exclusive(CONF_QUERIES, "queries"): [_QUERY_SCHEMA[LANGUAGE_INFLUXQL]],
vol.Exclusive(CONF_QUERIES_FLUX, "queries"): [_QUERY_SCHEMA[LANGUAGE_FLUX]],
}
),
validate_version_specific_config,
@ -123,61 +141,23 @@ PLATFORM_SCHEMA = vol.All(
def setup_platform(hass, config, add_entities, discovery_info=None):
"""Set up the InfluxDB component."""
use_v2_api = config[CONF_API_VERSION] == API_VERSION_2
queries = None
try:
influx = get_influx_connection(config, test_read=True)
except ConnectionError as exc:
_LOGGER.error(exc)
raise PlatformNotReady()
if use_v2_api:
influx_conf = {
"url": config[CONF_URL],
"token": config[CONF_TOKEN],
"org": config[CONF_ORG],
}
bucket = config[CONF_BUCKET]
queries = config[CONF_QUERIES_FLUX]
queries = config[CONF_QUERIES_FLUX if CONF_QUERIES_FLUX in config else CONF_QUERIES]
entities = [InfluxSensor(hass, influx, query) for query in queries]
add_entities(entities, update_before_add=True)
for v2_query in queries:
if CONF_BUCKET not in v2_query:
v2_query[CONF_BUCKET] = bucket
else:
influx_conf = {
"database": config[CONF_DB_NAME],
"verify_ssl": config[CONF_VERIFY_SSL],
}
if CONF_USERNAME in config:
influx_conf["username"] = config[CONF_USERNAME]
if CONF_PASSWORD in config:
influx_conf["password"] = config[CONF_PASSWORD]
if CONF_HOST in config:
influx_conf["host"] = config[CONF_HOST]
if CONF_PATH in config:
influx_conf["path"] = config[CONF_PATH]
if CONF_PORT in config:
influx_conf["port"] = config[CONF_PORT]
if CONF_SSL in config:
influx_conf["ssl"] = config[CONF_SSL]
queries = config[CONF_QUERIES]
entities = []
for query in queries:
sensor = InfluxSensor(hass, influx_conf, query, use_v2_api)
if sensor.connected:
entities.append(sensor)
add_entities(entities, True)
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, lambda _: influx.close())
class InfluxSensor(Entity):
"""Implementation of a Influxdb sensor."""
def __init__(self, hass, influx_conf, query, use_v2_api):
def __init__(self, hass, influx, query):
"""Initialize the sensor."""
self._name = query.get(CONF_NAME)
self._unit_of_measurement = query.get(CONF_UNIT_OF_MEASUREMENT)
@ -190,32 +170,12 @@ class InfluxSensor(Entity):
self._state = None
self._hass = hass
if use_v2_api:
influx = InfluxDBClientV2(**influx_conf)
query_api = influx.query_api()
if query[CONF_LANGUAGE] == LANGUAGE_FLUX:
query_clause = query.get(CONF_QUERY)
query_clause.hass = hass
bucket = query[CONF_BUCKET]
else:
if CONF_DB_NAME in query:
kwargs = influx_conf.copy()
kwargs[CONF_DB_NAME] = query[CONF_DB_NAME]
else:
kwargs = influx_conf
influx = InfluxDBClient(**kwargs)
where_clause = query.get(CONF_WHERE)
where_clause.hass = hass
query_api = None
try:
if query_api is not None:
query_api.query(TEST_QUERY_V2)
self.connected = True
self.data = InfluxSensorDataV2(
query_api,
bucket,
self.data = InfluxFluxSensorData(
influx,
query.get(CONF_BUCKET),
query.get(CONF_RANGE_START),
query.get(CONF_RANGE_STOP),
query_clause,
@ -224,32 +184,16 @@ class InfluxSensor(Entity):
)
else:
influx.query(TEST_QUERY_V1)
self.connected = True
self.data = InfluxSensorDataV1(
where_clause = query.get(CONF_WHERE)
where_clause.hass = hass
self.data = InfluxQLSensorData(
influx,
query.get(CONF_DB_NAME),
query.get(CONF_GROUP_FUNCTION),
query.get(CONF_FIELD),
query.get(CONF_MEASUREMENT_NAME),
where_clause,
)
except exceptions.InfluxDBClientError as exc:
_LOGGER.error(
"Database host is not accessible due to '%s', please"
" check your entries in the configuration file and"
" that the database exists and is READ/WRITE",
exc,
)
self.connected = False
except ApiException as exc:
_LOGGER.error(
"Bucket is not accessible due to '%s', please "
"check your entries in the configuration file (url, org, "
"bucket, etc.) and verify that the org and bucket exist and the "
"provided token has READ access.",
exc,
)
self.connected = False
@property
def name(self):
@ -285,14 +229,12 @@ class InfluxSensor(Entity):
self._state = value
class InfluxSensorDataV2:
"""Class for handling the data retrieval with v2 API."""
class InfluxFluxSensorData:
"""Class for handling the data retrieval from Influx with Flux query."""
def __init__(
self, query_api, bucket, range_start, range_stop, query, imports, group
):
def __init__(self, influx, bucket, range_start, range_stop, query, imports, group):
"""Initialize the data object."""
self.query_api = query_api
self.influx = influx
self.bucket = bucket
self.range_start = range_start
self.range_stop = range_stop
@ -308,57 +250,47 @@ class InfluxSensorDataV2:
self.query_prefix = f'import "{i}" {self.query_prefix}'
if group is None:
self.query_postfix = "|> limit(n: 1)"
self.query_postfix = DEFAULT_FUNCTION_FLUX
else:
self.query_postfix = f'|> {group}(column: "_value")'
self.query_postfix = f'|> {group}(column: "{INFLUX_CONF_VALUE_V2}")'
@Throttle(MIN_TIME_BETWEEN_UPDATES)
def update(self):
"""Get the latest data by querying influx."""
_LOGGER.debug("Rendering query: %s", self.query)
_LOGGER.debug(RENDERING_QUERY_MESSAGE, self.query)
try:
rendered_query = self.query.render()
except TemplateError as ex:
_LOGGER.error("Could not render query template: %s", ex)
_LOGGER.error(RENDERING_QUERY_ERROR_MESSAGE, ex)
return
self.full_query = f"{self.query_prefix} {rendered_query} {self.query_postfix}"
_LOGGER.info("Running query: %s", self.full_query)
_LOGGER.debug(RUNNING_QUERY_MESSAGE, self.full_query)
try:
tables = self.query_api.query(self.full_query)
except (OSError, ApiException) as exc:
_LOGGER.error(
"Could not execute query '%s' due to '%s', "
"Check the syntax of your query",
self.full_query,
exc,
)
tables = self.influx.query(self.full_query)
except (ConnectionError, ValueError) as exc:
_LOGGER.error(exc)
self.value = None
return
if not tables:
_LOGGER.warning(
"Query returned no results, sensor state set to UNKNOWN: %s",
self.full_query,
)
_LOGGER.warning(QUERY_NO_RESULTS_MESSAGE, self.full_query)
self.value = None
else:
if len(tables) > 1:
_LOGGER.warning(
"Query returned multiple tables, only value from first one is shown: %s",
self.full_query,
)
self.value = tables[0].records[0].values["_value"]
if len(tables) > 1 or len(tables[0].records) > 1:
_LOGGER.warning(QUERY_MULTIPLE_RESULTS_MESSAGE, self.full_query)
self.value = tables[0].records[0].values[INFLUX_CONF_VALUE_V2]
class InfluxSensorDataV1:
class InfluxQLSensorData:
"""Class for handling the data retrieval with v1 API."""
def __init__(self, influx, group, field, measurement, where):
def __init__(self, influx, db_name, group, field, measurement, where):
"""Initialize the data object."""
self.influx = influx
self.db_name = db_name
self.group = group
self.field = field
self.measurement = measurement
@ -369,38 +301,28 @@ class InfluxSensorDataV1:
@Throttle(MIN_TIME_BETWEEN_UPDATES)
def update(self):
"""Get the latest data with a shell command."""
_LOGGER.info("Rendering where: %s", self.where)
_LOGGER.debug(RENDERING_WHERE_MESSAGE, self.where)
try:
where_clause = self.where.render()
except TemplateError as ex:
_LOGGER.error("Could not render where clause template: %s", ex)
_LOGGER.error(RENDERING_WHERE_ERROR_MESSAGE, ex)
return
self.query = f"select {self.group}({self.field}) as value from {self.measurement} where {where_clause}"
self.query = f"select {self.group}({self.field}) as {INFLUX_CONF_VALUE} from {self.measurement} where {where_clause}"
_LOGGER.info("Running query: %s", self.query)
_LOGGER.debug(RUNNING_QUERY_MESSAGE, self.query)
try:
points = list(self.influx.query(self.query).get_points())
except (OSError, exceptions.InfluxDBClientError) as exc:
_LOGGER.error(
"Could not execute query '%s' due to '%s', "
"Check the syntax of your query",
self.query,
exc,
)
points = self.influx.query(self.query, self.db_name)
except (ConnectionError, ValueError) as exc:
_LOGGER.error(exc)
self.value = None
return
if not points:
_LOGGER.warning(
"Query returned no points, sensor state set to UNKNOWN: %s", self.query
)
_LOGGER.warning(QUERY_NO_RESULTS_MESSAGE, self.query)
self.value = None
else:
if len(points) > 1:
_LOGGER.warning(
"Query returned multiple points, only first one shown: %s",
self.query,
)
self.value = points[0].get("value")
_LOGGER.warning(QUERY_MULTIPLE_RESULTS_MESSAGE, self.query)
self.value = points[0].get(INFLUX_CONF_VALUE)

View file

@ -1226,6 +1226,13 @@ async def test_event_listener_attribute_name_conflict(
influxdb.DEFAULT_API_VERSION,
influxdb.exceptions.InfluxDBClientError("fail"),
),
(
influxdb.DEFAULT_API_VERSION,
BASE_V1_CONFIG,
_get_write_api_mock_v1,
influxdb.DEFAULT_API_VERSION,
influxdb.exceptions.InfluxDBServerError("fail"),
),
(
influxdb.API_VERSION_2,
BASE_V2_CONFIG,

View file

@ -1,8 +1,9 @@
"""The tests for the InfluxDB sensor."""
from dataclasses import dataclass
from datetime import timedelta
from typing import Dict, List, Type
from influxdb.exceptions import InfluxDBClientError
from influxdb.exceptions import InfluxDBClientError, InfluxDBServerError
from influxdb_client.rest import ApiException
import pytest
from voluptuous import Invalid
@ -18,12 +19,15 @@ from homeassistant.components.influxdb.sensor import PLATFORM_SCHEMA
import homeassistant.components.sensor as sensor
from homeassistant.const import STATE_UNKNOWN
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
from tests.async_mock import MagicMock, patch
from tests.common import async_fire_time_changed
INFLUXDB_PATH = "homeassistant.components.influxdb"
INFLUXDB_CLIENT_PATH = f"{INFLUXDB_PATH}.sensor.InfluxDBClient"
INFLUXDB_CLIENT_PATH = f"{INFLUXDB_PATH}.InfluxDBClient"
INFLUXDB_SENSOR_PATH = f"{INFLUXDB_PATH}.sensor"
PLATFORM_NOT_READY_BASE_WAIT_TIME = 30
BASE_V1_CONFIG = {}
BASE_V2_CONFIG = {
@ -137,6 +141,8 @@ def _set_query_mock_v1(mock_influx_client, return_value=None, side_effect=None):
query_api.side_effect = get_return_value
return query_api
def _set_query_mock_v2(mock_influx_client, return_value=None, side_effect=None):
"""Set return value or side effect for the V2 client."""
@ -149,6 +155,8 @@ def _set_query_mock_v2(mock_influx_client, return_value=None, side_effect=None):
query_api.return_value = return_value
return query_api
async def _setup(hass, config_ext, queries, expected_sensors):
"""Create client and test expected sensors."""
@ -451,3 +459,79 @@ async def test_error_rendering_template(
assert (
len([record for record in caplog.records if record.levelname == "ERROR"]) == 1
)
@pytest.mark.parametrize(
"mock_client, config_ext, queries, set_query_mock, test_exception, make_resultset",
[
(
DEFAULT_API_VERSION,
BASE_V1_CONFIG,
BASE_V1_QUERY,
_set_query_mock_v1,
OSError("fail"),
_make_v1_resultset,
),
(
DEFAULT_API_VERSION,
BASE_V1_CONFIG,
BASE_V1_QUERY,
_set_query_mock_v1,
InfluxDBClientError("fail"),
_make_v1_resultset,
),
(
DEFAULT_API_VERSION,
BASE_V1_CONFIG,
BASE_V1_QUERY,
_set_query_mock_v1,
InfluxDBServerError("fail"),
_make_v1_resultset,
),
(
API_VERSION_2,
BASE_V2_CONFIG,
BASE_V2_QUERY,
_set_query_mock_v2,
OSError("fail"),
_make_v2_resultset,
),
(
API_VERSION_2,
BASE_V2_CONFIG,
BASE_V2_QUERY,
_set_query_mock_v2,
ApiException(),
_make_v2_resultset,
),
],
indirect=["mock_client"],
)
async def test_connection_error_at_startup(
hass,
caplog,
mock_client,
config_ext,
queries,
set_query_mock,
test_exception,
make_resultset,
):
"""Test behavior of sensor when influx returns error."""
query_api = set_query_mock(mock_client, side_effect=test_exception)
expected_sensor = "sensor.test"
# Test sensor is not setup first time due to connection error
await _setup(hass, config_ext, queries, [])
assert hass.states.get(expected_sensor) is None
assert (
len([record for record in caplog.records if record.levelname == "ERROR"]) == 1
)
# Stop throwing exception and advance time to test setup succeeds
query_api.reset_mock(side_effect=True)
set_query_mock(mock_client, return_value=make_resultset(42))
new_time = dt_util.utcnow() + timedelta(seconds=PLATFORM_NOT_READY_BASE_WAIT_TIME)
async_fire_time_changed(hass, new_time)
await hass.async_block_till_done()
assert hass.states.get(expected_sensor) is not None