Support reconfiguring MQTT config entry (#36537)

This commit is contained in:
Erik Montnemery 2020-06-23 02:49:01 +02:00 committed by GitHub
parent ee816ed3dd
commit 747490ab34
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 881 additions and 249 deletions

View file

@ -19,7 +19,7 @@ DEFAULT_QOS = 0
TRIGGER_SCHEMA = vol.Schema(
{
vol.Required(CONF_PLATFORM): mqtt.DOMAIN,
vol.Required(CONF_TOPIC): mqtt.valid_subscribe_topic,
vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic,
vol.Optional(CONF_PAYLOAD): cv.string,
vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string,
vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All(

View file

@ -274,7 +274,7 @@ async def system_options_update(hass, connection, msg):
{"type": "config_entries/update", "entry_id": str, vol.Optional("title"): str}
)
async def config_entry_update(hass, connection, msg):
"""Update config entry system options."""
"""Update config entry."""
changes = dict(msg)
changes.pop("id")
changes.pop("type")

View file

@ -12,7 +12,7 @@ import sys
from typing import Any, Callable, List, Optional, Union
import attr
import requests.certs
import certifi
import voluptuous as vol
from homeassistant import config_entries
@ -46,11 +46,20 @@ from . import debug_info, discovery, server
from .const import (
ATTR_DISCOVERY_HASH,
ATTR_DISCOVERY_TOPIC,
ATTR_PAYLOAD,
ATTR_QOS,
ATTR_RETAIN,
ATTR_TOPIC,
CONF_BIRTH_MESSAGE,
CONF_BROKER,
CONF_DISCOVERY,
CONF_QOS,
CONF_RETAIN,
CONF_STATE_TOPIC,
CONF_WILL_MESSAGE,
DEFAULT_DISCOVERY,
DEFAULT_QOS,
DEFAULT_RETAIN,
MQTT_CONNECTED,
MQTT_DISCONNECTED,
PROTOCOL_311,
@ -59,6 +68,7 @@ from .debug_info import log_messages
from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash, set_discovery_hash
from .models import Message, MessageCallbackType, PublishPayloadType
from .subscription import async_subscribe_topics, async_unsubscribe_topics
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__)
@ -80,17 +90,12 @@ CONF_CLIENT_CERT = "client_cert"
CONF_TLS_INSECURE = "tls_insecure"
CONF_TLS_VERSION = "tls_version"
CONF_BIRTH_MESSAGE = "birth_message"
CONF_WILL_MESSAGE = "will_message"
CONF_COMMAND_TOPIC = "command_topic"
CONF_AVAILABILITY_TOPIC = "availability_topic"
CONF_PAYLOAD_AVAILABLE = "payload_available"
CONF_PAYLOAD_NOT_AVAILABLE = "payload_not_available"
CONF_JSON_ATTRS_TOPIC = "json_attributes_topic"
CONF_JSON_ATTRS_TEMPLATE = "json_attributes_template"
CONF_QOS = "qos"
CONF_RETAIN = "retain"
CONF_UNIQUE_ID = "unique_id"
CONF_IDENTIFIERS = "identifiers"
@ -105,18 +110,13 @@ PROTOCOL_31 = "3.1"
DEFAULT_PORT = 1883
DEFAULT_KEEPALIVE = 60
DEFAULT_RETAIN = False
DEFAULT_PROTOCOL = PROTOCOL_311
DEFAULT_DISCOVERY_PREFIX = "homeassistant"
DEFAULT_TLS_PROTOCOL = "auto"
DEFAULT_PAYLOAD_AVAILABLE = "online"
DEFAULT_PAYLOAD_NOT_AVAILABLE = "offline"
ATTR_TOPIC = "topic"
ATTR_PAYLOAD = "payload"
ATTR_PAYLOAD_TEMPLATE = "payload_template"
ATTR_QOS = CONF_QOS
ATTR_RETAIN = CONF_RETAIN
MAX_RECONNECT_WAIT = 300 # seconds
@ -125,59 +125,6 @@ CONNECTION_FAILED = "connection_failed"
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
def valid_topic(value: Any) -> str:
"""Validate that this is a valid topic name/filter."""
value = cv.string(value)
try:
raw_value = value.encode("utf-8")
except UnicodeError:
raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.")
if not raw_value:
raise vol.Invalid("MQTT topic name/filter must not be empty.")
if len(raw_value) > 65535:
raise vol.Invalid(
"MQTT topic name/filter must not be longer than 65535 encoded bytes."
)
if "\0" in value:
raise vol.Invalid("MQTT topic name/filter must not contain null character.")
return value
def valid_subscribe_topic(value: Any) -> str:
"""Validate that we can subscribe using this MQTT topic."""
value = valid_topic(value)
for i in (i for i, c in enumerate(value) if c == "+"):
if (i > 0 and value[i - 1] != "/") or (
i < len(value) - 1 and value[i + 1] != "/"
):
raise vol.Invalid(
"Single-level wildcard must occupy an entire level of the filter"
)
index = value.find("#")
if index != -1:
if index != len(value) - 1:
# If there are multiple wildcards, this will also trigger
raise vol.Invalid(
"Multi-level wildcard must be the last "
"character in the topic filter."
)
if len(value) > 1 and value[index - 1] != "/":
raise vol.Invalid(
"Multi-level wildcard must be after a topic level separator."
)
return value
def valid_publish_topic(value: Any) -> str:
"""Validate that we can publish using this MQTT topic."""
value = valid_topic(value)
if "+" in value or "#" in value:
raise vol.Invalid("Wildcards can not be used in topic names")
return value
def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType:
"""Validate that a device info entry has at least one identifying value."""
if not value.get(CONF_IDENTIFIERS) and not value.get(CONF_CONNECTIONS):
@ -188,8 +135,6 @@ def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType
return value
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
CLIENT_KEY_AUTH_MSG = (
"client_key and client_cert must both be present in "
"the MQTT broker configuration"
@ -554,6 +499,11 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
return True
def _merge_config(entry, conf):
"""Merge configuration.yaml config with config entry."""
return {**conf, **entry.data}
async def async_setup_entry(hass, entry):
"""Load a config entry."""
conf = hass.data.get(DATA_MQTT_CONFIG)
@ -574,76 +524,9 @@ async def async_setup_entry(hass, entry):
entry.data,
)
conf.update(entry.data)
conf = _merge_config(entry, conf)
broker = conf[CONF_BROKER]
port = conf[CONF_PORT]
client_id = conf.get(CONF_CLIENT_ID)
keepalive = conf[CONF_KEEPALIVE]
username = conf.get(CONF_USERNAME)
password = conf.get(CONF_PASSWORD)
certificate = conf.get(CONF_CERTIFICATE)
client_key = conf.get(CONF_CLIENT_KEY)
client_cert = conf.get(CONF_CLIENT_CERT)
tls_insecure = conf.get(CONF_TLS_INSECURE)
protocol = conf[CONF_PROTOCOL]
# For cloudmqtt.com, secured connection, auto fill in certificate
if (
certificate is None
and 19999 < conf[CONF_PORT] < 30000
and broker.endswith(".cloudmqtt.com")
):
certificate = os.path.join(
os.path.dirname(__file__), "addtrustexternalcaroot.crt"
)
# When the certificate is set to auto, use bundled certs from requests
elif certificate == "auto":
certificate = requests.certs.where()
if CONF_WILL_MESSAGE in conf:
will_message = Message(**conf[CONF_WILL_MESSAGE])
else:
will_message = None
if CONF_BIRTH_MESSAGE in conf:
birth_message = Message(**conf[CONF_BIRTH_MESSAGE])
else:
birth_message = None
# Be able to override versions other than TLSv1.0 under Python3.6
conf_tls_version: str = conf.get(CONF_TLS_VERSION)
if conf_tls_version == "1.2":
tls_version = ssl.PROTOCOL_TLSv1_2
elif conf_tls_version == "1.1":
tls_version = ssl.PROTOCOL_TLSv1_1
elif conf_tls_version == "1.0":
tls_version = ssl.PROTOCOL_TLSv1
else:
# Python3.6 supports automatic negotiation of highest TLS version
if sys.hexversion >= 0x03060000:
tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member
else:
tls_version = ssl.PROTOCOL_TLSv1
hass.data[DATA_MQTT] = MQTT(
hass,
broker=broker,
port=port,
client_id=client_id,
keepalive=keepalive,
username=username,
password=password,
certificate=certificate,
client_key=client_key,
client_cert=client_cert,
tls_insecure=tls_insecure,
protocol=protocol,
will_message=will_message,
birth_message=birth_message,
tls_version=tls_version,
)
hass.data[DATA_MQTT] = MQTT(hass, entry, conf,)
await hass.data[DATA_MQTT].async_connect()
@ -732,53 +615,101 @@ class Subscription:
class MQTT:
"""Home Assistant MQTT client."""
def __init__(
self,
hass: HomeAssistantType,
broker: str,
port: int,
client_id: Optional[str],
keepalive: Optional[int],
username: Optional[str],
password: Optional[str],
certificate: Optional[str],
client_key: Optional[str],
client_cert: Optional[str],
tls_insecure: Optional[bool],
protocol: Optional[str],
will_message: Optional[Message],
birth_message: Optional[Message],
tls_version: Optional[int],
) -> None:
def __init__(self, hass: HomeAssistantType, config_entry, conf,) -> None:
"""Initialize Home Assistant MQTT client."""
# We don't import them on the top because some integrations
# We don't import on the top because some integrations
# should be able to optionally rely on MQTT.
# pylint: disable=import-outside-toplevel
import paho.mqtt.client as mqtt
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
self.hass = hass
self.broker = broker
self.port = port
self.keepalive = keepalive
self.config_entry = config_entry
self.conf = conf
self.subscriptions: List[Subscription] = []
self.birth_message = birth_message
self.connected = False
self._mqttc: mqtt.Client = None
self._paho_lock = asyncio.Lock()
if protocol == PROTOCOL_31:
self.init_client()
self.config_entry.add_update_listener(self.async_config_entry_updated)
@staticmethod
async def async_config_entry_updated(hass, entry) -> None:
"""Handle signals of config entry being updated.
This is a static method because a class method (bound method), can not be used with weak references.
Causes for this is config entry options changing.
"""
self = hass.data[DATA_MQTT]
conf = hass.data.get(DATA_MQTT_CONFIG)
if conf is None:
conf = CONFIG_SCHEMA({DOMAIN: dict(entry.data)})[DOMAIN]
self.conf = _merge_config(entry, conf)
await self.async_disconnect()
self.init_client()
await self.async_connect()
await discovery.async_stop(hass)
if self.conf.get(CONF_DISCOVERY):
await _async_setup_discovery(hass, self.conf, entry)
def init_client(self):
"""Initialize paho client."""
# We don't import on the top because some integrations
# should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
if self.conf[CONF_PROTOCOL] == PROTOCOL_31:
proto: int = mqtt.MQTTv31
else:
proto = mqtt.MQTTv311
client_id = self.conf.get(CONF_CLIENT_ID)
if client_id is None:
self._mqttc = mqtt.Client(protocol=proto)
else:
self._mqttc = mqtt.Client(client_id, protocol=proto)
username = self.conf.get(CONF_USERNAME)
password = self.conf.get(CONF_PASSWORD)
if username is not None:
self._mqttc.username_pw_set(username, password)
certificate = self.conf.get(CONF_CERTIFICATE)
# For cloudmqtt.com, secured connection, auto fill in certificate
if (
certificate is None
and 19999 < self.conf[CONF_PORT] < 30000
and self.conf[CONF_BROKER].endswith(".cloudmqtt.com")
):
certificate = os.path.join(
os.path.dirname(__file__), "addtrustexternalcaroot.crt"
)
# When the certificate is set to auto, use bundled certs from certifi
elif certificate == "auto":
certificate = certifi.where()
# Be able to override versions other than TLSv1.0 under Python3.6
conf_tls_version: str = self.conf.get(CONF_TLS_VERSION)
if conf_tls_version == "1.2":
tls_version = ssl.PROTOCOL_TLSv1_2
elif conf_tls_version == "1.1":
tls_version = ssl.PROTOCOL_TLSv1_1
elif conf_tls_version == "1.0":
tls_version = ssl.PROTOCOL_TLSv1
else:
# Python3.6 supports automatic negotiation of highest TLS version
if sys.hexversion >= 0x03060000:
tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member
else:
tls_version = ssl.PROTOCOL_TLSv1
client_key = self.conf.get(CONF_CLIENT_KEY)
client_cert = self.conf.get(CONF_CLIENT_CERT)
tls_insecure = self.conf.get(CONF_TLS_INSECURE)
if certificate is not None:
self._mqttc.tls_set(
certificate,
@ -794,6 +725,11 @@ class MQTT:
self._mqttc.on_disconnect = self._mqtt_on_disconnect
self._mqttc.on_message = self._mqtt_on_message
if CONF_WILL_MESSAGE in self.conf:
will_message = Message(**self.conf[CONF_WILL_MESSAGE])
else:
will_message = None
if will_message is not None:
self._mqttc.will_set( # pylint: disable=no-value-for-parameter
*attr.astuple(
@ -813,14 +749,17 @@ class MQTT:
)
async def async_connect(self) -> str:
"""Connect to the host. Does process messages yet."""
"""Connect to the host. Does not process messages yet."""
# pylint: disable=import-outside-toplevel
import paho.mqtt.client as mqtt
result: int = None
try:
result = await self.hass.async_add_executor_job(
self._mqttc.connect, self.broker, self.port, self.keepalive
self._mqttc.connect,
self.conf[CONF_BROKER],
self.conf[CONF_PORT],
self.conf[CONF_KEEPALIVE],
)
except OSError as err:
_LOGGER.error("Failed to connect to MQTT server due to exception: %s", err)
@ -922,7 +861,12 @@ class MQTT:
self.connected = True
dispatcher_send(self.hass, MQTT_CONNECTED)
_LOGGER.info("Connected to MQTT server (%s)", result_code)
_LOGGER.info(
"Connected to MQTT server %s:%s (%s)",
self.conf[CONF_BROKER],
self.conf[CONF_PORT],
result_code,
)
# Group subscriptions to only re-subscribe once for each topic.
keyfunc = attrgetter("topic")
@ -931,11 +875,12 @@ class MQTT:
max_qos = max(subscription.qos for subscription in subs)
self.hass.add_job(self._async_perform_subscription, topic, max_qos)
if self.birth_message:
if CONF_BIRTH_MESSAGE in self.conf:
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
self.hass.add_job(
self.async_publish( # pylint: disable=no-value-for-parameter
*attr.astuple(
self.birth_message,
birth_message,
filter=lambda attr, value: attr.name
not in ["subscribed_topic", "timestamp"],
)
@ -990,7 +935,12 @@ class MQTT:
"""Disconnected callback."""
self.connected = False
dispatcher_send(self.hass, MQTT_DISCONNECTED)
_LOGGER.warning("Disconnected from MQTT server (%s)", result_code)
_LOGGER.warning(
"Disconnected from MQTT server %s:%s (%s)",
self.conf[CONF_BROKER],
self.conf[CONF_PORT],
result_code,
)
def _raise_on_error(result_code: int) -> None:

View file

@ -1,5 +1,6 @@
"""Config flow for MQTT."""
from collections import OrderedDict
import logging
import queue
import voluptuous as vol
@ -13,7 +14,22 @@ from homeassistant.const import (
CONF_USERNAME,
)
from .const import CONF_BROKER, CONF_DISCOVERY, DEFAULT_DISCOVERY
from .const import (
ATTR_PAYLOAD,
ATTR_QOS,
ATTR_RETAIN,
ATTR_TOPIC,
CONF_BIRTH_MESSAGE,
CONF_BROKER,
CONF_DISCOVERY,
CONF_WILL_MESSAGE,
DEFAULT_DISCOVERY,
DEFAULT_QOS,
DEFAULT_RETAIN,
)
from .util import MQTT_WILL_BIRTH_SCHEMA
_LOGGER = logging.getLogger(__name__)
@config_entries.HANDLERS.register("mqtt")
@ -25,6 +41,11 @@ class FlowHandler(config_entries.ConfigFlow):
_hassio_discovery = None
@staticmethod
def async_get_options_flow(config_entry):
"""Get the options flow for this handler."""
return MQTTOptionsFlowHandler(config_entry)
async def async_step_user(self, user_input=None):
"""Handle a flow initialized by the user."""
if self._async_current_entries():
@ -123,6 +144,163 @@ class FlowHandler(config_entries.ConfigFlow):
)
class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
"""Handle MQTT options."""
def __init__(self, config_entry):
"""Initialize MQTT options flow."""
self.config_entry = config_entry
self.broker_config = {}
self.options = dict(config_entry.options)
async def async_step_init(self, user_input=None):
"""Manage the MQTT options."""
return await self.async_step_broker()
async def async_step_broker(self, user_input=None):
"""Manage the MQTT options."""
errors = {}
current_config = self.config_entry.data
if user_input is not None:
can_connect = await self.hass.async_add_executor_job(
try_connection,
user_input[CONF_BROKER],
user_input[CONF_PORT],
user_input.get(CONF_USERNAME),
user_input.get(CONF_PASSWORD),
)
if can_connect:
self.broker_config.update(user_input)
return await self.async_step_options()
errors["base"] = "cannot_connect"
fields = OrderedDict()
fields[vol.Required(CONF_BROKER, default=current_config[CONF_BROKER])] = str
fields[vol.Required(CONF_PORT, default=current_config[CONF_PORT])] = vol.Coerce(
int
)
fields[
vol.Optional(
CONF_USERNAME,
description={"suggested_value": current_config.get(CONF_USERNAME)},
)
] = str
fields[
vol.Optional(
CONF_PASSWORD,
description={"suggested_value": current_config.get(CONF_PASSWORD)},
)
] = str
return self.async_show_form(
step_id="broker", data_schema=vol.Schema(fields), errors=errors,
)
async def async_step_options(self, user_input=None):
"""Manage the MQTT options."""
errors = {}
current_config = self.config_entry.data
options_config = {}
if user_input is not None:
bad_birth = False
bad_will = False
if "birth_topic" in user_input:
birth_message = {
ATTR_TOPIC: user_input["birth_topic"],
ATTR_PAYLOAD: user_input.get("birth_payload", ""),
ATTR_QOS: user_input["birth_qos"],
ATTR_RETAIN: user_input["birth_retain"],
}
try:
birth_message = MQTT_WILL_BIRTH_SCHEMA(birth_message)
options_config[CONF_BIRTH_MESSAGE] = birth_message
except vol.Invalid:
errors["base"] = "bad_birth"
bad_birth = True
if "will_topic" in user_input:
will_message = {
ATTR_TOPIC: user_input["will_topic"],
ATTR_PAYLOAD: user_input.get("will_payload", ""),
ATTR_QOS: user_input["will_qos"],
ATTR_RETAIN: user_input["will_retain"],
}
try:
will_message = MQTT_WILL_BIRTH_SCHEMA(will_message)
options_config[CONF_WILL_MESSAGE] = will_message
except vol.Invalid:
errors["base"] = "bad_will"
bad_will = True
options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY]
if not bad_birth and not bad_will:
updated_config = {}
updated_config.update(self.broker_config)
updated_config.update(options_config)
self.hass.config_entries.async_update_entry(
self.config_entry, data=updated_config
)
return self.async_create_entry(title="", data=None)
birth_topic = None
birth_payload = None
birth_qos = DEFAULT_QOS
birth_retain = DEFAULT_RETAIN
if CONF_BIRTH_MESSAGE in current_config:
birth_topic = current_config[CONF_BIRTH_MESSAGE][ATTR_TOPIC]
birth_payload = current_config[CONF_BIRTH_MESSAGE][ATTR_PAYLOAD]
birth_qos = current_config[CONF_BIRTH_MESSAGE].get(ATTR_QOS, DEFAULT_QOS)
birth_retain = current_config[CONF_BIRTH_MESSAGE].get(
ATTR_RETAIN, DEFAULT_RETAIN
)
will_topic = None
will_payload = None
will_qos = DEFAULT_QOS
will_retain = DEFAULT_RETAIN
if CONF_WILL_MESSAGE in current_config:
will_topic = current_config[CONF_WILL_MESSAGE][ATTR_TOPIC]
will_payload = current_config[CONF_WILL_MESSAGE][ATTR_PAYLOAD]
will_qos = current_config[CONF_WILL_MESSAGE].get(ATTR_QOS, DEFAULT_QOS)
will_retain = current_config[CONF_WILL_MESSAGE].get(
ATTR_RETAIN, DEFAULT_RETAIN
)
fields = OrderedDict()
fields[
vol.Optional(
CONF_DISCOVERY,
default=current_config.get(CONF_DISCOVERY, DEFAULT_DISCOVERY),
)
] = bool
fields[
vol.Optional("birth_topic", description={"suggested_value": birth_topic})
] = str
fields[
vol.Optional(
"birth_payload", description={"suggested_value": birth_payload}
)
] = str
fields[vol.Optional("birth_qos", default=birth_qos)] = vol.In([0, 1, 2])
fields[vol.Optional("birth_retain", default=birth_retain)] = bool
fields[
vol.Optional("will_topic", description={"suggested_value": will_topic})
] = str
fields[
vol.Optional("will_payload", description={"suggested_value": will_payload})
] = str
fields[vol.Optional("will_qos", default=will_qos)] = vol.In([0, 1, 2])
fields[vol.Optional("will_retain", default=will_retain)] = bool
return self.async_show_form(
step_id="options", data_schema=vol.Schema(fields), errors=errors,
)
def try_connection(broker, port, username, password, protocol="3.1"):
"""Test if we can connect to an MQTT broker."""
# pylint: disable=import-outside-toplevel

View file

@ -1,14 +1,25 @@
"""Constants used by multiple MQTT modules."""
CONF_BROKER = "broker"
CONF_DISCOVERY = "discovery"
DEFAULT_DISCOVERY = False
ATTR_DISCOVERY_HASH = "discovery_hash"
ATTR_DISCOVERY_PAYLOAD = "discovery_payload"
ATTR_DISCOVERY_TOPIC = "discovery_topic"
ATTR_PAYLOAD = "payload"
ATTR_QOS = "qos"
ATTR_RETAIN = "retain"
ATTR_TOPIC = "topic"
CONF_BROKER = "broker"
CONF_BIRTH_MESSAGE = "birth_message"
CONF_DISCOVERY = "discovery"
CONF_QOS = ATTR_QOS
CONF_RETAIN = ATTR_RETAIN
CONF_STATE_TOPIC = "state_topic"
PROTOCOL_311 = "3.1.1"
CONF_WILL_MESSAGE = "will_message"
DEFAULT_DISCOVERY = False
DEFAULT_QOS = 0
DEFAULT_RETAIN = False
MQTT_CONNECTED = "mqtt_connected"
MQTT_DISCONNECTED = "mqtt_disconnected"
PROTOCOL_311 = "3.1.1"

View file

@ -35,8 +35,9 @@ SUPPORTED_COMPONENTS = [
]
ALREADY_DISCOVERED = "mqtt_discovered_components"
DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock"
CONFIG_ENTRY_IS_SETUP = "mqtt_config_entry_is_setup"
DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock"
DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe"
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
@ -163,8 +164,15 @@ async def async_start(
hass.data[DATA_CONFIG_ENTRY_LOCK] = asyncio.Lock()
hass.data[CONFIG_ENTRY_IS_SETUP] = set()
await mqtt.async_subscribe(
hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe(
hass, f"{discovery_topic}/#", async_device_message_received, 0
)
return True
async def async_stop(hass: HomeAssistantType) -> bool:
"""Stop MQTT Discovery."""
if DISCOVERY_UNSUBSCRIBE in hass.data and hass.data[DISCOVERY_UNSUBSCRIBE]:
hass.data[DISCOVERY_UNSUBSCRIBE]()
hass.data[DISCOVERY_UNSUBSCRIBE] = None

View file

@ -47,5 +47,37 @@
"button_5": "Fifth button",
"button_6": "Sixth button"
}
},
"options": {
"step": {
"broker": {
"description": "Please enter the connection information of your MQTT broker.",
"data": {
"broker": "Broker",
"port": "[%key:common::config_flow::data::port%]",
"username": "[%key:common::config_flow::data::username%]",
"password": "[%key:common::config_flow::data::password%]"
}
},
"options": {
"description": "Please select MQTT options.",
"data": {
"discovery": "Enable discovery",
"birth_topic": "Birth message topic",
"birth_payload": "Birth message payload",
"birth_qos": "Birth message QoS",
"birth_retain": "Birth message retain",
"will_topic": "Will message topic",
"will_payload": "Will message payload",
"will_qos": "Will message QoS",
"will_retain": "Will message retain"
}
}
},
"error": {
"cannot_connect": "Unable to connect to the broker.",
"bad_birth": "Invalid birth topic.",
"bad_will": "Invalid will topic."
}
}
}

View file

@ -0,0 +1,82 @@
"""Utility functions for the MQTT integration."""
from typing import Any
import voluptuous as vol
from homeassistant.const import CONF_PAYLOAD
from homeassistant.helpers import config_validation as cv
from .const import (
ATTR_PAYLOAD,
ATTR_QOS,
ATTR_RETAIN,
ATTR_TOPIC,
DEFAULT_QOS,
DEFAULT_RETAIN,
)
def valid_topic(value: Any) -> str:
"""Validate that this is a valid topic name/filter."""
value = cv.string(value)
try:
raw_value = value.encode("utf-8")
except UnicodeError:
raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.")
if not raw_value:
raise vol.Invalid("MQTT topic name/filter must not be empty.")
if len(raw_value) > 65535:
raise vol.Invalid(
"MQTT topic name/filter must not be longer than 65535 encoded bytes."
)
if "\0" in value:
raise vol.Invalid("MQTT topic name/filter must not contain null character.")
return value
def valid_subscribe_topic(value: Any) -> str:
"""Validate that we can subscribe using this MQTT topic."""
value = valid_topic(value)
for i in (i for i, c in enumerate(value) if c == "+"):
if (i > 0 and value[i - 1] != "/") or (
i < len(value) - 1 and value[i + 1] != "/"
):
raise vol.Invalid(
"Single-level wildcard must occupy an entire level of the filter"
)
index = value.find("#")
if index != -1:
if index != len(value) - 1:
# If there are multiple wildcards, this will also trigger
raise vol.Invalid(
"Multi-level wildcard must be the last "
"character in the topic filter."
)
if len(value) > 1 and value[index - 1] != "/":
raise vol.Invalid(
"Multi-level wildcard must be after a topic level separator."
)
return value
def valid_publish_topic(value: Any) -> str:
"""Validate that we can publish using this MQTT topic."""
value = valid_topic(value)
if "+" in value or "#" in value:
raise vol.Invalid("Wildcards can not be used in topic names")
return value
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
MQTT_WILL_BIRTH_SCHEMA = vol.Schema(
{
vol.Required(ATTR_TOPIC): valid_publish_topic,
vol.Required(ATTR_PAYLOAD, CONF_PAYLOAD): cv.string,
vol.Optional(ATTR_QOS, default=DEFAULT_QOS): _VALID_QOS_SCHEMA,
vol.Optional(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean,
},
required=True,
)

View file

@ -1008,7 +1008,8 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
entry = self.hass.config_entries.async_get_entry(flow.handler)
if entry is None:
raise UnknownEntry(flow.handler)
self.hass.config_entries.async_update_entry(entry, options=result["data"])
if result["data"] is not None:
self.hass.config_entries.async_update_entry(entry, options=result["data"])
result["result"] = True
return result

View file

@ -1,7 +1,11 @@
"""Test config flow."""
import pytest
import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant.components import mqtt
from homeassistant.components.mqtt.discovery import async_start
from homeassistant.setup import async_setup_component
from tests.async_mock import patch
@ -144,3 +148,340 @@ async def test_hassio_confirm(hass, mock_try_connection, mock_finish_setup):
assert len(mock_try_connection.mock_calls) == 1
# Check config entry got setup
assert len(mock_finish_setup.mock_calls) == 1
async def test_option_flow(hass, mqtt_mock, mock_try_connection):
"""Test config flow options."""
mock_try_connection.return_value = True
config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
await async_start(hass, "homeassistant", config_entry)
config_entry.data = {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
mqtt_mock.async_connect.reset_mock()
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "broker"
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
mqtt.CONF_BROKER: "another-broker",
mqtt.CONF_PORT: 2345,
mqtt.CONF_USERNAME: "user",
mqtt.CONF_PASSWORD: "pass",
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "options"
await hass.async_block_till_done()
assert mqtt_mock.async_connect.call_count == 0
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
mqtt.CONF_DISCOVERY: True,
"birth_topic": "ha_state/online",
"birth_payload": "online",
"birth_qos": 1,
"birth_retain": True,
"will_topic": "ha_state/offline",
"will_payload": "offline",
"will_qos": 2,
"will_retain": True,
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["data"] is None
assert config_entry.data == {
mqtt.CONF_BROKER: "another-broker",
mqtt.CONF_PORT: 2345,
mqtt.CONF_USERNAME: "user",
mqtt.CONF_PASSWORD: "pass",
mqtt.CONF_DISCOVERY: True,
mqtt.CONF_BIRTH_MESSAGE: {
mqtt.ATTR_TOPIC: "ha_state/online",
mqtt.ATTR_PAYLOAD: "online",
mqtt.ATTR_QOS: 1,
mqtt.ATTR_RETAIN: True,
},
mqtt.CONF_WILL_MESSAGE: {
mqtt.ATTR_TOPIC: "ha_state/offline",
mqtt.ATTR_PAYLOAD: "offline",
mqtt.ATTR_QOS: 2,
mqtt.ATTR_RETAIN: True,
},
}
await hass.async_block_till_done()
assert mqtt_mock.async_connect.call_count == 1
def get_default(schema, key):
"""Get default value for key in voluptuous schema."""
for k in schema.keys():
if k == key:
if k.default == vol.UNDEFINED:
return None
return k.default()
def get_suggested(schema, key):
"""Get suggested value for key in voluptuous schema."""
for k in schema.keys():
if k == key:
if k.description is None or "suggested_value" not in k.description:
return None
return k.description["suggested_value"]
async def test_option_flow_default_suggested_values(
hass, mqtt_mock, mock_try_connection
):
"""Test config flow options has default/suggested values."""
mock_try_connection.return_value = True
config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
await async_start(hass, "homeassistant", config_entry)
config_entry.data = {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
mqtt.CONF_USERNAME: "user",
mqtt.CONF_PASSWORD: "pass",
mqtt.CONF_DISCOVERY: True,
mqtt.CONF_BIRTH_MESSAGE: {
mqtt.ATTR_TOPIC: "ha_state/online",
mqtt.ATTR_PAYLOAD: "online",
mqtt.ATTR_QOS: 1,
mqtt.ATTR_RETAIN: True,
},
mqtt.CONF_WILL_MESSAGE: {
mqtt.ATTR_TOPIC: "ha_state/offline",
mqtt.ATTR_PAYLOAD: "offline",
mqtt.ATTR_QOS: 2,
mqtt.ATTR_RETAIN: False,
},
}
# Test default/suggested values from config
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "broker"
defaults = {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
suggested = {
mqtt.CONF_USERNAME: "user",
mqtt.CONF_PASSWORD: "pass",
}
for k, v in defaults.items():
assert get_default(result["data_schema"].schema, k) == v
for k, v in suggested.items():
assert get_suggested(result["data_schema"].schema, k) == v
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
mqtt.CONF_BROKER: "another-broker",
mqtt.CONF_PORT: 2345,
mqtt.CONF_USERNAME: "us3r",
mqtt.CONF_PASSWORD: "p4ss",
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "options"
defaults = {
mqtt.CONF_DISCOVERY: True,
"birth_qos": 1,
"birth_retain": True,
"will_qos": 2,
"will_retain": False,
}
suggested = {
"birth_topic": "ha_state/online",
"birth_payload": "online",
"will_topic": "ha_state/offline",
"will_payload": "offline",
}
for k, v in defaults.items():
assert get_default(result["data_schema"].schema, k) == v
for k, v in suggested.items():
assert get_suggested(result["data_schema"].schema, k) == v
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
mqtt.CONF_DISCOVERY: False,
"birth_topic": "ha_state/onl1ne",
"birth_payload": "onl1ne",
"birth_qos": 2,
"birth_retain": False,
"will_topic": "ha_state/offl1ne",
"will_payload": "offl1ne",
"will_qos": 1,
"will_retain": True,
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
# Test updated default/suggested values from config
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "broker"
defaults = {
mqtt.CONF_BROKER: "another-broker",
mqtt.CONF_PORT: 2345,
}
suggested = {
mqtt.CONF_USERNAME: "us3r",
mqtt.CONF_PASSWORD: "p4ss",
}
for k, v in defaults.items():
assert get_default(result["data_schema"].schema, k) == v
for k, v in suggested.items():
assert get_suggested(result["data_schema"].schema, k) == v
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "options"
defaults = {
mqtt.CONF_DISCOVERY: False,
"birth_qos": 2,
"birth_retain": False,
"will_qos": 1,
"will_retain": True,
}
suggested = {
"birth_topic": "ha_state/onl1ne",
"birth_payload": "onl1ne",
"will_topic": "ha_state/offl1ne",
"will_payload": "offl1ne",
}
for k, v in defaults.items():
assert get_default(result["data_schema"].schema, k) == v
for k, v in suggested.items():
assert get_suggested(result["data_schema"].schema, k) == v
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
mqtt.CONF_DISCOVERY: True,
"birth_topic": "ha_state/onl1ne",
"birth_payload": "onl1ne",
"birth_qos": 2,
"birth_retain": False,
"will_topic": "ha_state/offl1ne",
"will_payload": "offl1ne",
"will_qos": 1,
"will_retain": True,
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
async def test_options_user_connection_fails(hass, mock_try_connection):
"""Test if connection cannot be made."""
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
config_entry.add_to_hass(hass)
config_entry.data = {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
mock_try_connection.return_value = False
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={mqtt.CONF_BROKER: "bad-broker", mqtt.CONF_PORT: 2345},
)
assert result["type"] == "form"
assert result["errors"]["base"] == "cannot_connect"
# Check we tried the connection
assert len(mock_try_connection.mock_calls) == 1
# Check config entry did not update
assert config_entry.data == {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
async def test_options_bad_birth_message_fails(hass, mock_try_connection):
"""Test bad birth message."""
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
config_entry.add_to_hass(hass)
config_entry.data = {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
mock_try_connection.return_value = True
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345},
)
assert result["type"] == "form"
assert result["step_id"] == "options"
result = await hass.config_entries.options.async_configure(
result["flow_id"], user_input={"birth_topic": "ha_state/online/#"},
)
assert result["type"] == "form"
assert result["errors"]["base"] == "bad_birth"
# Check config entry did not update
assert config_entry.data == {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
async def test_options_bad_will_message_fails(hass, mock_try_connection):
"""Test bad will message."""
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
config_entry.add_to_hass(hass)
config_entry.data = {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
mock_try_connection.return_value = True
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345},
)
assert result["type"] == "form"
assert result["step_id"] == "options"
result = await hass.config_entries.options.async_configure(
result["flow_id"], user_input={"will_topic": "ha_state/offline/#"},
)
assert result["type"] == "form"
assert result["errors"]["base"] == "bad_will"
# Check config entry did not update
assert config_entry.data == {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}

View file

@ -171,26 +171,26 @@ def test_validate_topic():
"""Test topic name/filter validation."""
# Invalid UTF-8, must not contain U+D800 to U+DFFF.
with pytest.raises(vol.Invalid):
mqtt.valid_topic("\ud800")
mqtt.util.valid_topic("\ud800")
with pytest.raises(vol.Invalid):
mqtt.valid_topic("\udfff")
mqtt.util.valid_topic("\udfff")
# Topic MUST NOT be empty
with pytest.raises(vol.Invalid):
mqtt.valid_topic("")
mqtt.util.valid_topic("")
# Topic MUST NOT be longer than 65535 encoded bytes.
with pytest.raises(vol.Invalid):
mqtt.valid_topic("ü" * 32768)
mqtt.util.valid_topic("ü" * 32768)
# UTF-8 MUST NOT include null character
with pytest.raises(vol.Invalid):
mqtt.valid_topic("bad\0one")
mqtt.util.valid_topic("bad\0one")
# Topics "SHOULD NOT" include these special characters
# (not MUST NOT, RFC2119). The receiver MAY close the connection.
mqtt.valid_topic("\u0001")
mqtt.valid_topic("\u001F")
mqtt.valid_topic("\u009F")
mqtt.valid_topic("\u009F")
mqtt.valid_topic("\uffff")
mqtt.util.valid_topic("\u0001")
mqtt.util.valid_topic("\u001F")
mqtt.util.valid_topic("\u009F")
mqtt.util.valid_topic("\u009F")
mqtt.util.valid_topic("\uffff")
def test_validate_subscribe_topic():
@ -587,7 +587,7 @@ async def test_retained_message_on_subscribe_received(
mqtt_client_mock.subscribe.side_effect = side_effect
# Fake that the client is connected
mqtt_mock.connected = True
mqtt_mock().connected = True
calls_a = MagicMock()
await mqtt.async_subscribe(hass, "test/state", calls_a)
@ -605,7 +605,7 @@ async def test_not_calling_unsubscribe_with_active_subscribers(
):
"""Test not calling unsubscribe() when other subscribers are active."""
# Fake that the client is connected
mqtt_mock.connected = True
mqtt_mock().connected = True
unsub = await mqtt.async_subscribe(hass, "test/state", None)
await mqtt.async_subscribe(hass, "test/state", None)
@ -620,7 +620,7 @@ async def test_not_calling_unsubscribe_with_active_subscribers(
async def test_restore_subscriptions_on_reconnect(hass, mqtt_client_mock, mqtt_mock):
"""Test subscriptions are restored on reconnect."""
# Fake that the client is connected
mqtt_mock.connected = True
mqtt_mock().connected = True
await mqtt.async_subscribe(hass, "test/state", None)
await hass.async_block_till_done()
@ -637,7 +637,7 @@ async def test_restore_all_active_subscriptions_on_reconnect(
):
"""Test active subscriptions are restored correctly on reconnect."""
# Fake that the client is connected
mqtt_mock.connected = True
mqtt_mock().connected = True
mqtt_client_mock.subscribe.side_effect = (
(0, 1),
@ -716,81 +716,107 @@ async def test_setup_raises_ConfigEntryNotReady_if_no_connect_broker(hass, caplo
assert "Failed to connect to MQTT server due to exception:" in caplog.text
async def test_setup_uses_certificate_on_certificate_set_to_auto(hass, mock_mqtt):
async def test_setup_uses_certificate_on_certificate_set_to_auto(hass):
"""Test setup uses bundled certs when certificate is set to auto."""
entry = MockConfigEntry(
domain=mqtt.DOMAIN,
data={mqtt.CONF_BROKER: "test-broker", "certificate": "auto"},
)
calls = []
assert await mqtt.async_setup_entry(hass, entry)
def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
calls.append((certificate, certfile, keyfile, tls_version))
assert mock_mqtt.called
with patch("paho.mqtt.client.Client") as mock_client:
mock_client().tls_set = mock_tls_set
entry = MockConfigEntry(
domain=mqtt.DOMAIN,
data={mqtt.CONF_BROKER: "test-broker", "certificate": "auto"},
)
import requests.certs
assert await mqtt.async_setup_entry(hass, entry)
expectedCertificate = requests.certs.where()
assert mock_mqtt.mock_calls[0][2]["certificate"] == expectedCertificate
assert calls
import certifi
expectedCertificate = certifi.where()
# assert mock_mqtt.mock_calls[0][1][2]["certificate"] == expectedCertificate
assert calls[0][0] == expectedCertificate
async def test_setup_does_not_use_certificate_on_mqtts_port(hass, mock_mqtt):
"""Test setup doesn't use bundled certs when ssl set."""
entry = MockConfigEntry(
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "port": 8883}
)
assert await mqtt.async_setup_entry(hass, entry)
assert mock_mqtt.called
assert mock_mqtt.mock_calls[0][2]["port"] == 8883
import requests.certs
mqttsCertificateBundle = requests.certs.where()
assert mock_mqtt.mock_calls[0][2]["port"] != mqttsCertificateBundle
async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass, mock_mqtt):
async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass):
"""Test setup defaults to TLSv1 under python3.6."""
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
calls = []
assert await mqtt.async_setup_entry(hass, entry)
def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
calls.append((certificate, certfile, keyfile, tls_version))
assert mock_mqtt.called
with patch("paho.mqtt.client.Client") as mock_client:
mock_client().tls_set = mock_tls_set
entry = MockConfigEntry(
domain=mqtt.DOMAIN,
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"},
)
import sys
assert await mqtt.async_setup_entry(hass, entry)
if sys.hexversion >= 0x03060000:
expectedTlsVersion = ssl.PROTOCOL_TLS # pylint: disable=no-member
else:
expectedTlsVersion = ssl.PROTOCOL_TLSv1
assert calls
assert mock_mqtt.mock_calls[0][2]["tls_version"] == expectedTlsVersion
import sys
if sys.hexversion >= 0x03060000:
expectedTlsVersion = ssl.PROTOCOL_TLS # pylint: disable=no-member
else:
expectedTlsVersion = ssl.PROTOCOL_TLSv1
assert calls[0][3] == expectedTlsVersion
async def test_setup_with_tls_config_uses_tls_version1_2(hass, mock_mqtt):
async def test_setup_with_tls_config_uses_tls_version1_2(hass):
"""Test setup uses specified TLS version."""
entry = MockConfigEntry(
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "tls_version": "1.2"}
)
calls = []
assert await mqtt.async_setup_entry(hass, entry)
def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
calls.append((certificate, certfile, keyfile, tls_version))
assert mock_mqtt.called
with patch("paho.mqtt.client.Client") as mock_client:
mock_client().tls_set = mock_tls_set
entry = MockConfigEntry(
domain=mqtt.DOMAIN,
data={
"certificate": "auto",
mqtt.CONF_BROKER: "test-broker",
"tls_version": "1.2",
},
)
assert mock_mqtt.mock_calls[0][2]["tls_version"] == ssl.PROTOCOL_TLSv1_2
assert await mqtt.async_setup_entry(hass, entry)
assert calls
assert calls[0][3] == ssl.PROTOCOL_TLSv1_2
async def test_setup_with_tls_config_of_v1_under_python36_only_uses_v1(hass, mock_mqtt):
async def test_setup_with_tls_config_of_v1_under_python36_only_uses_v1(hass):
"""Test setup uses TLSv1.0 if explicitly chosen."""
entry = MockConfigEntry(
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "tls_version": "1.0"}
)
calls = []
assert await mqtt.async_setup_entry(hass, entry)
def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
calls.append((certificate, certfile, keyfile, tls_version))
assert mock_mqtt.called
assert mock_mqtt.mock_calls[0][2]["tls_version"] == ssl.PROTOCOL_TLSv1
with patch("paho.mqtt.client.Client") as mock_client:
mock_client().tls_set = mock_tls_set
entry = MockConfigEntry(
domain=mqtt.DOMAIN,
data={
"certificate": "auto",
mqtt.CONF_BROKER: "test-broker",
"tls_version": "1.0",
},
)
assert await mqtt.async_setup_entry(hass, entry)
assert calls
assert calls[0][3] == ssl.PROTOCOL_TLSv1
@pytest.mark.parametrize(

View file

@ -46,8 +46,8 @@ class TestMQTT:
)
self.hass.block_till_done()
assert mock_mqtt.called
assert mock_mqtt.mock_calls[1][2]["username"] == "homeassistant"
assert mock_mqtt.mock_calls[1][2]["password"] == password
assert mock_mqtt.mock_calls[1][1][2]["username"] == "homeassistant"
assert mock_mqtt.mock_calls[1][1][2]["password"] == password
@patch("passlib.apps.custom_app_context", Mock(return_value=""))
@patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock()))
@ -69,8 +69,8 @@ class TestMQTT:
)
self.hass.block_till_done()
assert mock_mqtt.called
assert mock_mqtt.mock_calls[1][2]["username"] == "homeassistant"
assert mock_mqtt.mock_calls[1][2]["password"] == password
assert mock_mqtt.mock_calls[1][1][2]["username"] == "homeassistant"
assert mock_mqtt.mock_calls[1][1][2]["password"] == password
@patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock()))
@patch("hbmqtt.broker.Broker.start", return_value=mock_coro())

View file

@ -304,8 +304,11 @@ async def mqtt_mock(hass, mqtt_client_mock, mqtt_config):
assert result
await hass.async_block_till_done()
mqtt_component_mock = MagicMock(spec_set=hass.data["mqtt"], wraps=hass.data["mqtt"])
hass.data["mqtt"].connected = mqtt_component_mock.connected
mqtt_component_mock = MagicMock(
return_value=hass.data["mqtt"],
spec_set=hass.data["mqtt"],
wraps=hass.data["mqtt"],
)
mqtt_component_mock._mqttc = mqtt_client_mock
hass.data["mqtt"] = mqtt_component_mock