Support reconfiguring MQTT config entry (#36537)
This commit is contained in:
parent
ee816ed3dd
commit
747490ab34
13 changed files with 881 additions and 249 deletions
|
@ -19,7 +19,7 @@ DEFAULT_QOS = 0
|
|||
TRIGGER_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_PLATFORM): mqtt.DOMAIN,
|
||||
vol.Required(CONF_TOPIC): mqtt.valid_subscribe_topic,
|
||||
vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic,
|
||||
vol.Optional(CONF_PAYLOAD): cv.string,
|
||||
vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string,
|
||||
vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All(
|
||||
|
|
|
@ -274,7 +274,7 @@ async def system_options_update(hass, connection, msg):
|
|||
{"type": "config_entries/update", "entry_id": str, vol.Optional("title"): str}
|
||||
)
|
||||
async def config_entry_update(hass, connection, msg):
|
||||
"""Update config entry system options."""
|
||||
"""Update config entry."""
|
||||
changes = dict(msg)
|
||||
changes.pop("id")
|
||||
changes.pop("type")
|
||||
|
|
|
@ -12,7 +12,7 @@ import sys
|
|||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import attr
|
||||
import requests.certs
|
||||
import certifi
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
|
@ -46,11 +46,20 @@ from . import debug_info, discovery, server
|
|||
from .const import (
|
||||
ATTR_DISCOVERY_HASH,
|
||||
ATTR_DISCOVERY_TOPIC,
|
||||
ATTR_PAYLOAD,
|
||||
ATTR_QOS,
|
||||
ATTR_RETAIN,
|
||||
ATTR_TOPIC,
|
||||
CONF_BIRTH_MESSAGE,
|
||||
CONF_BROKER,
|
||||
CONF_DISCOVERY,
|
||||
CONF_QOS,
|
||||
CONF_RETAIN,
|
||||
CONF_STATE_TOPIC,
|
||||
CONF_WILL_MESSAGE,
|
||||
DEFAULT_DISCOVERY,
|
||||
DEFAULT_QOS,
|
||||
DEFAULT_RETAIN,
|
||||
MQTT_CONNECTED,
|
||||
MQTT_DISCONNECTED,
|
||||
PROTOCOL_311,
|
||||
|
@ -59,6 +68,7 @@ from .debug_info import log_messages
|
|||
from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash, set_discovery_hash
|
||||
from .models import Message, MessageCallbackType, PublishPayloadType
|
||||
from .subscription import async_subscribe_topics, async_unsubscribe_topics
|
||||
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -80,17 +90,12 @@ CONF_CLIENT_CERT = "client_cert"
|
|||
CONF_TLS_INSECURE = "tls_insecure"
|
||||
CONF_TLS_VERSION = "tls_version"
|
||||
|
||||
CONF_BIRTH_MESSAGE = "birth_message"
|
||||
CONF_WILL_MESSAGE = "will_message"
|
||||
|
||||
CONF_COMMAND_TOPIC = "command_topic"
|
||||
CONF_AVAILABILITY_TOPIC = "availability_topic"
|
||||
CONF_PAYLOAD_AVAILABLE = "payload_available"
|
||||
CONF_PAYLOAD_NOT_AVAILABLE = "payload_not_available"
|
||||
CONF_JSON_ATTRS_TOPIC = "json_attributes_topic"
|
||||
CONF_JSON_ATTRS_TEMPLATE = "json_attributes_template"
|
||||
CONF_QOS = "qos"
|
||||
CONF_RETAIN = "retain"
|
||||
|
||||
CONF_UNIQUE_ID = "unique_id"
|
||||
CONF_IDENTIFIERS = "identifiers"
|
||||
|
@ -105,18 +110,13 @@ PROTOCOL_31 = "3.1"
|
|||
|
||||
DEFAULT_PORT = 1883
|
||||
DEFAULT_KEEPALIVE = 60
|
||||
DEFAULT_RETAIN = False
|
||||
DEFAULT_PROTOCOL = PROTOCOL_311
|
||||
DEFAULT_DISCOVERY_PREFIX = "homeassistant"
|
||||
DEFAULT_TLS_PROTOCOL = "auto"
|
||||
DEFAULT_PAYLOAD_AVAILABLE = "online"
|
||||
DEFAULT_PAYLOAD_NOT_AVAILABLE = "offline"
|
||||
|
||||
ATTR_TOPIC = "topic"
|
||||
ATTR_PAYLOAD = "payload"
|
||||
ATTR_PAYLOAD_TEMPLATE = "payload_template"
|
||||
ATTR_QOS = CONF_QOS
|
||||
ATTR_RETAIN = CONF_RETAIN
|
||||
|
||||
MAX_RECONNECT_WAIT = 300 # seconds
|
||||
|
||||
|
@ -125,59 +125,6 @@ CONNECTION_FAILED = "connection_failed"
|
|||
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
|
||||
|
||||
|
||||
def valid_topic(value: Any) -> str:
|
||||
"""Validate that this is a valid topic name/filter."""
|
||||
value = cv.string(value)
|
||||
try:
|
||||
raw_value = value.encode("utf-8")
|
||||
except UnicodeError:
|
||||
raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.")
|
||||
if not raw_value:
|
||||
raise vol.Invalid("MQTT topic name/filter must not be empty.")
|
||||
if len(raw_value) > 65535:
|
||||
raise vol.Invalid(
|
||||
"MQTT topic name/filter must not be longer than 65535 encoded bytes."
|
||||
)
|
||||
if "\0" in value:
|
||||
raise vol.Invalid("MQTT topic name/filter must not contain null character.")
|
||||
return value
|
||||
|
||||
|
||||
def valid_subscribe_topic(value: Any) -> str:
|
||||
"""Validate that we can subscribe using this MQTT topic."""
|
||||
value = valid_topic(value)
|
||||
for i in (i for i, c in enumerate(value) if c == "+"):
|
||||
if (i > 0 and value[i - 1] != "/") or (
|
||||
i < len(value) - 1 and value[i + 1] != "/"
|
||||
):
|
||||
raise vol.Invalid(
|
||||
"Single-level wildcard must occupy an entire level of the filter"
|
||||
)
|
||||
|
||||
index = value.find("#")
|
||||
if index != -1:
|
||||
if index != len(value) - 1:
|
||||
# If there are multiple wildcards, this will also trigger
|
||||
raise vol.Invalid(
|
||||
"Multi-level wildcard must be the last "
|
||||
"character in the topic filter."
|
||||
)
|
||||
if len(value) > 1 and value[index - 1] != "/":
|
||||
raise vol.Invalid(
|
||||
"Multi-level wildcard must be after a topic level separator."
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def valid_publish_topic(value: Any) -> str:
|
||||
"""Validate that we can publish using this MQTT topic."""
|
||||
value = valid_topic(value)
|
||||
if "+" in value or "#" in value:
|
||||
raise vol.Invalid("Wildcards can not be used in topic names")
|
||||
return value
|
||||
|
||||
|
||||
def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType:
|
||||
"""Validate that a device info entry has at least one identifying value."""
|
||||
if not value.get(CONF_IDENTIFIERS) and not value.get(CONF_CONNECTIONS):
|
||||
|
@ -188,8 +135,6 @@ def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType
|
|||
return value
|
||||
|
||||
|
||||
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
|
||||
|
||||
CLIENT_KEY_AUTH_MSG = (
|
||||
"client_key and client_cert must both be present in "
|
||||
"the MQTT broker configuration"
|
||||
|
@ -554,6 +499,11 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def _merge_config(entry, conf):
|
||||
"""Merge configuration.yaml config with config entry."""
|
||||
return {**conf, **entry.data}
|
||||
|
||||
|
||||
async def async_setup_entry(hass, entry):
|
||||
"""Load a config entry."""
|
||||
conf = hass.data.get(DATA_MQTT_CONFIG)
|
||||
|
@ -574,76 +524,9 @@ async def async_setup_entry(hass, entry):
|
|||
entry.data,
|
||||
)
|
||||
|
||||
conf.update(entry.data)
|
||||
conf = _merge_config(entry, conf)
|
||||
|
||||
broker = conf[CONF_BROKER]
|
||||
port = conf[CONF_PORT]
|
||||
client_id = conf.get(CONF_CLIENT_ID)
|
||||
keepalive = conf[CONF_KEEPALIVE]
|
||||
username = conf.get(CONF_USERNAME)
|
||||
password = conf.get(CONF_PASSWORD)
|
||||
certificate = conf.get(CONF_CERTIFICATE)
|
||||
client_key = conf.get(CONF_CLIENT_KEY)
|
||||
client_cert = conf.get(CONF_CLIENT_CERT)
|
||||
tls_insecure = conf.get(CONF_TLS_INSECURE)
|
||||
protocol = conf[CONF_PROTOCOL]
|
||||
|
||||
# For cloudmqtt.com, secured connection, auto fill in certificate
|
||||
if (
|
||||
certificate is None
|
||||
and 19999 < conf[CONF_PORT] < 30000
|
||||
and broker.endswith(".cloudmqtt.com")
|
||||
):
|
||||
certificate = os.path.join(
|
||||
os.path.dirname(__file__), "addtrustexternalcaroot.crt"
|
||||
)
|
||||
|
||||
# When the certificate is set to auto, use bundled certs from requests
|
||||
elif certificate == "auto":
|
||||
certificate = requests.certs.where()
|
||||
|
||||
if CONF_WILL_MESSAGE in conf:
|
||||
will_message = Message(**conf[CONF_WILL_MESSAGE])
|
||||
else:
|
||||
will_message = None
|
||||
|
||||
if CONF_BIRTH_MESSAGE in conf:
|
||||
birth_message = Message(**conf[CONF_BIRTH_MESSAGE])
|
||||
else:
|
||||
birth_message = None
|
||||
|
||||
# Be able to override versions other than TLSv1.0 under Python3.6
|
||||
conf_tls_version: str = conf.get(CONF_TLS_VERSION)
|
||||
if conf_tls_version == "1.2":
|
||||
tls_version = ssl.PROTOCOL_TLSv1_2
|
||||
elif conf_tls_version == "1.1":
|
||||
tls_version = ssl.PROTOCOL_TLSv1_1
|
||||
elif conf_tls_version == "1.0":
|
||||
tls_version = ssl.PROTOCOL_TLSv1
|
||||
else:
|
||||
# Python3.6 supports automatic negotiation of highest TLS version
|
||||
if sys.hexversion >= 0x03060000:
|
||||
tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member
|
||||
else:
|
||||
tls_version = ssl.PROTOCOL_TLSv1
|
||||
|
||||
hass.data[DATA_MQTT] = MQTT(
|
||||
hass,
|
||||
broker=broker,
|
||||
port=port,
|
||||
client_id=client_id,
|
||||
keepalive=keepalive,
|
||||
username=username,
|
||||
password=password,
|
||||
certificate=certificate,
|
||||
client_key=client_key,
|
||||
client_cert=client_cert,
|
||||
tls_insecure=tls_insecure,
|
||||
protocol=protocol,
|
||||
will_message=will_message,
|
||||
birth_message=birth_message,
|
||||
tls_version=tls_version,
|
||||
)
|
||||
hass.data[DATA_MQTT] = MQTT(hass, entry, conf,)
|
||||
|
||||
await hass.data[DATA_MQTT].async_connect()
|
||||
|
||||
|
@ -732,53 +615,101 @@ class Subscription:
|
|||
class MQTT:
|
||||
"""Home Assistant MQTT client."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistantType,
|
||||
broker: str,
|
||||
port: int,
|
||||
client_id: Optional[str],
|
||||
keepalive: Optional[int],
|
||||
username: Optional[str],
|
||||
password: Optional[str],
|
||||
certificate: Optional[str],
|
||||
client_key: Optional[str],
|
||||
client_cert: Optional[str],
|
||||
tls_insecure: Optional[bool],
|
||||
protocol: Optional[str],
|
||||
will_message: Optional[Message],
|
||||
birth_message: Optional[Message],
|
||||
tls_version: Optional[int],
|
||||
) -> None:
|
||||
def __init__(self, hass: HomeAssistantType, config_entry, conf,) -> None:
|
||||
"""Initialize Home Assistant MQTT client."""
|
||||
# We don't import them on the top because some integrations
|
||||
# We don't import on the top because some integrations
|
||||
# should be able to optionally rely on MQTT.
|
||||
# pylint: disable=import-outside-toplevel
|
||||
import paho.mqtt.client as mqtt
|
||||
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||
|
||||
self.hass = hass
|
||||
self.broker = broker
|
||||
self.port = port
|
||||
self.keepalive = keepalive
|
||||
self.config_entry = config_entry
|
||||
self.conf = conf
|
||||
self.subscriptions: List[Subscription] = []
|
||||
self.birth_message = birth_message
|
||||
self.connected = False
|
||||
self._mqttc: mqtt.Client = None
|
||||
self._paho_lock = asyncio.Lock()
|
||||
|
||||
if protocol == PROTOCOL_31:
|
||||
self.init_client()
|
||||
self.config_entry.add_update_listener(self.async_config_entry_updated)
|
||||
|
||||
@staticmethod
|
||||
async def async_config_entry_updated(hass, entry) -> None:
|
||||
"""Handle signals of config entry being updated.
|
||||
|
||||
This is a static method because a class method (bound method), can not be used with weak references.
|
||||
Causes for this is config entry options changing.
|
||||
"""
|
||||
self = hass.data[DATA_MQTT]
|
||||
|
||||
conf = hass.data.get(DATA_MQTT_CONFIG)
|
||||
if conf is None:
|
||||
conf = CONFIG_SCHEMA({DOMAIN: dict(entry.data)})[DOMAIN]
|
||||
|
||||
self.conf = _merge_config(entry, conf)
|
||||
await self.async_disconnect()
|
||||
self.init_client()
|
||||
await self.async_connect()
|
||||
|
||||
await discovery.async_stop(hass)
|
||||
if self.conf.get(CONF_DISCOVERY):
|
||||
await _async_setup_discovery(hass, self.conf, entry)
|
||||
|
||||
def init_client(self):
|
||||
"""Initialize paho client."""
|
||||
# We don't import on the top because some integrations
|
||||
# should be able to optionally rely on MQTT.
|
||||
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||
|
||||
if self.conf[CONF_PROTOCOL] == PROTOCOL_31:
|
||||
proto: int = mqtt.MQTTv31
|
||||
else:
|
||||
proto = mqtt.MQTTv311
|
||||
|
||||
client_id = self.conf.get(CONF_CLIENT_ID)
|
||||
if client_id is None:
|
||||
self._mqttc = mqtt.Client(protocol=proto)
|
||||
else:
|
||||
self._mqttc = mqtt.Client(client_id, protocol=proto)
|
||||
|
||||
username = self.conf.get(CONF_USERNAME)
|
||||
password = self.conf.get(CONF_PASSWORD)
|
||||
if username is not None:
|
||||
self._mqttc.username_pw_set(username, password)
|
||||
|
||||
certificate = self.conf.get(CONF_CERTIFICATE)
|
||||
|
||||
# For cloudmqtt.com, secured connection, auto fill in certificate
|
||||
if (
|
||||
certificate is None
|
||||
and 19999 < self.conf[CONF_PORT] < 30000
|
||||
and self.conf[CONF_BROKER].endswith(".cloudmqtt.com")
|
||||
):
|
||||
certificate = os.path.join(
|
||||
os.path.dirname(__file__), "addtrustexternalcaroot.crt"
|
||||
)
|
||||
|
||||
# When the certificate is set to auto, use bundled certs from certifi
|
||||
elif certificate == "auto":
|
||||
certificate = certifi.where()
|
||||
|
||||
# Be able to override versions other than TLSv1.0 under Python3.6
|
||||
conf_tls_version: str = self.conf.get(CONF_TLS_VERSION)
|
||||
if conf_tls_version == "1.2":
|
||||
tls_version = ssl.PROTOCOL_TLSv1_2
|
||||
elif conf_tls_version == "1.1":
|
||||
tls_version = ssl.PROTOCOL_TLSv1_1
|
||||
elif conf_tls_version == "1.0":
|
||||
tls_version = ssl.PROTOCOL_TLSv1
|
||||
else:
|
||||
# Python3.6 supports automatic negotiation of highest TLS version
|
||||
if sys.hexversion >= 0x03060000:
|
||||
tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member
|
||||
else:
|
||||
tls_version = ssl.PROTOCOL_TLSv1
|
||||
|
||||
client_key = self.conf.get(CONF_CLIENT_KEY)
|
||||
client_cert = self.conf.get(CONF_CLIENT_CERT)
|
||||
tls_insecure = self.conf.get(CONF_TLS_INSECURE)
|
||||
if certificate is not None:
|
||||
self._mqttc.tls_set(
|
||||
certificate,
|
||||
|
@ -794,6 +725,11 @@ class MQTT:
|
|||
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
||||
self._mqttc.on_message = self._mqtt_on_message
|
||||
|
||||
if CONF_WILL_MESSAGE in self.conf:
|
||||
will_message = Message(**self.conf[CONF_WILL_MESSAGE])
|
||||
else:
|
||||
will_message = None
|
||||
|
||||
if will_message is not None:
|
||||
self._mqttc.will_set( # pylint: disable=no-value-for-parameter
|
||||
*attr.astuple(
|
||||
|
@ -813,14 +749,17 @@ class MQTT:
|
|||
)
|
||||
|
||||
async def async_connect(self) -> str:
|
||||
"""Connect to the host. Does process messages yet."""
|
||||
"""Connect to the host. Does not process messages yet."""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
result: int = None
|
||||
try:
|
||||
result = await self.hass.async_add_executor_job(
|
||||
self._mqttc.connect, self.broker, self.port, self.keepalive
|
||||
self._mqttc.connect,
|
||||
self.conf[CONF_BROKER],
|
||||
self.conf[CONF_PORT],
|
||||
self.conf[CONF_KEEPALIVE],
|
||||
)
|
||||
except OSError as err:
|
||||
_LOGGER.error("Failed to connect to MQTT server due to exception: %s", err)
|
||||
|
@ -922,7 +861,12 @@ class MQTT:
|
|||
|
||||
self.connected = True
|
||||
dispatcher_send(self.hass, MQTT_CONNECTED)
|
||||
_LOGGER.info("Connected to MQTT server (%s)", result_code)
|
||||
_LOGGER.info(
|
||||
"Connected to MQTT server %s:%s (%s)",
|
||||
self.conf[CONF_BROKER],
|
||||
self.conf[CONF_PORT],
|
||||
result_code,
|
||||
)
|
||||
|
||||
# Group subscriptions to only re-subscribe once for each topic.
|
||||
keyfunc = attrgetter("topic")
|
||||
|
@ -931,11 +875,12 @@ class MQTT:
|
|||
max_qos = max(subscription.qos for subscription in subs)
|
||||
self.hass.add_job(self._async_perform_subscription, topic, max_qos)
|
||||
|
||||
if self.birth_message:
|
||||
if CONF_BIRTH_MESSAGE in self.conf:
|
||||
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
|
||||
self.hass.add_job(
|
||||
self.async_publish( # pylint: disable=no-value-for-parameter
|
||||
*attr.astuple(
|
||||
self.birth_message,
|
||||
birth_message,
|
||||
filter=lambda attr, value: attr.name
|
||||
not in ["subscribed_topic", "timestamp"],
|
||||
)
|
||||
|
@ -990,7 +935,12 @@ class MQTT:
|
|||
"""Disconnected callback."""
|
||||
self.connected = False
|
||||
dispatcher_send(self.hass, MQTT_DISCONNECTED)
|
||||
_LOGGER.warning("Disconnected from MQTT server (%s)", result_code)
|
||||
_LOGGER.warning(
|
||||
"Disconnected from MQTT server %s:%s (%s)",
|
||||
self.conf[CONF_BROKER],
|
||||
self.conf[CONF_PORT],
|
||||
result_code,
|
||||
)
|
||||
|
||||
|
||||
def _raise_on_error(result_code: int) -> None:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Config flow for MQTT."""
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
import queue
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -13,7 +14,22 @@ from homeassistant.const import (
|
|||
CONF_USERNAME,
|
||||
)
|
||||
|
||||
from .const import CONF_BROKER, CONF_DISCOVERY, DEFAULT_DISCOVERY
|
||||
from .const import (
|
||||
ATTR_PAYLOAD,
|
||||
ATTR_QOS,
|
||||
ATTR_RETAIN,
|
||||
ATTR_TOPIC,
|
||||
CONF_BIRTH_MESSAGE,
|
||||
CONF_BROKER,
|
||||
CONF_DISCOVERY,
|
||||
CONF_WILL_MESSAGE,
|
||||
DEFAULT_DISCOVERY,
|
||||
DEFAULT_QOS,
|
||||
DEFAULT_RETAIN,
|
||||
)
|
||||
from .util import MQTT_WILL_BIRTH_SCHEMA
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@config_entries.HANDLERS.register("mqtt")
|
||||
|
@ -25,6 +41,11 @@ class FlowHandler(config_entries.ConfigFlow):
|
|||
|
||||
_hassio_discovery = None
|
||||
|
||||
@staticmethod
|
||||
def async_get_options_flow(config_entry):
|
||||
"""Get the options flow for this handler."""
|
||||
return MQTTOptionsFlowHandler(config_entry)
|
||||
|
||||
async def async_step_user(self, user_input=None):
|
||||
"""Handle a flow initialized by the user."""
|
||||
if self._async_current_entries():
|
||||
|
@ -123,6 +144,163 @@ class FlowHandler(config_entries.ConfigFlow):
|
|||
)
|
||||
|
||||
|
||||
class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
||||
"""Handle MQTT options."""
|
||||
|
||||
def __init__(self, config_entry):
|
||||
"""Initialize MQTT options flow."""
|
||||
self.config_entry = config_entry
|
||||
self.broker_config = {}
|
||||
self.options = dict(config_entry.options)
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
"""Manage the MQTT options."""
|
||||
return await self.async_step_broker()
|
||||
|
||||
async def async_step_broker(self, user_input=None):
|
||||
"""Manage the MQTT options."""
|
||||
errors = {}
|
||||
current_config = self.config_entry.data
|
||||
if user_input is not None:
|
||||
can_connect = await self.hass.async_add_executor_job(
|
||||
try_connection,
|
||||
user_input[CONF_BROKER],
|
||||
user_input[CONF_PORT],
|
||||
user_input.get(CONF_USERNAME),
|
||||
user_input.get(CONF_PASSWORD),
|
||||
)
|
||||
|
||||
if can_connect:
|
||||
self.broker_config.update(user_input)
|
||||
return await self.async_step_options()
|
||||
|
||||
errors["base"] = "cannot_connect"
|
||||
|
||||
fields = OrderedDict()
|
||||
fields[vol.Required(CONF_BROKER, default=current_config[CONF_BROKER])] = str
|
||||
fields[vol.Required(CONF_PORT, default=current_config[CONF_PORT])] = vol.Coerce(
|
||||
int
|
||||
)
|
||||
fields[
|
||||
vol.Optional(
|
||||
CONF_USERNAME,
|
||||
description={"suggested_value": current_config.get(CONF_USERNAME)},
|
||||
)
|
||||
] = str
|
||||
fields[
|
||||
vol.Optional(
|
||||
CONF_PASSWORD,
|
||||
description={"suggested_value": current_config.get(CONF_PASSWORD)},
|
||||
)
|
||||
] = str
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="broker", data_schema=vol.Schema(fields), errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_options(self, user_input=None):
|
||||
"""Manage the MQTT options."""
|
||||
errors = {}
|
||||
current_config = self.config_entry.data
|
||||
options_config = {}
|
||||
if user_input is not None:
|
||||
bad_birth = False
|
||||
bad_will = False
|
||||
|
||||
if "birth_topic" in user_input:
|
||||
birth_message = {
|
||||
ATTR_TOPIC: user_input["birth_topic"],
|
||||
ATTR_PAYLOAD: user_input.get("birth_payload", ""),
|
||||
ATTR_QOS: user_input["birth_qos"],
|
||||
ATTR_RETAIN: user_input["birth_retain"],
|
||||
}
|
||||
try:
|
||||
birth_message = MQTT_WILL_BIRTH_SCHEMA(birth_message)
|
||||
options_config[CONF_BIRTH_MESSAGE] = birth_message
|
||||
except vol.Invalid:
|
||||
errors["base"] = "bad_birth"
|
||||
bad_birth = True
|
||||
|
||||
if "will_topic" in user_input:
|
||||
will_message = {
|
||||
ATTR_TOPIC: user_input["will_topic"],
|
||||
ATTR_PAYLOAD: user_input.get("will_payload", ""),
|
||||
ATTR_QOS: user_input["will_qos"],
|
||||
ATTR_RETAIN: user_input["will_retain"],
|
||||
}
|
||||
try:
|
||||
will_message = MQTT_WILL_BIRTH_SCHEMA(will_message)
|
||||
options_config[CONF_WILL_MESSAGE] = will_message
|
||||
except vol.Invalid:
|
||||
errors["base"] = "bad_will"
|
||||
bad_will = True
|
||||
|
||||
options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY]
|
||||
|
||||
if not bad_birth and not bad_will:
|
||||
updated_config = {}
|
||||
updated_config.update(self.broker_config)
|
||||
updated_config.update(options_config)
|
||||
self.hass.config_entries.async_update_entry(
|
||||
self.config_entry, data=updated_config
|
||||
)
|
||||
return self.async_create_entry(title="", data=None)
|
||||
|
||||
birth_topic = None
|
||||
birth_payload = None
|
||||
birth_qos = DEFAULT_QOS
|
||||
birth_retain = DEFAULT_RETAIN
|
||||
if CONF_BIRTH_MESSAGE in current_config:
|
||||
birth_topic = current_config[CONF_BIRTH_MESSAGE][ATTR_TOPIC]
|
||||
birth_payload = current_config[CONF_BIRTH_MESSAGE][ATTR_PAYLOAD]
|
||||
birth_qos = current_config[CONF_BIRTH_MESSAGE].get(ATTR_QOS, DEFAULT_QOS)
|
||||
birth_retain = current_config[CONF_BIRTH_MESSAGE].get(
|
||||
ATTR_RETAIN, DEFAULT_RETAIN
|
||||
)
|
||||
|
||||
will_topic = None
|
||||
will_payload = None
|
||||
will_qos = DEFAULT_QOS
|
||||
will_retain = DEFAULT_RETAIN
|
||||
if CONF_WILL_MESSAGE in current_config:
|
||||
will_topic = current_config[CONF_WILL_MESSAGE][ATTR_TOPIC]
|
||||
will_payload = current_config[CONF_WILL_MESSAGE][ATTR_PAYLOAD]
|
||||
will_qos = current_config[CONF_WILL_MESSAGE].get(ATTR_QOS, DEFAULT_QOS)
|
||||
will_retain = current_config[CONF_WILL_MESSAGE].get(
|
||||
ATTR_RETAIN, DEFAULT_RETAIN
|
||||
)
|
||||
|
||||
fields = OrderedDict()
|
||||
fields[
|
||||
vol.Optional(
|
||||
CONF_DISCOVERY,
|
||||
default=current_config.get(CONF_DISCOVERY, DEFAULT_DISCOVERY),
|
||||
)
|
||||
] = bool
|
||||
fields[
|
||||
vol.Optional("birth_topic", description={"suggested_value": birth_topic})
|
||||
] = str
|
||||
fields[
|
||||
vol.Optional(
|
||||
"birth_payload", description={"suggested_value": birth_payload}
|
||||
)
|
||||
] = str
|
||||
fields[vol.Optional("birth_qos", default=birth_qos)] = vol.In([0, 1, 2])
|
||||
fields[vol.Optional("birth_retain", default=birth_retain)] = bool
|
||||
fields[
|
||||
vol.Optional("will_topic", description={"suggested_value": will_topic})
|
||||
] = str
|
||||
fields[
|
||||
vol.Optional("will_payload", description={"suggested_value": will_payload})
|
||||
] = str
|
||||
fields[vol.Optional("will_qos", default=will_qos)] = vol.In([0, 1, 2])
|
||||
fields[vol.Optional("will_retain", default=will_retain)] = bool
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="options", data_schema=vol.Schema(fields), errors=errors,
|
||||
)
|
||||
|
||||
|
||||
def try_connection(broker, port, username, password, protocol="3.1"):
|
||||
"""Test if we can connect to an MQTT broker."""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
|
|
|
@ -1,14 +1,25 @@
|
|||
"""Constants used by multiple MQTT modules."""
|
||||
CONF_BROKER = "broker"
|
||||
CONF_DISCOVERY = "discovery"
|
||||
DEFAULT_DISCOVERY = False
|
||||
|
||||
ATTR_DISCOVERY_HASH = "discovery_hash"
|
||||
ATTR_DISCOVERY_PAYLOAD = "discovery_payload"
|
||||
ATTR_DISCOVERY_TOPIC = "discovery_topic"
|
||||
ATTR_PAYLOAD = "payload"
|
||||
ATTR_QOS = "qos"
|
||||
ATTR_RETAIN = "retain"
|
||||
ATTR_TOPIC = "topic"
|
||||
|
||||
CONF_BROKER = "broker"
|
||||
CONF_BIRTH_MESSAGE = "birth_message"
|
||||
CONF_DISCOVERY = "discovery"
|
||||
CONF_QOS = ATTR_QOS
|
||||
CONF_RETAIN = ATTR_RETAIN
|
||||
CONF_STATE_TOPIC = "state_topic"
|
||||
PROTOCOL_311 = "3.1.1"
|
||||
CONF_WILL_MESSAGE = "will_message"
|
||||
|
||||
DEFAULT_DISCOVERY = False
|
||||
DEFAULT_QOS = 0
|
||||
DEFAULT_RETAIN = False
|
||||
|
||||
MQTT_CONNECTED = "mqtt_connected"
|
||||
MQTT_DISCONNECTED = "mqtt_disconnected"
|
||||
|
||||
PROTOCOL_311 = "3.1.1"
|
||||
|
|
|
@ -35,8 +35,9 @@ SUPPORTED_COMPONENTS = [
|
|||
]
|
||||
|
||||
ALREADY_DISCOVERED = "mqtt_discovered_components"
|
||||
DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock"
|
||||
CONFIG_ENTRY_IS_SETUP = "mqtt_config_entry_is_setup"
|
||||
DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock"
|
||||
DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe"
|
||||
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
|
||||
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
|
||||
|
||||
|
@ -163,8 +164,15 @@ async def async_start(
|
|||
hass.data[DATA_CONFIG_ENTRY_LOCK] = asyncio.Lock()
|
||||
hass.data[CONFIG_ENTRY_IS_SETUP] = set()
|
||||
|
||||
await mqtt.async_subscribe(
|
||||
hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe(
|
||||
hass, f"{discovery_topic}/#", async_device_message_received, 0
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_stop(hass: HomeAssistantType) -> bool:
|
||||
"""Stop MQTT Discovery."""
|
||||
if DISCOVERY_UNSUBSCRIBE in hass.data and hass.data[DISCOVERY_UNSUBSCRIBE]:
|
||||
hass.data[DISCOVERY_UNSUBSCRIBE]()
|
||||
hass.data[DISCOVERY_UNSUBSCRIBE] = None
|
||||
|
|
|
@ -47,5 +47,37 @@
|
|||
"button_5": "Fifth button",
|
||||
"button_6": "Sixth button"
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"broker": {
|
||||
"description": "Please enter the connection information of your MQTT broker.",
|
||||
"data": {
|
||||
"broker": "Broker",
|
||||
"port": "[%key:common::config_flow::data::port%]",
|
||||
"username": "[%key:common::config_flow::data::username%]",
|
||||
"password": "[%key:common::config_flow::data::password%]"
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"description": "Please select MQTT options.",
|
||||
"data": {
|
||||
"discovery": "Enable discovery",
|
||||
"birth_topic": "Birth message topic",
|
||||
"birth_payload": "Birth message payload",
|
||||
"birth_qos": "Birth message QoS",
|
||||
"birth_retain": "Birth message retain",
|
||||
"will_topic": "Will message topic",
|
||||
"will_payload": "Will message payload",
|
||||
"will_qos": "Will message QoS",
|
||||
"will_retain": "Will message retain"
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"cannot_connect": "Unable to connect to the broker.",
|
||||
"bad_birth": "Invalid birth topic.",
|
||||
"bad_will": "Invalid will topic."
|
||||
}
|
||||
}
|
||||
}
|
82
homeassistant/components/mqtt/util.py
Normal file
82
homeassistant/components/mqtt/util.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
"""Utility functions for the MQTT integration."""
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import CONF_PAYLOAD
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
|
||||
from .const import (
|
||||
ATTR_PAYLOAD,
|
||||
ATTR_QOS,
|
||||
ATTR_RETAIN,
|
||||
ATTR_TOPIC,
|
||||
DEFAULT_QOS,
|
||||
DEFAULT_RETAIN,
|
||||
)
|
||||
|
||||
|
||||
def valid_topic(value: Any) -> str:
|
||||
"""Validate that this is a valid topic name/filter."""
|
||||
value = cv.string(value)
|
||||
try:
|
||||
raw_value = value.encode("utf-8")
|
||||
except UnicodeError:
|
||||
raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.")
|
||||
if not raw_value:
|
||||
raise vol.Invalid("MQTT topic name/filter must not be empty.")
|
||||
if len(raw_value) > 65535:
|
||||
raise vol.Invalid(
|
||||
"MQTT topic name/filter must not be longer than 65535 encoded bytes."
|
||||
)
|
||||
if "\0" in value:
|
||||
raise vol.Invalid("MQTT topic name/filter must not contain null character.")
|
||||
return value
|
||||
|
||||
|
||||
def valid_subscribe_topic(value: Any) -> str:
|
||||
"""Validate that we can subscribe using this MQTT topic."""
|
||||
value = valid_topic(value)
|
||||
for i in (i for i, c in enumerate(value) if c == "+"):
|
||||
if (i > 0 and value[i - 1] != "/") or (
|
||||
i < len(value) - 1 and value[i + 1] != "/"
|
||||
):
|
||||
raise vol.Invalid(
|
||||
"Single-level wildcard must occupy an entire level of the filter"
|
||||
)
|
||||
|
||||
index = value.find("#")
|
||||
if index != -1:
|
||||
if index != len(value) - 1:
|
||||
# If there are multiple wildcards, this will also trigger
|
||||
raise vol.Invalid(
|
||||
"Multi-level wildcard must be the last "
|
||||
"character in the topic filter."
|
||||
)
|
||||
if len(value) > 1 and value[index - 1] != "/":
|
||||
raise vol.Invalid(
|
||||
"Multi-level wildcard must be after a topic level separator."
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def valid_publish_topic(value: Any) -> str:
|
||||
"""Validate that we can publish using this MQTT topic."""
|
||||
value = valid_topic(value)
|
||||
if "+" in value or "#" in value:
|
||||
raise vol.Invalid("Wildcards can not be used in topic names")
|
||||
return value
|
||||
|
||||
|
||||
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
|
||||
|
||||
MQTT_WILL_BIRTH_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(ATTR_TOPIC): valid_publish_topic,
|
||||
vol.Required(ATTR_PAYLOAD, CONF_PAYLOAD): cv.string,
|
||||
vol.Optional(ATTR_QOS, default=DEFAULT_QOS): _VALID_QOS_SCHEMA,
|
||||
vol.Optional(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean,
|
||||
},
|
||||
required=True,
|
||||
)
|
|
@ -1008,7 +1008,8 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
|
|||
entry = self.hass.config_entries.async_get_entry(flow.handler)
|
||||
if entry is None:
|
||||
raise UnknownEntry(flow.handler)
|
||||
self.hass.config_entries.async_update_entry(entry, options=result["data"])
|
||||
if result["data"] is not None:
|
||||
self.hass.config_entries.async_update_entry(entry, options=result["data"])
|
||||
|
||||
result["result"] = True
|
||||
return result
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
"""Test config flow."""
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.components import mqtt
|
||||
from homeassistant.components.mqtt.discovery import async_start
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.async_mock import patch
|
||||
|
@ -144,3 +148,340 @@ async def test_hassio_confirm(hass, mock_try_connection, mock_finish_setup):
|
|||
assert len(mock_try_connection.mock_calls) == 1
|
||||
# Check config entry got setup
|
||||
assert len(mock_finish_setup.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_option_flow(hass, mqtt_mock, mock_try_connection):
|
||||
"""Test config flow options."""
|
||||
mock_try_connection.return_value = True
|
||||
config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
|
||||
await async_start(hass, "homeassistant", config_entry)
|
||||
config_entry.data = {
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
mqtt.CONF_PORT: 1234,
|
||||
}
|
||||
|
||||
mqtt_mock.async_connect.reset_mock()
|
||||
|
||||
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "broker"
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
mqtt.CONF_BROKER: "another-broker",
|
||||
mqtt.CONF_PORT: 2345,
|
||||
mqtt.CONF_USERNAME: "user",
|
||||
mqtt.CONF_PASSWORD: "pass",
|
||||
},
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "options"
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert mqtt_mock.async_connect.call_count == 0
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
mqtt.CONF_DISCOVERY: True,
|
||||
"birth_topic": "ha_state/online",
|
||||
"birth_payload": "online",
|
||||
"birth_qos": 1,
|
||||
"birth_retain": True,
|
||||
"will_topic": "ha_state/offline",
|
||||
"will_payload": "offline",
|
||||
"will_qos": 2,
|
||||
"will_retain": True,
|
||||
},
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert result["data"] is None
|
||||
assert config_entry.data == {
|
||||
mqtt.CONF_BROKER: "another-broker",
|
||||
mqtt.CONF_PORT: 2345,
|
||||
mqtt.CONF_USERNAME: "user",
|
||||
mqtt.CONF_PASSWORD: "pass",
|
||||
mqtt.CONF_DISCOVERY: True,
|
||||
mqtt.CONF_BIRTH_MESSAGE: {
|
||||
mqtt.ATTR_TOPIC: "ha_state/online",
|
||||
mqtt.ATTR_PAYLOAD: "online",
|
||||
mqtt.ATTR_QOS: 1,
|
||||
mqtt.ATTR_RETAIN: True,
|
||||
},
|
||||
mqtt.CONF_WILL_MESSAGE: {
|
||||
mqtt.ATTR_TOPIC: "ha_state/offline",
|
||||
mqtt.ATTR_PAYLOAD: "offline",
|
||||
mqtt.ATTR_QOS: 2,
|
||||
mqtt.ATTR_RETAIN: True,
|
||||
},
|
||||
}
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert mqtt_mock.async_connect.call_count == 1
|
||||
|
||||
|
||||
def get_default(schema, key):
|
||||
"""Get default value for key in voluptuous schema."""
|
||||
for k in schema.keys():
|
||||
if k == key:
|
||||
if k.default == vol.UNDEFINED:
|
||||
return None
|
||||
return k.default()
|
||||
|
||||
|
||||
def get_suggested(schema, key):
|
||||
"""Get suggested value for key in voluptuous schema."""
|
||||
for k in schema.keys():
|
||||
if k == key:
|
||||
if k.description is None or "suggested_value" not in k.description:
|
||||
return None
|
||||
return k.description["suggested_value"]
|
||||
|
||||
|
||||
async def test_option_flow_default_suggested_values(
|
||||
hass, mqtt_mock, mock_try_connection
|
||||
):
|
||||
"""Test config flow options has default/suggested values."""
|
||||
mock_try_connection.return_value = True
|
||||
config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
|
||||
await async_start(hass, "homeassistant", config_entry)
|
||||
config_entry.data = {
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
mqtt.CONF_PORT: 1234,
|
||||
mqtt.CONF_USERNAME: "user",
|
||||
mqtt.CONF_PASSWORD: "pass",
|
||||
mqtt.CONF_DISCOVERY: True,
|
||||
mqtt.CONF_BIRTH_MESSAGE: {
|
||||
mqtt.ATTR_TOPIC: "ha_state/online",
|
||||
mqtt.ATTR_PAYLOAD: "online",
|
||||
mqtt.ATTR_QOS: 1,
|
||||
mqtt.ATTR_RETAIN: True,
|
||||
},
|
||||
mqtt.CONF_WILL_MESSAGE: {
|
||||
mqtt.ATTR_TOPIC: "ha_state/offline",
|
||||
mqtt.ATTR_PAYLOAD: "offline",
|
||||
mqtt.ATTR_QOS: 2,
|
||||
mqtt.ATTR_RETAIN: False,
|
||||
},
|
||||
}
|
||||
|
||||
# Test default/suggested values from config
|
||||
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "broker"
|
||||
defaults = {
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
mqtt.CONF_PORT: 1234,
|
||||
}
|
||||
suggested = {
|
||||
mqtt.CONF_USERNAME: "user",
|
||||
mqtt.CONF_PASSWORD: "pass",
|
||||
}
|
||||
for k, v in defaults.items():
|
||||
assert get_default(result["data_schema"].schema, k) == v
|
||||
for k, v in suggested.items():
|
||||
assert get_suggested(result["data_schema"].schema, k) == v
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
mqtt.CONF_BROKER: "another-broker",
|
||||
mqtt.CONF_PORT: 2345,
|
||||
mqtt.CONF_USERNAME: "us3r",
|
||||
mqtt.CONF_PASSWORD: "p4ss",
|
||||
},
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "options"
|
||||
defaults = {
|
||||
mqtt.CONF_DISCOVERY: True,
|
||||
"birth_qos": 1,
|
||||
"birth_retain": True,
|
||||
"will_qos": 2,
|
||||
"will_retain": False,
|
||||
}
|
||||
suggested = {
|
||||
"birth_topic": "ha_state/online",
|
||||
"birth_payload": "online",
|
||||
"will_topic": "ha_state/offline",
|
||||
"will_payload": "offline",
|
||||
}
|
||||
for k, v in defaults.items():
|
||||
assert get_default(result["data_schema"].schema, k) == v
|
||||
for k, v in suggested.items():
|
||||
assert get_suggested(result["data_schema"].schema, k) == v
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
mqtt.CONF_DISCOVERY: False,
|
||||
"birth_topic": "ha_state/onl1ne",
|
||||
"birth_payload": "onl1ne",
|
||||
"birth_qos": 2,
|
||||
"birth_retain": False,
|
||||
"will_topic": "ha_state/offl1ne",
|
||||
"will_payload": "offl1ne",
|
||||
"will_qos": 1,
|
||||
"will_retain": True,
|
||||
},
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
|
||||
# Test updated default/suggested values from config
|
||||
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "broker"
|
||||
defaults = {
|
||||
mqtt.CONF_BROKER: "another-broker",
|
||||
mqtt.CONF_PORT: 2345,
|
||||
}
|
||||
suggested = {
|
||||
mqtt.CONF_USERNAME: "us3r",
|
||||
mqtt.CONF_PASSWORD: "p4ss",
|
||||
}
|
||||
for k, v in defaults.items():
|
||||
assert get_default(result["data_schema"].schema, k) == v
|
||||
for k, v in suggested.items():
|
||||
assert get_suggested(result["data_schema"].schema, k) == v
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345},
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "options"
|
||||
defaults = {
|
||||
mqtt.CONF_DISCOVERY: False,
|
||||
"birth_qos": 2,
|
||||
"birth_retain": False,
|
||||
"will_qos": 1,
|
||||
"will_retain": True,
|
||||
}
|
||||
suggested = {
|
||||
"birth_topic": "ha_state/onl1ne",
|
||||
"birth_payload": "onl1ne",
|
||||
"will_topic": "ha_state/offl1ne",
|
||||
"will_payload": "offl1ne",
|
||||
}
|
||||
for k, v in defaults.items():
|
||||
assert get_default(result["data_schema"].schema, k) == v
|
||||
for k, v in suggested.items():
|
||||
assert get_suggested(result["data_schema"].schema, k) == v
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
mqtt.CONF_DISCOVERY: True,
|
||||
"birth_topic": "ha_state/onl1ne",
|
||||
"birth_payload": "onl1ne",
|
||||
"birth_qos": 2,
|
||||
"birth_retain": False,
|
||||
"will_topic": "ha_state/offl1ne",
|
||||
"will_payload": "offl1ne",
|
||||
"will_qos": 1,
|
||||
"will_retain": True,
|
||||
},
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
|
||||
|
||||
async def test_options_user_connection_fails(hass, mock_try_connection):
|
||||
"""Test if connection cannot be made."""
|
||||
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
|
||||
config_entry.add_to_hass(hass)
|
||||
config_entry.data = {
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
mqtt.CONF_PORT: 1234,
|
||||
}
|
||||
|
||||
mock_try_connection.return_value = False
|
||||
|
||||
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||
assert result["type"] == "form"
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={mqtt.CONF_BROKER: "bad-broker", mqtt.CONF_PORT: 2345},
|
||||
)
|
||||
|
||||
assert result["type"] == "form"
|
||||
assert result["errors"]["base"] == "cannot_connect"
|
||||
|
||||
# Check we tried the connection
|
||||
assert len(mock_try_connection.mock_calls) == 1
|
||||
# Check config entry did not update
|
||||
assert config_entry.data == {
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
mqtt.CONF_PORT: 1234,
|
||||
}
|
||||
|
||||
|
||||
async def test_options_bad_birth_message_fails(hass, mock_try_connection):
|
||||
"""Test bad birth message."""
|
||||
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
|
||||
config_entry.add_to_hass(hass)
|
||||
config_entry.data = {
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
mqtt.CONF_PORT: 1234,
|
||||
}
|
||||
|
||||
mock_try_connection.return_value = True
|
||||
|
||||
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||
assert result["type"] == "form"
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345},
|
||||
)
|
||||
|
||||
assert result["type"] == "form"
|
||||
assert result["step_id"] == "options"
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"], user_input={"birth_topic": "ha_state/online/#"},
|
||||
)
|
||||
assert result["type"] == "form"
|
||||
assert result["errors"]["base"] == "bad_birth"
|
||||
|
||||
# Check config entry did not update
|
||||
assert config_entry.data == {
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
mqtt.CONF_PORT: 1234,
|
||||
}
|
||||
|
||||
|
||||
async def test_options_bad_will_message_fails(hass, mock_try_connection):
|
||||
"""Test bad will message."""
|
||||
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
|
||||
config_entry.add_to_hass(hass)
|
||||
config_entry.data = {
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
mqtt.CONF_PORT: 1234,
|
||||
}
|
||||
|
||||
mock_try_connection.return_value = True
|
||||
|
||||
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||
assert result["type"] == "form"
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345},
|
||||
)
|
||||
|
||||
assert result["type"] == "form"
|
||||
assert result["step_id"] == "options"
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"], user_input={"will_topic": "ha_state/offline/#"},
|
||||
)
|
||||
assert result["type"] == "form"
|
||||
assert result["errors"]["base"] == "bad_will"
|
||||
|
||||
# Check config entry did not update
|
||||
assert config_entry.data == {
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
mqtt.CONF_PORT: 1234,
|
||||
}
|
||||
|
|
|
@ -171,26 +171,26 @@ def test_validate_topic():
|
|||
"""Test topic name/filter validation."""
|
||||
# Invalid UTF-8, must not contain U+D800 to U+DFFF.
|
||||
with pytest.raises(vol.Invalid):
|
||||
mqtt.valid_topic("\ud800")
|
||||
mqtt.util.valid_topic("\ud800")
|
||||
with pytest.raises(vol.Invalid):
|
||||
mqtt.valid_topic("\udfff")
|
||||
mqtt.util.valid_topic("\udfff")
|
||||
# Topic MUST NOT be empty
|
||||
with pytest.raises(vol.Invalid):
|
||||
mqtt.valid_topic("")
|
||||
mqtt.util.valid_topic("")
|
||||
# Topic MUST NOT be longer than 65535 encoded bytes.
|
||||
with pytest.raises(vol.Invalid):
|
||||
mqtt.valid_topic("ü" * 32768)
|
||||
mqtt.util.valid_topic("ü" * 32768)
|
||||
# UTF-8 MUST NOT include null character
|
||||
with pytest.raises(vol.Invalid):
|
||||
mqtt.valid_topic("bad\0one")
|
||||
mqtt.util.valid_topic("bad\0one")
|
||||
|
||||
# Topics "SHOULD NOT" include these special characters
|
||||
# (not MUST NOT, RFC2119). The receiver MAY close the connection.
|
||||
mqtt.valid_topic("\u0001")
|
||||
mqtt.valid_topic("\u001F")
|
||||
mqtt.valid_topic("\u009F")
|
||||
mqtt.valid_topic("\u009F")
|
||||
mqtt.valid_topic("\uffff")
|
||||
mqtt.util.valid_topic("\u0001")
|
||||
mqtt.util.valid_topic("\u001F")
|
||||
mqtt.util.valid_topic("\u009F")
|
||||
mqtt.util.valid_topic("\u009F")
|
||||
mqtt.util.valid_topic("\uffff")
|
||||
|
||||
|
||||
def test_validate_subscribe_topic():
|
||||
|
@ -587,7 +587,7 @@ async def test_retained_message_on_subscribe_received(
|
|||
mqtt_client_mock.subscribe.side_effect = side_effect
|
||||
|
||||
# Fake that the client is connected
|
||||
mqtt_mock.connected = True
|
||||
mqtt_mock().connected = True
|
||||
|
||||
calls_a = MagicMock()
|
||||
await mqtt.async_subscribe(hass, "test/state", calls_a)
|
||||
|
@ -605,7 +605,7 @@ async def test_not_calling_unsubscribe_with_active_subscribers(
|
|||
):
|
||||
"""Test not calling unsubscribe() when other subscribers are active."""
|
||||
# Fake that the client is connected
|
||||
mqtt_mock.connected = True
|
||||
mqtt_mock().connected = True
|
||||
|
||||
unsub = await mqtt.async_subscribe(hass, "test/state", None)
|
||||
await mqtt.async_subscribe(hass, "test/state", None)
|
||||
|
@ -620,7 +620,7 @@ async def test_not_calling_unsubscribe_with_active_subscribers(
|
|||
async def test_restore_subscriptions_on_reconnect(hass, mqtt_client_mock, mqtt_mock):
|
||||
"""Test subscriptions are restored on reconnect."""
|
||||
# Fake that the client is connected
|
||||
mqtt_mock.connected = True
|
||||
mqtt_mock().connected = True
|
||||
|
||||
await mqtt.async_subscribe(hass, "test/state", None)
|
||||
await hass.async_block_till_done()
|
||||
|
@ -637,7 +637,7 @@ async def test_restore_all_active_subscriptions_on_reconnect(
|
|||
):
|
||||
"""Test active subscriptions are restored correctly on reconnect."""
|
||||
# Fake that the client is connected
|
||||
mqtt_mock.connected = True
|
||||
mqtt_mock().connected = True
|
||||
|
||||
mqtt_client_mock.subscribe.side_effect = (
|
||||
(0, 1),
|
||||
|
@ -716,81 +716,107 @@ async def test_setup_raises_ConfigEntryNotReady_if_no_connect_broker(hass, caplo
|
|||
assert "Failed to connect to MQTT server due to exception:" in caplog.text
|
||||
|
||||
|
||||
async def test_setup_uses_certificate_on_certificate_set_to_auto(hass, mock_mqtt):
|
||||
async def test_setup_uses_certificate_on_certificate_set_to_auto(hass):
|
||||
"""Test setup uses bundled certs when certificate is set to auto."""
|
||||
entry = MockConfigEntry(
|
||||
domain=mqtt.DOMAIN,
|
||||
data={mqtt.CONF_BROKER: "test-broker", "certificate": "auto"},
|
||||
)
|
||||
calls = []
|
||||
|
||||
assert await mqtt.async_setup_entry(hass, entry)
|
||||
def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
|
||||
calls.append((certificate, certfile, keyfile, tls_version))
|
||||
|
||||
assert mock_mqtt.called
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
mock_client().tls_set = mock_tls_set
|
||||
entry = MockConfigEntry(
|
||||
domain=mqtt.DOMAIN,
|
||||
data={mqtt.CONF_BROKER: "test-broker", "certificate": "auto"},
|
||||
)
|
||||
|
||||
import requests.certs
|
||||
assert await mqtt.async_setup_entry(hass, entry)
|
||||
|
||||
expectedCertificate = requests.certs.where()
|
||||
assert mock_mqtt.mock_calls[0][2]["certificate"] == expectedCertificate
|
||||
assert calls
|
||||
|
||||
import certifi
|
||||
|
||||
expectedCertificate = certifi.where()
|
||||
# assert mock_mqtt.mock_calls[0][1][2]["certificate"] == expectedCertificate
|
||||
assert calls[0][0] == expectedCertificate
|
||||
|
||||
|
||||
async def test_setup_does_not_use_certificate_on_mqtts_port(hass, mock_mqtt):
|
||||
"""Test setup doesn't use bundled certs when ssl set."""
|
||||
entry = MockConfigEntry(
|
||||
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "port": 8883}
|
||||
)
|
||||
|
||||
assert await mqtt.async_setup_entry(hass, entry)
|
||||
|
||||
assert mock_mqtt.called
|
||||
assert mock_mqtt.mock_calls[0][2]["port"] == 8883
|
||||
|
||||
import requests.certs
|
||||
|
||||
mqttsCertificateBundle = requests.certs.where()
|
||||
assert mock_mqtt.mock_calls[0][2]["port"] != mqttsCertificateBundle
|
||||
|
||||
|
||||
async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass, mock_mqtt):
|
||||
async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass):
|
||||
"""Test setup defaults to TLSv1 under python3.6."""
|
||||
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
|
||||
calls = []
|
||||
|
||||
assert await mqtt.async_setup_entry(hass, entry)
|
||||
def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
|
||||
calls.append((certificate, certfile, keyfile, tls_version))
|
||||
|
||||
assert mock_mqtt.called
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
mock_client().tls_set = mock_tls_set
|
||||
entry = MockConfigEntry(
|
||||
domain=mqtt.DOMAIN,
|
||||
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"},
|
||||
)
|
||||
|
||||
import sys
|
||||
assert await mqtt.async_setup_entry(hass, entry)
|
||||
|
||||
if sys.hexversion >= 0x03060000:
|
||||
expectedTlsVersion = ssl.PROTOCOL_TLS # pylint: disable=no-member
|
||||
else:
|
||||
expectedTlsVersion = ssl.PROTOCOL_TLSv1
|
||||
assert calls
|
||||
|
||||
assert mock_mqtt.mock_calls[0][2]["tls_version"] == expectedTlsVersion
|
||||
import sys
|
||||
|
||||
if sys.hexversion >= 0x03060000:
|
||||
expectedTlsVersion = ssl.PROTOCOL_TLS # pylint: disable=no-member
|
||||
else:
|
||||
expectedTlsVersion = ssl.PROTOCOL_TLSv1
|
||||
|
||||
assert calls[0][3] == expectedTlsVersion
|
||||
|
||||
|
||||
async def test_setup_with_tls_config_uses_tls_version1_2(hass, mock_mqtt):
|
||||
async def test_setup_with_tls_config_uses_tls_version1_2(hass):
|
||||
"""Test setup uses specified TLS version."""
|
||||
entry = MockConfigEntry(
|
||||
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "tls_version": "1.2"}
|
||||
)
|
||||
calls = []
|
||||
|
||||
assert await mqtt.async_setup_entry(hass, entry)
|
||||
def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
|
||||
calls.append((certificate, certfile, keyfile, tls_version))
|
||||
|
||||
assert mock_mqtt.called
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
mock_client().tls_set = mock_tls_set
|
||||
entry = MockConfigEntry(
|
||||
domain=mqtt.DOMAIN,
|
||||
data={
|
||||
"certificate": "auto",
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
"tls_version": "1.2",
|
||||
},
|
||||
)
|
||||
|
||||
assert mock_mqtt.mock_calls[0][2]["tls_version"] == ssl.PROTOCOL_TLSv1_2
|
||||
assert await mqtt.async_setup_entry(hass, entry)
|
||||
|
||||
assert calls
|
||||
|
||||
assert calls[0][3] == ssl.PROTOCOL_TLSv1_2
|
||||
|
||||
|
||||
async def test_setup_with_tls_config_of_v1_under_python36_only_uses_v1(hass, mock_mqtt):
|
||||
async def test_setup_with_tls_config_of_v1_under_python36_only_uses_v1(hass):
|
||||
"""Test setup uses TLSv1.0 if explicitly chosen."""
|
||||
entry = MockConfigEntry(
|
||||
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "tls_version": "1.0"}
|
||||
)
|
||||
calls = []
|
||||
|
||||
assert await mqtt.async_setup_entry(hass, entry)
|
||||
def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
|
||||
calls.append((certificate, certfile, keyfile, tls_version))
|
||||
|
||||
assert mock_mqtt.called
|
||||
assert mock_mqtt.mock_calls[0][2]["tls_version"] == ssl.PROTOCOL_TLSv1
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
mock_client().tls_set = mock_tls_set
|
||||
entry = MockConfigEntry(
|
||||
domain=mqtt.DOMAIN,
|
||||
data={
|
||||
"certificate": "auto",
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
"tls_version": "1.0",
|
||||
},
|
||||
)
|
||||
|
||||
assert await mqtt.async_setup_entry(hass, entry)
|
||||
|
||||
assert calls
|
||||
|
||||
assert calls[0][3] == ssl.PROTOCOL_TLSv1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -46,8 +46,8 @@ class TestMQTT:
|
|||
)
|
||||
self.hass.block_till_done()
|
||||
assert mock_mqtt.called
|
||||
assert mock_mqtt.mock_calls[1][2]["username"] == "homeassistant"
|
||||
assert mock_mqtt.mock_calls[1][2]["password"] == password
|
||||
assert mock_mqtt.mock_calls[1][1][2]["username"] == "homeassistant"
|
||||
assert mock_mqtt.mock_calls[1][1][2]["password"] == password
|
||||
|
||||
@patch("passlib.apps.custom_app_context", Mock(return_value=""))
|
||||
@patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock()))
|
||||
|
@ -69,8 +69,8 @@ class TestMQTT:
|
|||
)
|
||||
self.hass.block_till_done()
|
||||
assert mock_mqtt.called
|
||||
assert mock_mqtt.mock_calls[1][2]["username"] == "homeassistant"
|
||||
assert mock_mqtt.mock_calls[1][2]["password"] == password
|
||||
assert mock_mqtt.mock_calls[1][1][2]["username"] == "homeassistant"
|
||||
assert mock_mqtt.mock_calls[1][1][2]["password"] == password
|
||||
|
||||
@patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock()))
|
||||
@patch("hbmqtt.broker.Broker.start", return_value=mock_coro())
|
||||
|
|
|
@ -304,8 +304,11 @@ async def mqtt_mock(hass, mqtt_client_mock, mqtt_config):
|
|||
assert result
|
||||
await hass.async_block_till_done()
|
||||
|
||||
mqtt_component_mock = MagicMock(spec_set=hass.data["mqtt"], wraps=hass.data["mqtt"])
|
||||
hass.data["mqtt"].connected = mqtt_component_mock.connected
|
||||
mqtt_component_mock = MagicMock(
|
||||
return_value=hass.data["mqtt"],
|
||||
spec_set=hass.data["mqtt"],
|
||||
wraps=hass.data["mqtt"],
|
||||
)
|
||||
mqtt_component_mock._mqttc = mqtt_client_mock
|
||||
|
||||
hass.data["mqtt"] = mqtt_component_mock
|
||||
|
|
Loading…
Add table
Reference in a new issue