Support reconfiguring MQTT config entry (#36537)

This commit is contained in:
Erik Montnemery 2020-06-23 02:49:01 +02:00 committed by GitHub
parent ee816ed3dd
commit 747490ab34
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 881 additions and 249 deletions

View file

@ -19,7 +19,7 @@ DEFAULT_QOS = 0
TRIGGER_SCHEMA = vol.Schema( TRIGGER_SCHEMA = vol.Schema(
{ {
vol.Required(CONF_PLATFORM): mqtt.DOMAIN, vol.Required(CONF_PLATFORM): mqtt.DOMAIN,
vol.Required(CONF_TOPIC): mqtt.valid_subscribe_topic, vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic,
vol.Optional(CONF_PAYLOAD): cv.string, vol.Optional(CONF_PAYLOAD): cv.string,
vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string, vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string,
vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All( vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All(

View file

@ -274,7 +274,7 @@ async def system_options_update(hass, connection, msg):
{"type": "config_entries/update", "entry_id": str, vol.Optional("title"): str} {"type": "config_entries/update", "entry_id": str, vol.Optional("title"): str}
) )
async def config_entry_update(hass, connection, msg): async def config_entry_update(hass, connection, msg):
"""Update config entry system options.""" """Update config entry."""
changes = dict(msg) changes = dict(msg)
changes.pop("id") changes.pop("id")
changes.pop("type") changes.pop("type")

View file

@ -12,7 +12,7 @@ import sys
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
import attr import attr
import requests.certs import certifi
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
@ -46,11 +46,20 @@ from . import debug_info, discovery, server
from .const import ( from .const import (
ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_HASH,
ATTR_DISCOVERY_TOPIC, ATTR_DISCOVERY_TOPIC,
ATTR_PAYLOAD,
ATTR_QOS,
ATTR_RETAIN,
ATTR_TOPIC,
CONF_BIRTH_MESSAGE,
CONF_BROKER, CONF_BROKER,
CONF_DISCOVERY, CONF_DISCOVERY,
CONF_QOS,
CONF_RETAIN,
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
CONF_WILL_MESSAGE,
DEFAULT_DISCOVERY, DEFAULT_DISCOVERY,
DEFAULT_QOS, DEFAULT_QOS,
DEFAULT_RETAIN,
MQTT_CONNECTED, MQTT_CONNECTED,
MQTT_DISCONNECTED, MQTT_DISCONNECTED,
PROTOCOL_311, PROTOCOL_311,
@ -59,6 +68,7 @@ from .debug_info import log_messages
from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash, set_discovery_hash from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash, set_discovery_hash
from .models import Message, MessageCallbackType, PublishPayloadType from .models import Message, MessageCallbackType, PublishPayloadType
from .subscription import async_subscribe_topics, async_unsubscribe_topics from .subscription import async_subscribe_topics, async_unsubscribe_topics
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -80,17 +90,12 @@ CONF_CLIENT_CERT = "client_cert"
CONF_TLS_INSECURE = "tls_insecure" CONF_TLS_INSECURE = "tls_insecure"
CONF_TLS_VERSION = "tls_version" CONF_TLS_VERSION = "tls_version"
CONF_BIRTH_MESSAGE = "birth_message"
CONF_WILL_MESSAGE = "will_message"
CONF_COMMAND_TOPIC = "command_topic" CONF_COMMAND_TOPIC = "command_topic"
CONF_AVAILABILITY_TOPIC = "availability_topic" CONF_AVAILABILITY_TOPIC = "availability_topic"
CONF_PAYLOAD_AVAILABLE = "payload_available" CONF_PAYLOAD_AVAILABLE = "payload_available"
CONF_PAYLOAD_NOT_AVAILABLE = "payload_not_available" CONF_PAYLOAD_NOT_AVAILABLE = "payload_not_available"
CONF_JSON_ATTRS_TOPIC = "json_attributes_topic" CONF_JSON_ATTRS_TOPIC = "json_attributes_topic"
CONF_JSON_ATTRS_TEMPLATE = "json_attributes_template" CONF_JSON_ATTRS_TEMPLATE = "json_attributes_template"
CONF_QOS = "qos"
CONF_RETAIN = "retain"
CONF_UNIQUE_ID = "unique_id" CONF_UNIQUE_ID = "unique_id"
CONF_IDENTIFIERS = "identifiers" CONF_IDENTIFIERS = "identifiers"
@ -105,18 +110,13 @@ PROTOCOL_31 = "3.1"
DEFAULT_PORT = 1883 DEFAULT_PORT = 1883
DEFAULT_KEEPALIVE = 60 DEFAULT_KEEPALIVE = 60
DEFAULT_RETAIN = False
DEFAULT_PROTOCOL = PROTOCOL_311 DEFAULT_PROTOCOL = PROTOCOL_311
DEFAULT_DISCOVERY_PREFIX = "homeassistant" DEFAULT_DISCOVERY_PREFIX = "homeassistant"
DEFAULT_TLS_PROTOCOL = "auto" DEFAULT_TLS_PROTOCOL = "auto"
DEFAULT_PAYLOAD_AVAILABLE = "online" DEFAULT_PAYLOAD_AVAILABLE = "online"
DEFAULT_PAYLOAD_NOT_AVAILABLE = "offline" DEFAULT_PAYLOAD_NOT_AVAILABLE = "offline"
ATTR_TOPIC = "topic"
ATTR_PAYLOAD = "payload"
ATTR_PAYLOAD_TEMPLATE = "payload_template" ATTR_PAYLOAD_TEMPLATE = "payload_template"
ATTR_QOS = CONF_QOS
ATTR_RETAIN = CONF_RETAIN
MAX_RECONNECT_WAIT = 300 # seconds MAX_RECONNECT_WAIT = 300 # seconds
@ -125,59 +125,6 @@ CONNECTION_FAILED = "connection_failed"
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable" CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
def valid_topic(value: Any) -> str:
"""Validate that this is a valid topic name/filter."""
value = cv.string(value)
try:
raw_value = value.encode("utf-8")
except UnicodeError:
raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.")
if not raw_value:
raise vol.Invalid("MQTT topic name/filter must not be empty.")
if len(raw_value) > 65535:
raise vol.Invalid(
"MQTT topic name/filter must not be longer than 65535 encoded bytes."
)
if "\0" in value:
raise vol.Invalid("MQTT topic name/filter must not contain null character.")
return value
def valid_subscribe_topic(value: Any) -> str:
"""Validate that we can subscribe using this MQTT topic."""
value = valid_topic(value)
for i in (i for i, c in enumerate(value) if c == "+"):
if (i > 0 and value[i - 1] != "/") or (
i < len(value) - 1 and value[i + 1] != "/"
):
raise vol.Invalid(
"Single-level wildcard must occupy an entire level of the filter"
)
index = value.find("#")
if index != -1:
if index != len(value) - 1:
# If there are multiple wildcards, this will also trigger
raise vol.Invalid(
"Multi-level wildcard must be the last "
"character in the topic filter."
)
if len(value) > 1 and value[index - 1] != "/":
raise vol.Invalid(
"Multi-level wildcard must be after a topic level separator."
)
return value
def valid_publish_topic(value: Any) -> str:
"""Validate that we can publish using this MQTT topic."""
value = valid_topic(value)
if "+" in value or "#" in value:
raise vol.Invalid("Wildcards can not be used in topic names")
return value
def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType: def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType:
"""Validate that a device info entry has at least one identifying value.""" """Validate that a device info entry has at least one identifying value."""
if not value.get(CONF_IDENTIFIERS) and not value.get(CONF_CONNECTIONS): if not value.get(CONF_IDENTIFIERS) and not value.get(CONF_CONNECTIONS):
@ -188,8 +135,6 @@ def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType
return value return value
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
CLIENT_KEY_AUTH_MSG = ( CLIENT_KEY_AUTH_MSG = (
"client_key and client_cert must both be present in " "client_key and client_cert must both be present in "
"the MQTT broker configuration" "the MQTT broker configuration"
@ -554,6 +499,11 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
return True return True
def _merge_config(entry, conf):
"""Merge configuration.yaml config with config entry."""
return {**conf, **entry.data}
async def async_setup_entry(hass, entry): async def async_setup_entry(hass, entry):
"""Load a config entry.""" """Load a config entry."""
conf = hass.data.get(DATA_MQTT_CONFIG) conf = hass.data.get(DATA_MQTT_CONFIG)
@ -574,76 +524,9 @@ async def async_setup_entry(hass, entry):
entry.data, entry.data,
) )
conf.update(entry.data) conf = _merge_config(entry, conf)
broker = conf[CONF_BROKER] hass.data[DATA_MQTT] = MQTT(hass, entry, conf,)
port = conf[CONF_PORT]
client_id = conf.get(CONF_CLIENT_ID)
keepalive = conf[CONF_KEEPALIVE]
username = conf.get(CONF_USERNAME)
password = conf.get(CONF_PASSWORD)
certificate = conf.get(CONF_CERTIFICATE)
client_key = conf.get(CONF_CLIENT_KEY)
client_cert = conf.get(CONF_CLIENT_CERT)
tls_insecure = conf.get(CONF_TLS_INSECURE)
protocol = conf[CONF_PROTOCOL]
# For cloudmqtt.com, secured connection, auto fill in certificate
if (
certificate is None
and 19999 < conf[CONF_PORT] < 30000
and broker.endswith(".cloudmqtt.com")
):
certificate = os.path.join(
os.path.dirname(__file__), "addtrustexternalcaroot.crt"
)
# When the certificate is set to auto, use bundled certs from requests
elif certificate == "auto":
certificate = requests.certs.where()
if CONF_WILL_MESSAGE in conf:
will_message = Message(**conf[CONF_WILL_MESSAGE])
else:
will_message = None
if CONF_BIRTH_MESSAGE in conf:
birth_message = Message(**conf[CONF_BIRTH_MESSAGE])
else:
birth_message = None
# Be able to override versions other than TLSv1.0 under Python3.6
conf_tls_version: str = conf.get(CONF_TLS_VERSION)
if conf_tls_version == "1.2":
tls_version = ssl.PROTOCOL_TLSv1_2
elif conf_tls_version == "1.1":
tls_version = ssl.PROTOCOL_TLSv1_1
elif conf_tls_version == "1.0":
tls_version = ssl.PROTOCOL_TLSv1
else:
# Python3.6 supports automatic negotiation of highest TLS version
if sys.hexversion >= 0x03060000:
tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member
else:
tls_version = ssl.PROTOCOL_TLSv1
hass.data[DATA_MQTT] = MQTT(
hass,
broker=broker,
port=port,
client_id=client_id,
keepalive=keepalive,
username=username,
password=password,
certificate=certificate,
client_key=client_key,
client_cert=client_cert,
tls_insecure=tls_insecure,
protocol=protocol,
will_message=will_message,
birth_message=birth_message,
tls_version=tls_version,
)
await hass.data[DATA_MQTT].async_connect() await hass.data[DATA_MQTT].async_connect()
@ -732,53 +615,101 @@ class Subscription:
class MQTT: class MQTT:
"""Home Assistant MQTT client.""" """Home Assistant MQTT client."""
def __init__( def __init__(self, hass: HomeAssistantType, config_entry, conf,) -> None:
self,
hass: HomeAssistantType,
broker: str,
port: int,
client_id: Optional[str],
keepalive: Optional[int],
username: Optional[str],
password: Optional[str],
certificate: Optional[str],
client_key: Optional[str],
client_cert: Optional[str],
tls_insecure: Optional[bool],
protocol: Optional[str],
will_message: Optional[Message],
birth_message: Optional[Message],
tls_version: Optional[int],
) -> None:
"""Initialize Home Assistant MQTT client.""" """Initialize Home Assistant MQTT client."""
# We don't import them on the top because some integrations # We don't import on the top because some integrations
# should be able to optionally rely on MQTT. # should be able to optionally rely on MQTT.
# pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
import paho.mqtt.client as mqtt
self.hass = hass self.hass = hass
self.broker = broker self.config_entry = config_entry
self.port = port self.conf = conf
self.keepalive = keepalive
self.subscriptions: List[Subscription] = [] self.subscriptions: List[Subscription] = []
self.birth_message = birth_message
self.connected = False self.connected = False
self._mqttc: mqtt.Client = None self._mqttc: mqtt.Client = None
self._paho_lock = asyncio.Lock() self._paho_lock = asyncio.Lock()
if protocol == PROTOCOL_31: self.init_client()
self.config_entry.add_update_listener(self.async_config_entry_updated)
@staticmethod
async def async_config_entry_updated(hass, entry) -> None:
"""Handle signals of config entry being updated.
This is a static method because a class method (bound method), can not be used with weak references.
Causes for this is config entry options changing.
"""
self = hass.data[DATA_MQTT]
conf = hass.data.get(DATA_MQTT_CONFIG)
if conf is None:
conf = CONFIG_SCHEMA({DOMAIN: dict(entry.data)})[DOMAIN]
self.conf = _merge_config(entry, conf)
await self.async_disconnect()
self.init_client()
await self.async_connect()
await discovery.async_stop(hass)
if self.conf.get(CONF_DISCOVERY):
await _async_setup_discovery(hass, self.conf, entry)
def init_client(self):
"""Initialize paho client."""
# We don't import on the top because some integrations
# should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
if self.conf[CONF_PROTOCOL] == PROTOCOL_31:
proto: int = mqtt.MQTTv31 proto: int = mqtt.MQTTv31
else: else:
proto = mqtt.MQTTv311 proto = mqtt.MQTTv311
client_id = self.conf.get(CONF_CLIENT_ID)
if client_id is None: if client_id is None:
self._mqttc = mqtt.Client(protocol=proto) self._mqttc = mqtt.Client(protocol=proto)
else: else:
self._mqttc = mqtt.Client(client_id, protocol=proto) self._mqttc = mqtt.Client(client_id, protocol=proto)
username = self.conf.get(CONF_USERNAME)
password = self.conf.get(CONF_PASSWORD)
if username is not None: if username is not None:
self._mqttc.username_pw_set(username, password) self._mqttc.username_pw_set(username, password)
certificate = self.conf.get(CONF_CERTIFICATE)
# For cloudmqtt.com, secured connection, auto fill in certificate
if (
certificate is None
and 19999 < self.conf[CONF_PORT] < 30000
and self.conf[CONF_BROKER].endswith(".cloudmqtt.com")
):
certificate = os.path.join(
os.path.dirname(__file__), "addtrustexternalcaroot.crt"
)
# When the certificate is set to auto, use bundled certs from certifi
elif certificate == "auto":
certificate = certifi.where()
# Be able to override versions other than TLSv1.0 under Python3.6
conf_tls_version: str = self.conf.get(CONF_TLS_VERSION)
if conf_tls_version == "1.2":
tls_version = ssl.PROTOCOL_TLSv1_2
elif conf_tls_version == "1.1":
tls_version = ssl.PROTOCOL_TLSv1_1
elif conf_tls_version == "1.0":
tls_version = ssl.PROTOCOL_TLSv1
else:
# Python3.6 supports automatic negotiation of highest TLS version
if sys.hexversion >= 0x03060000:
tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member
else:
tls_version = ssl.PROTOCOL_TLSv1
client_key = self.conf.get(CONF_CLIENT_KEY)
client_cert = self.conf.get(CONF_CLIENT_CERT)
tls_insecure = self.conf.get(CONF_TLS_INSECURE)
if certificate is not None: if certificate is not None:
self._mqttc.tls_set( self._mqttc.tls_set(
certificate, certificate,
@ -794,6 +725,11 @@ class MQTT:
self._mqttc.on_disconnect = self._mqtt_on_disconnect self._mqttc.on_disconnect = self._mqtt_on_disconnect
self._mqttc.on_message = self._mqtt_on_message self._mqttc.on_message = self._mqtt_on_message
if CONF_WILL_MESSAGE in self.conf:
will_message = Message(**self.conf[CONF_WILL_MESSAGE])
else:
will_message = None
if will_message is not None: if will_message is not None:
self._mqttc.will_set( # pylint: disable=no-value-for-parameter self._mqttc.will_set( # pylint: disable=no-value-for-parameter
*attr.astuple( *attr.astuple(
@ -813,14 +749,17 @@ class MQTT:
) )
async def async_connect(self) -> str: async def async_connect(self) -> str:
"""Connect to the host. Does process messages yet.""" """Connect to the host. Does not process messages yet."""
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
result: int = None result: int = None
try: try:
result = await self.hass.async_add_executor_job( result = await self.hass.async_add_executor_job(
self._mqttc.connect, self.broker, self.port, self.keepalive self._mqttc.connect,
self.conf[CONF_BROKER],
self.conf[CONF_PORT],
self.conf[CONF_KEEPALIVE],
) )
except OSError as err: except OSError as err:
_LOGGER.error("Failed to connect to MQTT server due to exception: %s", err) _LOGGER.error("Failed to connect to MQTT server due to exception: %s", err)
@ -922,7 +861,12 @@ class MQTT:
self.connected = True self.connected = True
dispatcher_send(self.hass, MQTT_CONNECTED) dispatcher_send(self.hass, MQTT_CONNECTED)
_LOGGER.info("Connected to MQTT server (%s)", result_code) _LOGGER.info(
"Connected to MQTT server %s:%s (%s)",
self.conf[CONF_BROKER],
self.conf[CONF_PORT],
result_code,
)
# Group subscriptions to only re-subscribe once for each topic. # Group subscriptions to only re-subscribe once for each topic.
keyfunc = attrgetter("topic") keyfunc = attrgetter("topic")
@ -931,11 +875,12 @@ class MQTT:
max_qos = max(subscription.qos for subscription in subs) max_qos = max(subscription.qos for subscription in subs)
self.hass.add_job(self._async_perform_subscription, topic, max_qos) self.hass.add_job(self._async_perform_subscription, topic, max_qos)
if self.birth_message: if CONF_BIRTH_MESSAGE in self.conf:
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
self.hass.add_job( self.hass.add_job(
self.async_publish( # pylint: disable=no-value-for-parameter self.async_publish( # pylint: disable=no-value-for-parameter
*attr.astuple( *attr.astuple(
self.birth_message, birth_message,
filter=lambda attr, value: attr.name filter=lambda attr, value: attr.name
not in ["subscribed_topic", "timestamp"], not in ["subscribed_topic", "timestamp"],
) )
@ -990,7 +935,12 @@ class MQTT:
"""Disconnected callback.""" """Disconnected callback."""
self.connected = False self.connected = False
dispatcher_send(self.hass, MQTT_DISCONNECTED) dispatcher_send(self.hass, MQTT_DISCONNECTED)
_LOGGER.warning("Disconnected from MQTT server (%s)", result_code) _LOGGER.warning(
"Disconnected from MQTT server %s:%s (%s)",
self.conf[CONF_BROKER],
self.conf[CONF_PORT],
result_code,
)
def _raise_on_error(result_code: int) -> None: def _raise_on_error(result_code: int) -> None:

View file

@ -1,5 +1,6 @@
"""Config flow for MQTT.""" """Config flow for MQTT."""
from collections import OrderedDict from collections import OrderedDict
import logging
import queue import queue
import voluptuous as vol import voluptuous as vol
@ -13,7 +14,22 @@ from homeassistant.const import (
CONF_USERNAME, CONF_USERNAME,
) )
from .const import CONF_BROKER, CONF_DISCOVERY, DEFAULT_DISCOVERY from .const import (
ATTR_PAYLOAD,
ATTR_QOS,
ATTR_RETAIN,
ATTR_TOPIC,
CONF_BIRTH_MESSAGE,
CONF_BROKER,
CONF_DISCOVERY,
CONF_WILL_MESSAGE,
DEFAULT_DISCOVERY,
DEFAULT_QOS,
DEFAULT_RETAIN,
)
from .util import MQTT_WILL_BIRTH_SCHEMA
_LOGGER = logging.getLogger(__name__)
@config_entries.HANDLERS.register("mqtt") @config_entries.HANDLERS.register("mqtt")
@ -25,6 +41,11 @@ class FlowHandler(config_entries.ConfigFlow):
_hassio_discovery = None _hassio_discovery = None
@staticmethod
def async_get_options_flow(config_entry):
"""Get the options flow for this handler."""
return MQTTOptionsFlowHandler(config_entry)
async def async_step_user(self, user_input=None): async def async_step_user(self, user_input=None):
"""Handle a flow initialized by the user.""" """Handle a flow initialized by the user."""
if self._async_current_entries(): if self._async_current_entries():
@ -123,6 +144,163 @@ class FlowHandler(config_entries.ConfigFlow):
) )
class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
"""Handle MQTT options."""
def __init__(self, config_entry):
"""Initialize MQTT options flow."""
self.config_entry = config_entry
self.broker_config = {}
self.options = dict(config_entry.options)
async def async_step_init(self, user_input=None):
"""Manage the MQTT options."""
return await self.async_step_broker()
async def async_step_broker(self, user_input=None):
"""Manage the MQTT options."""
errors = {}
current_config = self.config_entry.data
if user_input is not None:
can_connect = await self.hass.async_add_executor_job(
try_connection,
user_input[CONF_BROKER],
user_input[CONF_PORT],
user_input.get(CONF_USERNAME),
user_input.get(CONF_PASSWORD),
)
if can_connect:
self.broker_config.update(user_input)
return await self.async_step_options()
errors["base"] = "cannot_connect"
fields = OrderedDict()
fields[vol.Required(CONF_BROKER, default=current_config[CONF_BROKER])] = str
fields[vol.Required(CONF_PORT, default=current_config[CONF_PORT])] = vol.Coerce(
int
)
fields[
vol.Optional(
CONF_USERNAME,
description={"suggested_value": current_config.get(CONF_USERNAME)},
)
] = str
fields[
vol.Optional(
CONF_PASSWORD,
description={"suggested_value": current_config.get(CONF_PASSWORD)},
)
] = str
return self.async_show_form(
step_id="broker", data_schema=vol.Schema(fields), errors=errors,
)
async def async_step_options(self, user_input=None):
"""Manage the MQTT options."""
errors = {}
current_config = self.config_entry.data
options_config = {}
if user_input is not None:
bad_birth = False
bad_will = False
if "birth_topic" in user_input:
birth_message = {
ATTR_TOPIC: user_input["birth_topic"],
ATTR_PAYLOAD: user_input.get("birth_payload", ""),
ATTR_QOS: user_input["birth_qos"],
ATTR_RETAIN: user_input["birth_retain"],
}
try:
birth_message = MQTT_WILL_BIRTH_SCHEMA(birth_message)
options_config[CONF_BIRTH_MESSAGE] = birth_message
except vol.Invalid:
errors["base"] = "bad_birth"
bad_birth = True
if "will_topic" in user_input:
will_message = {
ATTR_TOPIC: user_input["will_topic"],
ATTR_PAYLOAD: user_input.get("will_payload", ""),
ATTR_QOS: user_input["will_qos"],
ATTR_RETAIN: user_input["will_retain"],
}
try:
will_message = MQTT_WILL_BIRTH_SCHEMA(will_message)
options_config[CONF_WILL_MESSAGE] = will_message
except vol.Invalid:
errors["base"] = "bad_will"
bad_will = True
options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY]
if not bad_birth and not bad_will:
updated_config = {}
updated_config.update(self.broker_config)
updated_config.update(options_config)
self.hass.config_entries.async_update_entry(
self.config_entry, data=updated_config
)
return self.async_create_entry(title="", data=None)
birth_topic = None
birth_payload = None
birth_qos = DEFAULT_QOS
birth_retain = DEFAULT_RETAIN
if CONF_BIRTH_MESSAGE in current_config:
birth_topic = current_config[CONF_BIRTH_MESSAGE][ATTR_TOPIC]
birth_payload = current_config[CONF_BIRTH_MESSAGE][ATTR_PAYLOAD]
birth_qos = current_config[CONF_BIRTH_MESSAGE].get(ATTR_QOS, DEFAULT_QOS)
birth_retain = current_config[CONF_BIRTH_MESSAGE].get(
ATTR_RETAIN, DEFAULT_RETAIN
)
will_topic = None
will_payload = None
will_qos = DEFAULT_QOS
will_retain = DEFAULT_RETAIN
if CONF_WILL_MESSAGE in current_config:
will_topic = current_config[CONF_WILL_MESSAGE][ATTR_TOPIC]
will_payload = current_config[CONF_WILL_MESSAGE][ATTR_PAYLOAD]
will_qos = current_config[CONF_WILL_MESSAGE].get(ATTR_QOS, DEFAULT_QOS)
will_retain = current_config[CONF_WILL_MESSAGE].get(
ATTR_RETAIN, DEFAULT_RETAIN
)
fields = OrderedDict()
fields[
vol.Optional(
CONF_DISCOVERY,
default=current_config.get(CONF_DISCOVERY, DEFAULT_DISCOVERY),
)
] = bool
fields[
vol.Optional("birth_topic", description={"suggested_value": birth_topic})
] = str
fields[
vol.Optional(
"birth_payload", description={"suggested_value": birth_payload}
)
] = str
fields[vol.Optional("birth_qos", default=birth_qos)] = vol.In([0, 1, 2])
fields[vol.Optional("birth_retain", default=birth_retain)] = bool
fields[
vol.Optional("will_topic", description={"suggested_value": will_topic})
] = str
fields[
vol.Optional("will_payload", description={"suggested_value": will_payload})
] = str
fields[vol.Optional("will_qos", default=will_qos)] = vol.In([0, 1, 2])
fields[vol.Optional("will_retain", default=will_retain)] = bool
return self.async_show_form(
step_id="options", data_schema=vol.Schema(fields), errors=errors,
)
def try_connection(broker, port, username, password, protocol="3.1"): def try_connection(broker, port, username, password, protocol="3.1"):
"""Test if we can connect to an MQTT broker.""" """Test if we can connect to an MQTT broker."""
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel

View file

@ -1,14 +1,25 @@
"""Constants used by multiple MQTT modules.""" """Constants used by multiple MQTT modules."""
CONF_BROKER = "broker"
CONF_DISCOVERY = "discovery"
DEFAULT_DISCOVERY = False
ATTR_DISCOVERY_HASH = "discovery_hash" ATTR_DISCOVERY_HASH = "discovery_hash"
ATTR_DISCOVERY_PAYLOAD = "discovery_payload" ATTR_DISCOVERY_PAYLOAD = "discovery_payload"
ATTR_DISCOVERY_TOPIC = "discovery_topic" ATTR_DISCOVERY_TOPIC = "discovery_topic"
ATTR_PAYLOAD = "payload"
ATTR_QOS = "qos"
ATTR_RETAIN = "retain"
ATTR_TOPIC = "topic"
CONF_BROKER = "broker"
CONF_BIRTH_MESSAGE = "birth_message"
CONF_DISCOVERY = "discovery"
CONF_QOS = ATTR_QOS
CONF_RETAIN = ATTR_RETAIN
CONF_STATE_TOPIC = "state_topic" CONF_STATE_TOPIC = "state_topic"
PROTOCOL_311 = "3.1.1" CONF_WILL_MESSAGE = "will_message"
DEFAULT_DISCOVERY = False
DEFAULT_QOS = 0 DEFAULT_QOS = 0
DEFAULT_RETAIN = False
MQTT_CONNECTED = "mqtt_connected" MQTT_CONNECTED = "mqtt_connected"
MQTT_DISCONNECTED = "mqtt_disconnected" MQTT_DISCONNECTED = "mqtt_disconnected"
PROTOCOL_311 = "3.1.1"

View file

@ -35,8 +35,9 @@ SUPPORTED_COMPONENTS = [
] ]
ALREADY_DISCOVERED = "mqtt_discovered_components" ALREADY_DISCOVERED = "mqtt_discovered_components"
DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock"
CONFIG_ENTRY_IS_SETUP = "mqtt_config_entry_is_setup" CONFIG_ENTRY_IS_SETUP = "mqtt_config_entry_is_setup"
DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock"
DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe"
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}" MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}" MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
@ -163,8 +164,15 @@ async def async_start(
hass.data[DATA_CONFIG_ENTRY_LOCK] = asyncio.Lock() hass.data[DATA_CONFIG_ENTRY_LOCK] = asyncio.Lock()
hass.data[CONFIG_ENTRY_IS_SETUP] = set() hass.data[CONFIG_ENTRY_IS_SETUP] = set()
await mqtt.async_subscribe( hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe(
hass, f"{discovery_topic}/#", async_device_message_received, 0 hass, f"{discovery_topic}/#", async_device_message_received, 0
) )
return True return True
async def async_stop(hass: HomeAssistantType) -> bool:
"""Stop MQTT Discovery."""
if DISCOVERY_UNSUBSCRIBE in hass.data and hass.data[DISCOVERY_UNSUBSCRIBE]:
hass.data[DISCOVERY_UNSUBSCRIBE]()
hass.data[DISCOVERY_UNSUBSCRIBE] = None

View file

@ -47,5 +47,37 @@
"button_5": "Fifth button", "button_5": "Fifth button",
"button_6": "Sixth button" "button_6": "Sixth button"
} }
},
"options": {
"step": {
"broker": {
"description": "Please enter the connection information of your MQTT broker.",
"data": {
"broker": "Broker",
"port": "[%key:common::config_flow::data::port%]",
"username": "[%key:common::config_flow::data::username%]",
"password": "[%key:common::config_flow::data::password%]"
}
},
"options": {
"description": "Please select MQTT options.",
"data": {
"discovery": "Enable discovery",
"birth_topic": "Birth message topic",
"birth_payload": "Birth message payload",
"birth_qos": "Birth message QoS",
"birth_retain": "Birth message retain",
"will_topic": "Will message topic",
"will_payload": "Will message payload",
"will_qos": "Will message QoS",
"will_retain": "Will message retain"
}
}
},
"error": {
"cannot_connect": "Unable to connect to the broker.",
"bad_birth": "Invalid birth topic.",
"bad_will": "Invalid will topic."
}
} }
} }

View file

@ -0,0 +1,82 @@
"""Utility functions for the MQTT integration."""
from typing import Any
import voluptuous as vol
from homeassistant.const import CONF_PAYLOAD
from homeassistant.helpers import config_validation as cv
from .const import (
ATTR_PAYLOAD,
ATTR_QOS,
ATTR_RETAIN,
ATTR_TOPIC,
DEFAULT_QOS,
DEFAULT_RETAIN,
)
def valid_topic(value: Any) -> str:
"""Validate that this is a valid topic name/filter."""
value = cv.string(value)
try:
raw_value = value.encode("utf-8")
except UnicodeError:
raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.")
if not raw_value:
raise vol.Invalid("MQTT topic name/filter must not be empty.")
if len(raw_value) > 65535:
raise vol.Invalid(
"MQTT topic name/filter must not be longer than 65535 encoded bytes."
)
if "\0" in value:
raise vol.Invalid("MQTT topic name/filter must not contain null character.")
return value
def valid_subscribe_topic(value: Any) -> str:
"""Validate that we can subscribe using this MQTT topic."""
value = valid_topic(value)
for i in (i for i, c in enumerate(value) if c == "+"):
if (i > 0 and value[i - 1] != "/") or (
i < len(value) - 1 and value[i + 1] != "/"
):
raise vol.Invalid(
"Single-level wildcard must occupy an entire level of the filter"
)
index = value.find("#")
if index != -1:
if index != len(value) - 1:
# If there are multiple wildcards, this will also trigger
raise vol.Invalid(
"Multi-level wildcard must be the last "
"character in the topic filter."
)
if len(value) > 1 and value[index - 1] != "/":
raise vol.Invalid(
"Multi-level wildcard must be after a topic level separator."
)
return value
def valid_publish_topic(value: Any) -> str:
"""Validate that we can publish using this MQTT topic."""
value = valid_topic(value)
if "+" in value or "#" in value:
raise vol.Invalid("Wildcards can not be used in topic names")
return value
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
MQTT_WILL_BIRTH_SCHEMA = vol.Schema(
{
vol.Required(ATTR_TOPIC): valid_publish_topic,
vol.Required(ATTR_PAYLOAD, CONF_PAYLOAD): cv.string,
vol.Optional(ATTR_QOS, default=DEFAULT_QOS): _VALID_QOS_SCHEMA,
vol.Optional(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean,
},
required=True,
)

View file

@ -1008,7 +1008,8 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
entry = self.hass.config_entries.async_get_entry(flow.handler) entry = self.hass.config_entries.async_get_entry(flow.handler)
if entry is None: if entry is None:
raise UnknownEntry(flow.handler) raise UnknownEntry(flow.handler)
self.hass.config_entries.async_update_entry(entry, options=result["data"]) if result["data"] is not None:
self.hass.config_entries.async_update_entry(entry, options=result["data"])
result["result"] = True result["result"] = True
return result return result

View file

@ -1,7 +1,11 @@
"""Test config flow.""" """Test config flow."""
import pytest import pytest
import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant.components import mqtt
from homeassistant.components.mqtt.discovery import async_start
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.async_mock import patch from tests.async_mock import patch
@ -144,3 +148,340 @@ async def test_hassio_confirm(hass, mock_try_connection, mock_finish_setup):
assert len(mock_try_connection.mock_calls) == 1 assert len(mock_try_connection.mock_calls) == 1
# Check config entry got setup # Check config entry got setup
assert len(mock_finish_setup.mock_calls) == 1 assert len(mock_finish_setup.mock_calls) == 1
async def test_option_flow(hass, mqtt_mock, mock_try_connection):
"""Test config flow options."""
mock_try_connection.return_value = True
config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
await async_start(hass, "homeassistant", config_entry)
config_entry.data = {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
mqtt_mock.async_connect.reset_mock()
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "broker"
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
mqtt.CONF_BROKER: "another-broker",
mqtt.CONF_PORT: 2345,
mqtt.CONF_USERNAME: "user",
mqtt.CONF_PASSWORD: "pass",
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "options"
await hass.async_block_till_done()
assert mqtt_mock.async_connect.call_count == 0
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
mqtt.CONF_DISCOVERY: True,
"birth_topic": "ha_state/online",
"birth_payload": "online",
"birth_qos": 1,
"birth_retain": True,
"will_topic": "ha_state/offline",
"will_payload": "offline",
"will_qos": 2,
"will_retain": True,
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["data"] is None
assert config_entry.data == {
mqtt.CONF_BROKER: "another-broker",
mqtt.CONF_PORT: 2345,
mqtt.CONF_USERNAME: "user",
mqtt.CONF_PASSWORD: "pass",
mqtt.CONF_DISCOVERY: True,
mqtt.CONF_BIRTH_MESSAGE: {
mqtt.ATTR_TOPIC: "ha_state/online",
mqtt.ATTR_PAYLOAD: "online",
mqtt.ATTR_QOS: 1,
mqtt.ATTR_RETAIN: True,
},
mqtt.CONF_WILL_MESSAGE: {
mqtt.ATTR_TOPIC: "ha_state/offline",
mqtt.ATTR_PAYLOAD: "offline",
mqtt.ATTR_QOS: 2,
mqtt.ATTR_RETAIN: True,
},
}
await hass.async_block_till_done()
assert mqtt_mock.async_connect.call_count == 1
def get_default(schema, key):
"""Get default value for key in voluptuous schema."""
for k in schema.keys():
if k == key:
if k.default == vol.UNDEFINED:
return None
return k.default()
def get_suggested(schema, key):
"""Get suggested value for key in voluptuous schema."""
for k in schema.keys():
if k == key:
if k.description is None or "suggested_value" not in k.description:
return None
return k.description["suggested_value"]
async def test_option_flow_default_suggested_values(
hass, mqtt_mock, mock_try_connection
):
"""Test config flow options has default/suggested values."""
mock_try_connection.return_value = True
config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
await async_start(hass, "homeassistant", config_entry)
config_entry.data = {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
mqtt.CONF_USERNAME: "user",
mqtt.CONF_PASSWORD: "pass",
mqtt.CONF_DISCOVERY: True,
mqtt.CONF_BIRTH_MESSAGE: {
mqtt.ATTR_TOPIC: "ha_state/online",
mqtt.ATTR_PAYLOAD: "online",
mqtt.ATTR_QOS: 1,
mqtt.ATTR_RETAIN: True,
},
mqtt.CONF_WILL_MESSAGE: {
mqtt.ATTR_TOPIC: "ha_state/offline",
mqtt.ATTR_PAYLOAD: "offline",
mqtt.ATTR_QOS: 2,
mqtt.ATTR_RETAIN: False,
},
}
# Test default/suggested values from config
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "broker"
defaults = {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
suggested = {
mqtt.CONF_USERNAME: "user",
mqtt.CONF_PASSWORD: "pass",
}
for k, v in defaults.items():
assert get_default(result["data_schema"].schema, k) == v
for k, v in suggested.items():
assert get_suggested(result["data_schema"].schema, k) == v
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
mqtt.CONF_BROKER: "another-broker",
mqtt.CONF_PORT: 2345,
mqtt.CONF_USERNAME: "us3r",
mqtt.CONF_PASSWORD: "p4ss",
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "options"
defaults = {
mqtt.CONF_DISCOVERY: True,
"birth_qos": 1,
"birth_retain": True,
"will_qos": 2,
"will_retain": False,
}
suggested = {
"birth_topic": "ha_state/online",
"birth_payload": "online",
"will_topic": "ha_state/offline",
"will_payload": "offline",
}
for k, v in defaults.items():
assert get_default(result["data_schema"].schema, k) == v
for k, v in suggested.items():
assert get_suggested(result["data_schema"].schema, k) == v
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
mqtt.CONF_DISCOVERY: False,
"birth_topic": "ha_state/onl1ne",
"birth_payload": "onl1ne",
"birth_qos": 2,
"birth_retain": False,
"will_topic": "ha_state/offl1ne",
"will_payload": "offl1ne",
"will_qos": 1,
"will_retain": True,
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
# Test updated default/suggested values from config
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "broker"
defaults = {
mqtt.CONF_BROKER: "another-broker",
mqtt.CONF_PORT: 2345,
}
suggested = {
mqtt.CONF_USERNAME: "us3r",
mqtt.CONF_PASSWORD: "p4ss",
}
for k, v in defaults.items():
assert get_default(result["data_schema"].schema, k) == v
for k, v in suggested.items():
assert get_suggested(result["data_schema"].schema, k) == v
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "options"
defaults = {
mqtt.CONF_DISCOVERY: False,
"birth_qos": 2,
"birth_retain": False,
"will_qos": 1,
"will_retain": True,
}
suggested = {
"birth_topic": "ha_state/onl1ne",
"birth_payload": "onl1ne",
"will_topic": "ha_state/offl1ne",
"will_payload": "offl1ne",
}
for k, v in defaults.items():
assert get_default(result["data_schema"].schema, k) == v
for k, v in suggested.items():
assert get_suggested(result["data_schema"].schema, k) == v
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
mqtt.CONF_DISCOVERY: True,
"birth_topic": "ha_state/onl1ne",
"birth_payload": "onl1ne",
"birth_qos": 2,
"birth_retain": False,
"will_topic": "ha_state/offl1ne",
"will_payload": "offl1ne",
"will_qos": 1,
"will_retain": True,
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
async def test_options_user_connection_fails(hass, mock_try_connection):
"""Test if connection cannot be made."""
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
config_entry.add_to_hass(hass)
config_entry.data = {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
mock_try_connection.return_value = False
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={mqtt.CONF_BROKER: "bad-broker", mqtt.CONF_PORT: 2345},
)
assert result["type"] == "form"
assert result["errors"]["base"] == "cannot_connect"
# Check we tried the connection
assert len(mock_try_connection.mock_calls) == 1
# Check config entry did not update
assert config_entry.data == {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
async def test_options_bad_birth_message_fails(hass, mock_try_connection):
"""Test bad birth message."""
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
config_entry.add_to_hass(hass)
config_entry.data = {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
mock_try_connection.return_value = True
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345},
)
assert result["type"] == "form"
assert result["step_id"] == "options"
result = await hass.config_entries.options.async_configure(
result["flow_id"], user_input={"birth_topic": "ha_state/online/#"},
)
assert result["type"] == "form"
assert result["errors"]["base"] == "bad_birth"
# Check config entry did not update
assert config_entry.data == {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
async def test_options_bad_will_message_fails(hass, mock_try_connection):
"""Test bad will message."""
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
config_entry.add_to_hass(hass)
config_entry.data = {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
mock_try_connection.return_value = True
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345},
)
assert result["type"] == "form"
assert result["step_id"] == "options"
result = await hass.config_entries.options.async_configure(
result["flow_id"], user_input={"will_topic": "ha_state/offline/#"},
)
assert result["type"] == "form"
assert result["errors"]["base"] == "bad_will"
# Check config entry did not update
assert config_entry.data == {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}

View file

@ -171,26 +171,26 @@ def test_validate_topic():
"""Test topic name/filter validation.""" """Test topic name/filter validation."""
# Invalid UTF-8, must not contain U+D800 to U+DFFF. # Invalid UTF-8, must not contain U+D800 to U+DFFF.
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
mqtt.valid_topic("\ud800") mqtt.util.valid_topic("\ud800")
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
mqtt.valid_topic("\udfff") mqtt.util.valid_topic("\udfff")
# Topic MUST NOT be empty # Topic MUST NOT be empty
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
mqtt.valid_topic("") mqtt.util.valid_topic("")
# Topic MUST NOT be longer than 65535 encoded bytes. # Topic MUST NOT be longer than 65535 encoded bytes.
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
mqtt.valid_topic("ü" * 32768) mqtt.util.valid_topic("ü" * 32768)
# UTF-8 MUST NOT include null character # UTF-8 MUST NOT include null character
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
mqtt.valid_topic("bad\0one") mqtt.util.valid_topic("bad\0one")
# Topics "SHOULD NOT" include these special characters # Topics "SHOULD NOT" include these special characters
# (not MUST NOT, RFC2119). The receiver MAY close the connection. # (not MUST NOT, RFC2119). The receiver MAY close the connection.
mqtt.valid_topic("\u0001") mqtt.util.valid_topic("\u0001")
mqtt.valid_topic("\u001F") mqtt.util.valid_topic("\u001F")
mqtt.valid_topic("\u009F") mqtt.util.valid_topic("\u009F")
mqtt.valid_topic("\u009F") mqtt.util.valid_topic("\u009F")
mqtt.valid_topic("\uffff") mqtt.util.valid_topic("\uffff")
def test_validate_subscribe_topic(): def test_validate_subscribe_topic():
@ -587,7 +587,7 @@ async def test_retained_message_on_subscribe_received(
mqtt_client_mock.subscribe.side_effect = side_effect mqtt_client_mock.subscribe.side_effect = side_effect
# Fake that the client is connected # Fake that the client is connected
mqtt_mock.connected = True mqtt_mock().connected = True
calls_a = MagicMock() calls_a = MagicMock()
await mqtt.async_subscribe(hass, "test/state", calls_a) await mqtt.async_subscribe(hass, "test/state", calls_a)
@ -605,7 +605,7 @@ async def test_not_calling_unsubscribe_with_active_subscribers(
): ):
"""Test not calling unsubscribe() when other subscribers are active.""" """Test not calling unsubscribe() when other subscribers are active."""
# Fake that the client is connected # Fake that the client is connected
mqtt_mock.connected = True mqtt_mock().connected = True
unsub = await mqtt.async_subscribe(hass, "test/state", None) unsub = await mqtt.async_subscribe(hass, "test/state", None)
await mqtt.async_subscribe(hass, "test/state", None) await mqtt.async_subscribe(hass, "test/state", None)
@ -620,7 +620,7 @@ async def test_not_calling_unsubscribe_with_active_subscribers(
async def test_restore_subscriptions_on_reconnect(hass, mqtt_client_mock, mqtt_mock): async def test_restore_subscriptions_on_reconnect(hass, mqtt_client_mock, mqtt_mock):
"""Test subscriptions are restored on reconnect.""" """Test subscriptions are restored on reconnect."""
# Fake that the client is connected # Fake that the client is connected
mqtt_mock.connected = True mqtt_mock().connected = True
await mqtt.async_subscribe(hass, "test/state", None) await mqtt.async_subscribe(hass, "test/state", None)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -637,7 +637,7 @@ async def test_restore_all_active_subscriptions_on_reconnect(
): ):
"""Test active subscriptions are restored correctly on reconnect.""" """Test active subscriptions are restored correctly on reconnect."""
# Fake that the client is connected # Fake that the client is connected
mqtt_mock.connected = True mqtt_mock().connected = True
mqtt_client_mock.subscribe.side_effect = ( mqtt_client_mock.subscribe.side_effect = (
(0, 1), (0, 1),
@ -716,81 +716,107 @@ async def test_setup_raises_ConfigEntryNotReady_if_no_connect_broker(hass, caplo
assert "Failed to connect to MQTT server due to exception:" in caplog.text assert "Failed to connect to MQTT server due to exception:" in caplog.text
async def test_setup_uses_certificate_on_certificate_set_to_auto(hass, mock_mqtt): async def test_setup_uses_certificate_on_certificate_set_to_auto(hass):
"""Test setup uses bundled certs when certificate is set to auto.""" """Test setup uses bundled certs when certificate is set to auto."""
entry = MockConfigEntry( calls = []
domain=mqtt.DOMAIN,
data={mqtt.CONF_BROKER: "test-broker", "certificate": "auto"},
)
assert await mqtt.async_setup_entry(hass, entry) def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
calls.append((certificate, certfile, keyfile, tls_version))
assert mock_mqtt.called with patch("paho.mqtt.client.Client") as mock_client:
mock_client().tls_set = mock_tls_set
entry = MockConfigEntry(
domain=mqtt.DOMAIN,
data={mqtt.CONF_BROKER: "test-broker", "certificate": "auto"},
)
import requests.certs assert await mqtt.async_setup_entry(hass, entry)
expectedCertificate = requests.certs.where() assert calls
assert mock_mqtt.mock_calls[0][2]["certificate"] == expectedCertificate
import certifi
expectedCertificate = certifi.where()
# assert mock_mqtt.mock_calls[0][1][2]["certificate"] == expectedCertificate
assert calls[0][0] == expectedCertificate
async def test_setup_does_not_use_certificate_on_mqtts_port(hass, mock_mqtt): async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass):
"""Test setup doesn't use bundled certs when ssl set."""
entry = MockConfigEntry(
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "port": 8883}
)
assert await mqtt.async_setup_entry(hass, entry)
assert mock_mqtt.called
assert mock_mqtt.mock_calls[0][2]["port"] == 8883
import requests.certs
mqttsCertificateBundle = requests.certs.where()
assert mock_mqtt.mock_calls[0][2]["port"] != mqttsCertificateBundle
async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass, mock_mqtt):
"""Test setup defaults to TLSv1 under python3.6.""" """Test setup defaults to TLSv1 under python3.6."""
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"}) calls = []
assert await mqtt.async_setup_entry(hass, entry) def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
calls.append((certificate, certfile, keyfile, tls_version))
assert mock_mqtt.called with patch("paho.mqtt.client.Client") as mock_client:
mock_client().tls_set = mock_tls_set
entry = MockConfigEntry(
domain=mqtt.DOMAIN,
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"},
)
import sys assert await mqtt.async_setup_entry(hass, entry)
if sys.hexversion >= 0x03060000: assert calls
expectedTlsVersion = ssl.PROTOCOL_TLS # pylint: disable=no-member
else:
expectedTlsVersion = ssl.PROTOCOL_TLSv1
assert mock_mqtt.mock_calls[0][2]["tls_version"] == expectedTlsVersion import sys
if sys.hexversion >= 0x03060000:
expectedTlsVersion = ssl.PROTOCOL_TLS # pylint: disable=no-member
else:
expectedTlsVersion = ssl.PROTOCOL_TLSv1
assert calls[0][3] == expectedTlsVersion
async def test_setup_with_tls_config_uses_tls_version1_2(hass, mock_mqtt): async def test_setup_with_tls_config_uses_tls_version1_2(hass):
"""Test setup uses specified TLS version.""" """Test setup uses specified TLS version."""
entry = MockConfigEntry( calls = []
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "tls_version": "1.2"}
)
assert await mqtt.async_setup_entry(hass, entry) def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
calls.append((certificate, certfile, keyfile, tls_version))
assert mock_mqtt.called with patch("paho.mqtt.client.Client") as mock_client:
mock_client().tls_set = mock_tls_set
entry = MockConfigEntry(
domain=mqtt.DOMAIN,
data={
"certificate": "auto",
mqtt.CONF_BROKER: "test-broker",
"tls_version": "1.2",
},
)
assert mock_mqtt.mock_calls[0][2]["tls_version"] == ssl.PROTOCOL_TLSv1_2 assert await mqtt.async_setup_entry(hass, entry)
assert calls
assert calls[0][3] == ssl.PROTOCOL_TLSv1_2
async def test_setup_with_tls_config_of_v1_under_python36_only_uses_v1(hass, mock_mqtt): async def test_setup_with_tls_config_of_v1_under_python36_only_uses_v1(hass):
"""Test setup uses TLSv1.0 if explicitly chosen.""" """Test setup uses TLSv1.0 if explicitly chosen."""
entry = MockConfigEntry( calls = []
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "tls_version": "1.0"}
)
assert await mqtt.async_setup_entry(hass, entry) def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
calls.append((certificate, certfile, keyfile, tls_version))
assert mock_mqtt.called with patch("paho.mqtt.client.Client") as mock_client:
assert mock_mqtt.mock_calls[0][2]["tls_version"] == ssl.PROTOCOL_TLSv1 mock_client().tls_set = mock_tls_set
entry = MockConfigEntry(
domain=mqtt.DOMAIN,
data={
"certificate": "auto",
mqtt.CONF_BROKER: "test-broker",
"tls_version": "1.0",
},
)
assert await mqtt.async_setup_entry(hass, entry)
assert calls
assert calls[0][3] == ssl.PROTOCOL_TLSv1
@pytest.mark.parametrize( @pytest.mark.parametrize(

View file

@ -46,8 +46,8 @@ class TestMQTT:
) )
self.hass.block_till_done() self.hass.block_till_done()
assert mock_mqtt.called assert mock_mqtt.called
assert mock_mqtt.mock_calls[1][2]["username"] == "homeassistant" assert mock_mqtt.mock_calls[1][1][2]["username"] == "homeassistant"
assert mock_mqtt.mock_calls[1][2]["password"] == password assert mock_mqtt.mock_calls[1][1][2]["password"] == password
@patch("passlib.apps.custom_app_context", Mock(return_value="")) @patch("passlib.apps.custom_app_context", Mock(return_value=""))
@patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock())) @patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock()))
@ -69,8 +69,8 @@ class TestMQTT:
) )
self.hass.block_till_done() self.hass.block_till_done()
assert mock_mqtt.called assert mock_mqtt.called
assert mock_mqtt.mock_calls[1][2]["username"] == "homeassistant" assert mock_mqtt.mock_calls[1][1][2]["username"] == "homeassistant"
assert mock_mqtt.mock_calls[1][2]["password"] == password assert mock_mqtt.mock_calls[1][1][2]["password"] == password
@patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock())) @patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock()))
@patch("hbmqtt.broker.Broker.start", return_value=mock_coro()) @patch("hbmqtt.broker.Broker.start", return_value=mock_coro())

View file

@ -304,8 +304,11 @@ async def mqtt_mock(hass, mqtt_client_mock, mqtt_config):
assert result assert result
await hass.async_block_till_done() await hass.async_block_till_done()
mqtt_component_mock = MagicMock(spec_set=hass.data["mqtt"], wraps=hass.data["mqtt"]) mqtt_component_mock = MagicMock(
hass.data["mqtt"].connected = mqtt_component_mock.connected return_value=hass.data["mqtt"],
spec_set=hass.data["mqtt"],
wraps=hass.data["mqtt"],
)
mqtt_component_mock._mqttc = mqtt_client_mock mqtt_component_mock._mqttc = mqtt_client_mock
hass.data["mqtt"] = mqtt_component_mock hass.data["mqtt"] = mqtt_component_mock