From 747490ab343cd9487a50b3def847320a0661784e Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 23 Jun 2020 02:49:01 +0200 Subject: [PATCH] Support reconfiguring MQTT config entry (#36537) --- homeassistant/components/automation/mqtt.py | 2 +- .../components/config/config_entries.py | 2 +- homeassistant/components/mqtt/__init__.py | 284 ++++++--------- homeassistant/components/mqtt/config_flow.py | 180 ++++++++- homeassistant/components/mqtt/const.py | 21 +- homeassistant/components/mqtt/discovery.py | 12 +- homeassistant/components/mqtt/strings.json | 32 ++ homeassistant/components/mqtt/util.py | 82 +++++ homeassistant/config_entries.py | 3 +- tests/components/mqtt/test_config_flow.py | 341 ++++++++++++++++++ tests/components/mqtt/test_init.py | 156 ++++---- tests/components/mqtt/test_server.py | 8 +- tests/conftest.py | 7 +- 13 files changed, 881 insertions(+), 249 deletions(-) create mode 100644 homeassistant/components/mqtt/util.py diff --git a/homeassistant/components/automation/mqtt.py b/homeassistant/components/automation/mqtt.py index 046cbba2873..8bb8ad46041 100644 --- a/homeassistant/components/automation/mqtt.py +++ b/homeassistant/components/automation/mqtt.py @@ -19,7 +19,7 @@ DEFAULT_QOS = 0 TRIGGER_SCHEMA = vol.Schema( { vol.Required(CONF_PLATFORM): mqtt.DOMAIN, - vol.Required(CONF_TOPIC): mqtt.valid_subscribe_topic, + vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic, vol.Optional(CONF_PAYLOAD): cv.string, vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string, vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All( diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index 10ef2aeecb0..32934d4e970 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -274,7 +274,7 @@ async def system_options_update(hass, connection, msg): {"type": "config_entries/update", "entry_id": str, vol.Optional("title"): str} ) async def config_entry_update(hass, connection, msg): - """Update config entry system options.""" + """Update config entry.""" changes = dict(msg) changes.pop("id") changes.pop("type") diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 54f745d5bb2..f631ee45f87 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -12,7 +12,7 @@ import sys from typing import Any, Callable, List, Optional, Union import attr -import requests.certs +import certifi import voluptuous as vol from homeassistant import config_entries @@ -46,11 +46,20 @@ from . import debug_info, discovery, server from .const import ( ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_TOPIC, + ATTR_PAYLOAD, + ATTR_QOS, + ATTR_RETAIN, + ATTR_TOPIC, + CONF_BIRTH_MESSAGE, CONF_BROKER, CONF_DISCOVERY, + CONF_QOS, + CONF_RETAIN, CONF_STATE_TOPIC, + CONF_WILL_MESSAGE, DEFAULT_DISCOVERY, DEFAULT_QOS, + DEFAULT_RETAIN, MQTT_CONNECTED, MQTT_DISCONNECTED, PROTOCOL_311, @@ -59,6 +68,7 @@ from .debug_info import log_messages from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash, set_discovery_hash from .models import Message, MessageCallbackType, PublishPayloadType from .subscription import async_subscribe_topics, async_unsubscribe_topics +from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic _LOGGER = logging.getLogger(__name__) @@ -80,17 +90,12 @@ CONF_CLIENT_CERT = "client_cert" CONF_TLS_INSECURE = "tls_insecure" CONF_TLS_VERSION = "tls_version" -CONF_BIRTH_MESSAGE = "birth_message" -CONF_WILL_MESSAGE = "will_message" - CONF_COMMAND_TOPIC = "command_topic" CONF_AVAILABILITY_TOPIC = "availability_topic" CONF_PAYLOAD_AVAILABLE = "payload_available" CONF_PAYLOAD_NOT_AVAILABLE = "payload_not_available" CONF_JSON_ATTRS_TOPIC = "json_attributes_topic" CONF_JSON_ATTRS_TEMPLATE = "json_attributes_template" -CONF_QOS = "qos" -CONF_RETAIN = "retain" CONF_UNIQUE_ID = "unique_id" CONF_IDENTIFIERS = "identifiers" @@ -105,18 +110,13 @@ PROTOCOL_31 = "3.1" DEFAULT_PORT = 1883 DEFAULT_KEEPALIVE = 60 -DEFAULT_RETAIN = False DEFAULT_PROTOCOL = PROTOCOL_311 DEFAULT_DISCOVERY_PREFIX = "homeassistant" DEFAULT_TLS_PROTOCOL = "auto" DEFAULT_PAYLOAD_AVAILABLE = "online" DEFAULT_PAYLOAD_NOT_AVAILABLE = "offline" -ATTR_TOPIC = "topic" -ATTR_PAYLOAD = "payload" ATTR_PAYLOAD_TEMPLATE = "payload_template" -ATTR_QOS = CONF_QOS -ATTR_RETAIN = CONF_RETAIN MAX_RECONNECT_WAIT = 300 # seconds @@ -125,59 +125,6 @@ CONNECTION_FAILED = "connection_failed" CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable" -def valid_topic(value: Any) -> str: - """Validate that this is a valid topic name/filter.""" - value = cv.string(value) - try: - raw_value = value.encode("utf-8") - except UnicodeError: - raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.") - if not raw_value: - raise vol.Invalid("MQTT topic name/filter must not be empty.") - if len(raw_value) > 65535: - raise vol.Invalid( - "MQTT topic name/filter must not be longer than 65535 encoded bytes." - ) - if "\0" in value: - raise vol.Invalid("MQTT topic name/filter must not contain null character.") - return value - - -def valid_subscribe_topic(value: Any) -> str: - """Validate that we can subscribe using this MQTT topic.""" - value = valid_topic(value) - for i in (i for i, c in enumerate(value) if c == "+"): - if (i > 0 and value[i - 1] != "/") or ( - i < len(value) - 1 and value[i + 1] != "/" - ): - raise vol.Invalid( - "Single-level wildcard must occupy an entire level of the filter" - ) - - index = value.find("#") - if index != -1: - if index != len(value) - 1: - # If there are multiple wildcards, this will also trigger - raise vol.Invalid( - "Multi-level wildcard must be the last " - "character in the topic filter." - ) - if len(value) > 1 and value[index - 1] != "/": - raise vol.Invalid( - "Multi-level wildcard must be after a topic level separator." - ) - - return value - - -def valid_publish_topic(value: Any) -> str: - """Validate that we can publish using this MQTT topic.""" - value = valid_topic(value) - if "+" in value or "#" in value: - raise vol.Invalid("Wildcards can not be used in topic names") - return value - - def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType: """Validate that a device info entry has at least one identifying value.""" if not value.get(CONF_IDENTIFIERS) and not value.get(CONF_CONNECTIONS): @@ -188,8 +135,6 @@ def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType return value -_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2])) - CLIENT_KEY_AUTH_MSG = ( "client_key and client_cert must both be present in " "the MQTT broker configuration" @@ -554,6 +499,11 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool: return True +def _merge_config(entry, conf): + """Merge configuration.yaml config with config entry.""" + return {**conf, **entry.data} + + async def async_setup_entry(hass, entry): """Load a config entry.""" conf = hass.data.get(DATA_MQTT_CONFIG) @@ -574,76 +524,9 @@ async def async_setup_entry(hass, entry): entry.data, ) - conf.update(entry.data) + conf = _merge_config(entry, conf) - broker = conf[CONF_BROKER] - port = conf[CONF_PORT] - client_id = conf.get(CONF_CLIENT_ID) - keepalive = conf[CONF_KEEPALIVE] - username = conf.get(CONF_USERNAME) - password = conf.get(CONF_PASSWORD) - certificate = conf.get(CONF_CERTIFICATE) - client_key = conf.get(CONF_CLIENT_KEY) - client_cert = conf.get(CONF_CLIENT_CERT) - tls_insecure = conf.get(CONF_TLS_INSECURE) - protocol = conf[CONF_PROTOCOL] - - # For cloudmqtt.com, secured connection, auto fill in certificate - if ( - certificate is None - and 19999 < conf[CONF_PORT] < 30000 - and broker.endswith(".cloudmqtt.com") - ): - certificate = os.path.join( - os.path.dirname(__file__), "addtrustexternalcaroot.crt" - ) - - # When the certificate is set to auto, use bundled certs from requests - elif certificate == "auto": - certificate = requests.certs.where() - - if CONF_WILL_MESSAGE in conf: - will_message = Message(**conf[CONF_WILL_MESSAGE]) - else: - will_message = None - - if CONF_BIRTH_MESSAGE in conf: - birth_message = Message(**conf[CONF_BIRTH_MESSAGE]) - else: - birth_message = None - - # Be able to override versions other than TLSv1.0 under Python3.6 - conf_tls_version: str = conf.get(CONF_TLS_VERSION) - if conf_tls_version == "1.2": - tls_version = ssl.PROTOCOL_TLSv1_2 - elif conf_tls_version == "1.1": - tls_version = ssl.PROTOCOL_TLSv1_1 - elif conf_tls_version == "1.0": - tls_version = ssl.PROTOCOL_TLSv1 - else: - # Python3.6 supports automatic negotiation of highest TLS version - if sys.hexversion >= 0x03060000: - tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member - else: - tls_version = ssl.PROTOCOL_TLSv1 - - hass.data[DATA_MQTT] = MQTT( - hass, - broker=broker, - port=port, - client_id=client_id, - keepalive=keepalive, - username=username, - password=password, - certificate=certificate, - client_key=client_key, - client_cert=client_cert, - tls_insecure=tls_insecure, - protocol=protocol, - will_message=will_message, - birth_message=birth_message, - tls_version=tls_version, - ) + hass.data[DATA_MQTT] = MQTT(hass, entry, conf,) await hass.data[DATA_MQTT].async_connect() @@ -732,53 +615,101 @@ class Subscription: class MQTT: """Home Assistant MQTT client.""" - def __init__( - self, - hass: HomeAssistantType, - broker: str, - port: int, - client_id: Optional[str], - keepalive: Optional[int], - username: Optional[str], - password: Optional[str], - certificate: Optional[str], - client_key: Optional[str], - client_cert: Optional[str], - tls_insecure: Optional[bool], - protocol: Optional[str], - will_message: Optional[Message], - birth_message: Optional[Message], - tls_version: Optional[int], - ) -> None: + def __init__(self, hass: HomeAssistantType, config_entry, conf,) -> None: """Initialize Home Assistant MQTT client.""" - # We don't import them on the top because some integrations + # We don't import on the top because some integrations # should be able to optionally rely on MQTT. - # pylint: disable=import-outside-toplevel - import paho.mqtt.client as mqtt + import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel self.hass = hass - self.broker = broker - self.port = port - self.keepalive = keepalive + self.config_entry = config_entry + self.conf = conf self.subscriptions: List[Subscription] = [] - self.birth_message = birth_message self.connected = False self._mqttc: mqtt.Client = None self._paho_lock = asyncio.Lock() - if protocol == PROTOCOL_31: + self.init_client() + self.config_entry.add_update_listener(self.async_config_entry_updated) + + @staticmethod + async def async_config_entry_updated(hass, entry) -> None: + """Handle signals of config entry being updated. + + This is a static method because a class method (bound method), can not be used with weak references. + Causes for this is config entry options changing. + """ + self = hass.data[DATA_MQTT] + + conf = hass.data.get(DATA_MQTT_CONFIG) + if conf is None: + conf = CONFIG_SCHEMA({DOMAIN: dict(entry.data)})[DOMAIN] + + self.conf = _merge_config(entry, conf) + await self.async_disconnect() + self.init_client() + await self.async_connect() + + await discovery.async_stop(hass) + if self.conf.get(CONF_DISCOVERY): + await _async_setup_discovery(hass, self.conf, entry) + + def init_client(self): + """Initialize paho client.""" + # We don't import on the top because some integrations + # should be able to optionally rely on MQTT. + import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel + + if self.conf[CONF_PROTOCOL] == PROTOCOL_31: proto: int = mqtt.MQTTv31 else: proto = mqtt.MQTTv311 + client_id = self.conf.get(CONF_CLIENT_ID) if client_id is None: self._mqttc = mqtt.Client(protocol=proto) else: self._mqttc = mqtt.Client(client_id, protocol=proto) + username = self.conf.get(CONF_USERNAME) + password = self.conf.get(CONF_PASSWORD) if username is not None: self._mqttc.username_pw_set(username, password) + certificate = self.conf.get(CONF_CERTIFICATE) + + # For cloudmqtt.com, secured connection, auto fill in certificate + if ( + certificate is None + and 19999 < self.conf[CONF_PORT] < 30000 + and self.conf[CONF_BROKER].endswith(".cloudmqtt.com") + ): + certificate = os.path.join( + os.path.dirname(__file__), "addtrustexternalcaroot.crt" + ) + + # When the certificate is set to auto, use bundled certs from certifi + elif certificate == "auto": + certificate = certifi.where() + + # Be able to override versions other than TLSv1.0 under Python3.6 + conf_tls_version: str = self.conf.get(CONF_TLS_VERSION) + if conf_tls_version == "1.2": + tls_version = ssl.PROTOCOL_TLSv1_2 + elif conf_tls_version == "1.1": + tls_version = ssl.PROTOCOL_TLSv1_1 + elif conf_tls_version == "1.0": + tls_version = ssl.PROTOCOL_TLSv1 + else: + # Python3.6 supports automatic negotiation of highest TLS version + if sys.hexversion >= 0x03060000: + tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member + else: + tls_version = ssl.PROTOCOL_TLSv1 + + client_key = self.conf.get(CONF_CLIENT_KEY) + client_cert = self.conf.get(CONF_CLIENT_CERT) + tls_insecure = self.conf.get(CONF_TLS_INSECURE) if certificate is not None: self._mqttc.tls_set( certificate, @@ -794,6 +725,11 @@ class MQTT: self._mqttc.on_disconnect = self._mqtt_on_disconnect self._mqttc.on_message = self._mqtt_on_message + if CONF_WILL_MESSAGE in self.conf: + will_message = Message(**self.conf[CONF_WILL_MESSAGE]) + else: + will_message = None + if will_message is not None: self._mqttc.will_set( # pylint: disable=no-value-for-parameter *attr.astuple( @@ -813,14 +749,17 @@ class MQTT: ) async def async_connect(self) -> str: - """Connect to the host. Does process messages yet.""" + """Connect to the host. Does not process messages yet.""" # pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt result: int = None try: result = await self.hass.async_add_executor_job( - self._mqttc.connect, self.broker, self.port, self.keepalive + self._mqttc.connect, + self.conf[CONF_BROKER], + self.conf[CONF_PORT], + self.conf[CONF_KEEPALIVE], ) except OSError as err: _LOGGER.error("Failed to connect to MQTT server due to exception: %s", err) @@ -922,7 +861,12 @@ class MQTT: self.connected = True dispatcher_send(self.hass, MQTT_CONNECTED) - _LOGGER.info("Connected to MQTT server (%s)", result_code) + _LOGGER.info( + "Connected to MQTT server %s:%s (%s)", + self.conf[CONF_BROKER], + self.conf[CONF_PORT], + result_code, + ) # Group subscriptions to only re-subscribe once for each topic. keyfunc = attrgetter("topic") @@ -931,11 +875,12 @@ class MQTT: max_qos = max(subscription.qos for subscription in subs) self.hass.add_job(self._async_perform_subscription, topic, max_qos) - if self.birth_message: + if CONF_BIRTH_MESSAGE in self.conf: + birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE]) self.hass.add_job( self.async_publish( # pylint: disable=no-value-for-parameter *attr.astuple( - self.birth_message, + birth_message, filter=lambda attr, value: attr.name not in ["subscribed_topic", "timestamp"], ) @@ -990,7 +935,12 @@ class MQTT: """Disconnected callback.""" self.connected = False dispatcher_send(self.hass, MQTT_DISCONNECTED) - _LOGGER.warning("Disconnected from MQTT server (%s)", result_code) + _LOGGER.warning( + "Disconnected from MQTT server %s:%s (%s)", + self.conf[CONF_BROKER], + self.conf[CONF_PORT], + result_code, + ) def _raise_on_error(result_code: int) -> None: diff --git a/homeassistant/components/mqtt/config_flow.py b/homeassistant/components/mqtt/config_flow.py index 76c1889e629..2f4feaed5e9 100644 --- a/homeassistant/components/mqtt/config_flow.py +++ b/homeassistant/components/mqtt/config_flow.py @@ -1,5 +1,6 @@ """Config flow for MQTT.""" from collections import OrderedDict +import logging import queue import voluptuous as vol @@ -13,7 +14,22 @@ from homeassistant.const import ( CONF_USERNAME, ) -from .const import CONF_BROKER, CONF_DISCOVERY, DEFAULT_DISCOVERY +from .const import ( + ATTR_PAYLOAD, + ATTR_QOS, + ATTR_RETAIN, + ATTR_TOPIC, + CONF_BIRTH_MESSAGE, + CONF_BROKER, + CONF_DISCOVERY, + CONF_WILL_MESSAGE, + DEFAULT_DISCOVERY, + DEFAULT_QOS, + DEFAULT_RETAIN, +) +from .util import MQTT_WILL_BIRTH_SCHEMA + +_LOGGER = logging.getLogger(__name__) @config_entries.HANDLERS.register("mqtt") @@ -25,6 +41,11 @@ class FlowHandler(config_entries.ConfigFlow): _hassio_discovery = None + @staticmethod + def async_get_options_flow(config_entry): + """Get the options flow for this handler.""" + return MQTTOptionsFlowHandler(config_entry) + async def async_step_user(self, user_input=None): """Handle a flow initialized by the user.""" if self._async_current_entries(): @@ -123,6 +144,163 @@ class FlowHandler(config_entries.ConfigFlow): ) +class MQTTOptionsFlowHandler(config_entries.OptionsFlow): + """Handle MQTT options.""" + + def __init__(self, config_entry): + """Initialize MQTT options flow.""" + self.config_entry = config_entry + self.broker_config = {} + self.options = dict(config_entry.options) + + async def async_step_init(self, user_input=None): + """Manage the MQTT options.""" + return await self.async_step_broker() + + async def async_step_broker(self, user_input=None): + """Manage the MQTT options.""" + errors = {} + current_config = self.config_entry.data + if user_input is not None: + can_connect = await self.hass.async_add_executor_job( + try_connection, + user_input[CONF_BROKER], + user_input[CONF_PORT], + user_input.get(CONF_USERNAME), + user_input.get(CONF_PASSWORD), + ) + + if can_connect: + self.broker_config.update(user_input) + return await self.async_step_options() + + errors["base"] = "cannot_connect" + + fields = OrderedDict() + fields[vol.Required(CONF_BROKER, default=current_config[CONF_BROKER])] = str + fields[vol.Required(CONF_PORT, default=current_config[CONF_PORT])] = vol.Coerce( + int + ) + fields[ + vol.Optional( + CONF_USERNAME, + description={"suggested_value": current_config.get(CONF_USERNAME)}, + ) + ] = str + fields[ + vol.Optional( + CONF_PASSWORD, + description={"suggested_value": current_config.get(CONF_PASSWORD)}, + ) + ] = str + + return self.async_show_form( + step_id="broker", data_schema=vol.Schema(fields), errors=errors, + ) + + async def async_step_options(self, user_input=None): + """Manage the MQTT options.""" + errors = {} + current_config = self.config_entry.data + options_config = {} + if user_input is not None: + bad_birth = False + bad_will = False + + if "birth_topic" in user_input: + birth_message = { + ATTR_TOPIC: user_input["birth_topic"], + ATTR_PAYLOAD: user_input.get("birth_payload", ""), + ATTR_QOS: user_input["birth_qos"], + ATTR_RETAIN: user_input["birth_retain"], + } + try: + birth_message = MQTT_WILL_BIRTH_SCHEMA(birth_message) + options_config[CONF_BIRTH_MESSAGE] = birth_message + except vol.Invalid: + errors["base"] = "bad_birth" + bad_birth = True + + if "will_topic" in user_input: + will_message = { + ATTR_TOPIC: user_input["will_topic"], + ATTR_PAYLOAD: user_input.get("will_payload", ""), + ATTR_QOS: user_input["will_qos"], + ATTR_RETAIN: user_input["will_retain"], + } + try: + will_message = MQTT_WILL_BIRTH_SCHEMA(will_message) + options_config[CONF_WILL_MESSAGE] = will_message + except vol.Invalid: + errors["base"] = "bad_will" + bad_will = True + + options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY] + + if not bad_birth and not bad_will: + updated_config = {} + updated_config.update(self.broker_config) + updated_config.update(options_config) + self.hass.config_entries.async_update_entry( + self.config_entry, data=updated_config + ) + return self.async_create_entry(title="", data=None) + + birth_topic = None + birth_payload = None + birth_qos = DEFAULT_QOS + birth_retain = DEFAULT_RETAIN + if CONF_BIRTH_MESSAGE in current_config: + birth_topic = current_config[CONF_BIRTH_MESSAGE][ATTR_TOPIC] + birth_payload = current_config[CONF_BIRTH_MESSAGE][ATTR_PAYLOAD] + birth_qos = current_config[CONF_BIRTH_MESSAGE].get(ATTR_QOS, DEFAULT_QOS) + birth_retain = current_config[CONF_BIRTH_MESSAGE].get( + ATTR_RETAIN, DEFAULT_RETAIN + ) + + will_topic = None + will_payload = None + will_qos = DEFAULT_QOS + will_retain = DEFAULT_RETAIN + if CONF_WILL_MESSAGE in current_config: + will_topic = current_config[CONF_WILL_MESSAGE][ATTR_TOPIC] + will_payload = current_config[CONF_WILL_MESSAGE][ATTR_PAYLOAD] + will_qos = current_config[CONF_WILL_MESSAGE].get(ATTR_QOS, DEFAULT_QOS) + will_retain = current_config[CONF_WILL_MESSAGE].get( + ATTR_RETAIN, DEFAULT_RETAIN + ) + + fields = OrderedDict() + fields[ + vol.Optional( + CONF_DISCOVERY, + default=current_config.get(CONF_DISCOVERY, DEFAULT_DISCOVERY), + ) + ] = bool + fields[ + vol.Optional("birth_topic", description={"suggested_value": birth_topic}) + ] = str + fields[ + vol.Optional( + "birth_payload", description={"suggested_value": birth_payload} + ) + ] = str + fields[vol.Optional("birth_qos", default=birth_qos)] = vol.In([0, 1, 2]) + fields[vol.Optional("birth_retain", default=birth_retain)] = bool + fields[ + vol.Optional("will_topic", description={"suggested_value": will_topic}) + ] = str + fields[ + vol.Optional("will_payload", description={"suggested_value": will_payload}) + ] = str + fields[vol.Optional("will_qos", default=will_qos)] = vol.In([0, 1, 2]) + fields[vol.Optional("will_retain", default=will_retain)] = bool + + return self.async_show_form( + step_id="options", data_schema=vol.Schema(fields), errors=errors, + ) + + def try_connection(broker, port, username, password, protocol="3.1"): """Test if we can connect to an MQTT broker.""" # pylint: disable=import-outside-toplevel diff --git a/homeassistant/components/mqtt/const.py b/homeassistant/components/mqtt/const.py index acb24f4bdda..62d2643bc91 100644 --- a/homeassistant/components/mqtt/const.py +++ b/homeassistant/components/mqtt/const.py @@ -1,14 +1,25 @@ """Constants used by multiple MQTT modules.""" -CONF_BROKER = "broker" -CONF_DISCOVERY = "discovery" -DEFAULT_DISCOVERY = False - ATTR_DISCOVERY_HASH = "discovery_hash" ATTR_DISCOVERY_PAYLOAD = "discovery_payload" ATTR_DISCOVERY_TOPIC = "discovery_topic" +ATTR_PAYLOAD = "payload" +ATTR_QOS = "qos" +ATTR_RETAIN = "retain" +ATTR_TOPIC = "topic" + +CONF_BROKER = "broker" +CONF_BIRTH_MESSAGE = "birth_message" +CONF_DISCOVERY = "discovery" +CONF_QOS = ATTR_QOS +CONF_RETAIN = ATTR_RETAIN CONF_STATE_TOPIC = "state_topic" -PROTOCOL_311 = "3.1.1" +CONF_WILL_MESSAGE = "will_message" + +DEFAULT_DISCOVERY = False DEFAULT_QOS = 0 +DEFAULT_RETAIN = False MQTT_CONNECTED = "mqtt_connected" MQTT_DISCONNECTED = "mqtt_disconnected" + +PROTOCOL_311 = "3.1.1" diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index 0ab108cabde..281172b6332 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -35,8 +35,9 @@ SUPPORTED_COMPONENTS = [ ] ALREADY_DISCOVERED = "mqtt_discovered_components" -DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock" CONFIG_ENTRY_IS_SETUP = "mqtt_config_entry_is_setup" +DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock" +DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe" MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}" MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}" @@ -163,8 +164,15 @@ async def async_start( hass.data[DATA_CONFIG_ENTRY_LOCK] = asyncio.Lock() hass.data[CONFIG_ENTRY_IS_SETUP] = set() - await mqtt.async_subscribe( + hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe( hass, f"{discovery_topic}/#", async_device_message_received, 0 ) return True + + +async def async_stop(hass: HomeAssistantType) -> bool: + """Stop MQTT Discovery.""" + if DISCOVERY_UNSUBSCRIBE in hass.data and hass.data[DISCOVERY_UNSUBSCRIBE]: + hass.data[DISCOVERY_UNSUBSCRIBE]() + hass.data[DISCOVERY_UNSUBSCRIBE] = None diff --git a/homeassistant/components/mqtt/strings.json b/homeassistant/components/mqtt/strings.json index 305f3a206a7..d10bc8bc4e6 100644 --- a/homeassistant/components/mqtt/strings.json +++ b/homeassistant/components/mqtt/strings.json @@ -47,5 +47,37 @@ "button_5": "Fifth button", "button_6": "Sixth button" } + }, + "options": { + "step": { + "broker": { + "description": "Please enter the connection information of your MQTT broker.", + "data": { + "broker": "Broker", + "port": "[%key:common::config_flow::data::port%]", + "username": "[%key:common::config_flow::data::username%]", + "password": "[%key:common::config_flow::data::password%]" + } + }, + "options": { + "description": "Please select MQTT options.", + "data": { + "discovery": "Enable discovery", + "birth_topic": "Birth message topic", + "birth_payload": "Birth message payload", + "birth_qos": "Birth message QoS", + "birth_retain": "Birth message retain", + "will_topic": "Will message topic", + "will_payload": "Will message payload", + "will_qos": "Will message QoS", + "will_retain": "Will message retain" + } + } + }, + "error": { + "cannot_connect": "Unable to connect to the broker.", + "bad_birth": "Invalid birth topic.", + "bad_will": "Invalid will topic." + } } } \ No newline at end of file diff --git a/homeassistant/components/mqtt/util.py b/homeassistant/components/mqtt/util.py new file mode 100644 index 00000000000..568dbabd7b0 --- /dev/null +++ b/homeassistant/components/mqtt/util.py @@ -0,0 +1,82 @@ +"""Utility functions for the MQTT integration.""" +from typing import Any + +import voluptuous as vol + +from homeassistant.const import CONF_PAYLOAD +from homeassistant.helpers import config_validation as cv + +from .const import ( + ATTR_PAYLOAD, + ATTR_QOS, + ATTR_RETAIN, + ATTR_TOPIC, + DEFAULT_QOS, + DEFAULT_RETAIN, +) + + +def valid_topic(value: Any) -> str: + """Validate that this is a valid topic name/filter.""" + value = cv.string(value) + try: + raw_value = value.encode("utf-8") + except UnicodeError: + raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.") + if not raw_value: + raise vol.Invalid("MQTT topic name/filter must not be empty.") + if len(raw_value) > 65535: + raise vol.Invalid( + "MQTT topic name/filter must not be longer than 65535 encoded bytes." + ) + if "\0" in value: + raise vol.Invalid("MQTT topic name/filter must not contain null character.") + return value + + +def valid_subscribe_topic(value: Any) -> str: + """Validate that we can subscribe using this MQTT topic.""" + value = valid_topic(value) + for i in (i for i, c in enumerate(value) if c == "+"): + if (i > 0 and value[i - 1] != "/") or ( + i < len(value) - 1 and value[i + 1] != "/" + ): + raise vol.Invalid( + "Single-level wildcard must occupy an entire level of the filter" + ) + + index = value.find("#") + if index != -1: + if index != len(value) - 1: + # If there are multiple wildcards, this will also trigger + raise vol.Invalid( + "Multi-level wildcard must be the last " + "character in the topic filter." + ) + if len(value) > 1 and value[index - 1] != "/": + raise vol.Invalid( + "Multi-level wildcard must be after a topic level separator." + ) + + return value + + +def valid_publish_topic(value: Any) -> str: + """Validate that we can publish using this MQTT topic.""" + value = valid_topic(value) + if "+" in value or "#" in value: + raise vol.Invalid("Wildcards can not be used in topic names") + return value + + +_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2])) + +MQTT_WILL_BIRTH_SCHEMA = vol.Schema( + { + vol.Required(ATTR_TOPIC): valid_publish_topic, + vol.Required(ATTR_PAYLOAD, CONF_PAYLOAD): cv.string, + vol.Optional(ATTR_QOS, default=DEFAULT_QOS): _VALID_QOS_SCHEMA, + vol.Optional(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean, + }, + required=True, +) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 2f57cb50543..68442689d3b 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -1008,7 +1008,8 @@ class OptionsFlowManager(data_entry_flow.FlowManager): entry = self.hass.config_entries.async_get_entry(flow.handler) if entry is None: raise UnknownEntry(flow.handler) - self.hass.config_entries.async_update_entry(entry, options=result["data"]) + if result["data"] is not None: + self.hass.config_entries.async_update_entry(entry, options=result["data"]) result["result"] = True return result diff --git a/tests/components/mqtt/test_config_flow.py b/tests/components/mqtt/test_config_flow.py index 0990accec9f..581395b702a 100644 --- a/tests/components/mqtt/test_config_flow.py +++ b/tests/components/mqtt/test_config_flow.py @@ -1,7 +1,11 @@ """Test config flow.""" import pytest +import voluptuous as vol +from homeassistant import data_entry_flow +from homeassistant.components import mqtt +from homeassistant.components.mqtt.discovery import async_start from homeassistant.setup import async_setup_component from tests.async_mock import patch @@ -144,3 +148,340 @@ async def test_hassio_confirm(hass, mock_try_connection, mock_finish_setup): assert len(mock_try_connection.mock_calls) == 1 # Check config entry got setup assert len(mock_finish_setup.mock_calls) == 1 + + +async def test_option_flow(hass, mqtt_mock, mock_try_connection): + """Test config flow options.""" + mock_try_connection.return_value = True + config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0] + await async_start(hass, "homeassistant", config_entry) + config_entry.data = { + mqtt.CONF_BROKER: "test-broker", + mqtt.CONF_PORT: 1234, + } + + mqtt_mock.async_connect.reset_mock() + + result = await hass.config_entries.options.async_init(config_entry.entry_id) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "broker" + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={ + mqtt.CONF_BROKER: "another-broker", + mqtt.CONF_PORT: 2345, + mqtt.CONF_USERNAME: "user", + mqtt.CONF_PASSWORD: "pass", + }, + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "options" + + await hass.async_block_till_done() + assert mqtt_mock.async_connect.call_count == 0 + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={ + mqtt.CONF_DISCOVERY: True, + "birth_topic": "ha_state/online", + "birth_payload": "online", + "birth_qos": 1, + "birth_retain": True, + "will_topic": "ha_state/offline", + "will_payload": "offline", + "will_qos": 2, + "will_retain": True, + }, + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result["data"] is None + assert config_entry.data == { + mqtt.CONF_BROKER: "another-broker", + mqtt.CONF_PORT: 2345, + mqtt.CONF_USERNAME: "user", + mqtt.CONF_PASSWORD: "pass", + mqtt.CONF_DISCOVERY: True, + mqtt.CONF_BIRTH_MESSAGE: { + mqtt.ATTR_TOPIC: "ha_state/online", + mqtt.ATTR_PAYLOAD: "online", + mqtt.ATTR_QOS: 1, + mqtt.ATTR_RETAIN: True, + }, + mqtt.CONF_WILL_MESSAGE: { + mqtt.ATTR_TOPIC: "ha_state/offline", + mqtt.ATTR_PAYLOAD: "offline", + mqtt.ATTR_QOS: 2, + mqtt.ATTR_RETAIN: True, + }, + } + + await hass.async_block_till_done() + assert mqtt_mock.async_connect.call_count == 1 + + +def get_default(schema, key): + """Get default value for key in voluptuous schema.""" + for k in schema.keys(): + if k == key: + if k.default == vol.UNDEFINED: + return None + return k.default() + + +def get_suggested(schema, key): + """Get suggested value for key in voluptuous schema.""" + for k in schema.keys(): + if k == key: + if k.description is None or "suggested_value" not in k.description: + return None + return k.description["suggested_value"] + + +async def test_option_flow_default_suggested_values( + hass, mqtt_mock, mock_try_connection +): + """Test config flow options has default/suggested values.""" + mock_try_connection.return_value = True + config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0] + await async_start(hass, "homeassistant", config_entry) + config_entry.data = { + mqtt.CONF_BROKER: "test-broker", + mqtt.CONF_PORT: 1234, + mqtt.CONF_USERNAME: "user", + mqtt.CONF_PASSWORD: "pass", + mqtt.CONF_DISCOVERY: True, + mqtt.CONF_BIRTH_MESSAGE: { + mqtt.ATTR_TOPIC: "ha_state/online", + mqtt.ATTR_PAYLOAD: "online", + mqtt.ATTR_QOS: 1, + mqtt.ATTR_RETAIN: True, + }, + mqtt.CONF_WILL_MESSAGE: { + mqtt.ATTR_TOPIC: "ha_state/offline", + mqtt.ATTR_PAYLOAD: "offline", + mqtt.ATTR_QOS: 2, + mqtt.ATTR_RETAIN: False, + }, + } + + # Test default/suggested values from config + result = await hass.config_entries.options.async_init(config_entry.entry_id) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "broker" + defaults = { + mqtt.CONF_BROKER: "test-broker", + mqtt.CONF_PORT: 1234, + } + suggested = { + mqtt.CONF_USERNAME: "user", + mqtt.CONF_PASSWORD: "pass", + } + for k, v in defaults.items(): + assert get_default(result["data_schema"].schema, k) == v + for k, v in suggested.items(): + assert get_suggested(result["data_schema"].schema, k) == v + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={ + mqtt.CONF_BROKER: "another-broker", + mqtt.CONF_PORT: 2345, + mqtt.CONF_USERNAME: "us3r", + mqtt.CONF_PASSWORD: "p4ss", + }, + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "options" + defaults = { + mqtt.CONF_DISCOVERY: True, + "birth_qos": 1, + "birth_retain": True, + "will_qos": 2, + "will_retain": False, + } + suggested = { + "birth_topic": "ha_state/online", + "birth_payload": "online", + "will_topic": "ha_state/offline", + "will_payload": "offline", + } + for k, v in defaults.items(): + assert get_default(result["data_schema"].schema, k) == v + for k, v in suggested.items(): + assert get_suggested(result["data_schema"].schema, k) == v + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={ + mqtt.CONF_DISCOVERY: False, + "birth_topic": "ha_state/onl1ne", + "birth_payload": "onl1ne", + "birth_qos": 2, + "birth_retain": False, + "will_topic": "ha_state/offl1ne", + "will_payload": "offl1ne", + "will_qos": 1, + "will_retain": True, + }, + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + + # Test updated default/suggested values from config + result = await hass.config_entries.options.async_init(config_entry.entry_id) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "broker" + defaults = { + mqtt.CONF_BROKER: "another-broker", + mqtt.CONF_PORT: 2345, + } + suggested = { + mqtt.CONF_USERNAME: "us3r", + mqtt.CONF_PASSWORD: "p4ss", + } + for k, v in defaults.items(): + assert get_default(result["data_schema"].schema, k) == v + for k, v in suggested.items(): + assert get_suggested(result["data_schema"].schema, k) == v + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345}, + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "options" + defaults = { + mqtt.CONF_DISCOVERY: False, + "birth_qos": 2, + "birth_retain": False, + "will_qos": 1, + "will_retain": True, + } + suggested = { + "birth_topic": "ha_state/onl1ne", + "birth_payload": "onl1ne", + "will_topic": "ha_state/offl1ne", + "will_payload": "offl1ne", + } + for k, v in defaults.items(): + assert get_default(result["data_schema"].schema, k) == v + for k, v in suggested.items(): + assert get_suggested(result["data_schema"].schema, k) == v + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={ + mqtt.CONF_DISCOVERY: True, + "birth_topic": "ha_state/onl1ne", + "birth_payload": "onl1ne", + "birth_qos": 2, + "birth_retain": False, + "will_topic": "ha_state/offl1ne", + "will_payload": "offl1ne", + "will_qos": 1, + "will_retain": True, + }, + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + + +async def test_options_user_connection_fails(hass, mock_try_connection): + """Test if connection cannot be made.""" + config_entry = MockConfigEntry(domain=mqtt.DOMAIN) + config_entry.add_to_hass(hass) + config_entry.data = { + mqtt.CONF_BROKER: "test-broker", + mqtt.CONF_PORT: 1234, + } + + mock_try_connection.return_value = False + + result = await hass.config_entries.options.async_init(config_entry.entry_id) + assert result["type"] == "form" + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={mqtt.CONF_BROKER: "bad-broker", mqtt.CONF_PORT: 2345}, + ) + + assert result["type"] == "form" + assert result["errors"]["base"] == "cannot_connect" + + # Check we tried the connection + assert len(mock_try_connection.mock_calls) == 1 + # Check config entry did not update + assert config_entry.data == { + mqtt.CONF_BROKER: "test-broker", + mqtt.CONF_PORT: 1234, + } + + +async def test_options_bad_birth_message_fails(hass, mock_try_connection): + """Test bad birth message.""" + config_entry = MockConfigEntry(domain=mqtt.DOMAIN) + config_entry.add_to_hass(hass) + config_entry.data = { + mqtt.CONF_BROKER: "test-broker", + mqtt.CONF_PORT: 1234, + } + + mock_try_connection.return_value = True + + result = await hass.config_entries.options.async_init(config_entry.entry_id) + assert result["type"] == "form" + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345}, + ) + + assert result["type"] == "form" + assert result["step_id"] == "options" + + result = await hass.config_entries.options.async_configure( + result["flow_id"], user_input={"birth_topic": "ha_state/online/#"}, + ) + assert result["type"] == "form" + assert result["errors"]["base"] == "bad_birth" + + # Check config entry did not update + assert config_entry.data == { + mqtt.CONF_BROKER: "test-broker", + mqtt.CONF_PORT: 1234, + } + + +async def test_options_bad_will_message_fails(hass, mock_try_connection): + """Test bad will message.""" + config_entry = MockConfigEntry(domain=mqtt.DOMAIN) + config_entry.add_to_hass(hass) + config_entry.data = { + mqtt.CONF_BROKER: "test-broker", + mqtt.CONF_PORT: 1234, + } + + mock_try_connection.return_value = True + + result = await hass.config_entries.options.async_init(config_entry.entry_id) + assert result["type"] == "form" + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345}, + ) + + assert result["type"] == "form" + assert result["step_id"] == "options" + + result = await hass.config_entries.options.async_configure( + result["flow_id"], user_input={"will_topic": "ha_state/offline/#"}, + ) + assert result["type"] == "form" + assert result["errors"]["base"] == "bad_will" + + # Check config entry did not update + assert config_entry.data == { + mqtt.CONF_BROKER: "test-broker", + mqtt.CONF_PORT: 1234, + } diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 89b5a7423f8..247f616f379 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -171,26 +171,26 @@ def test_validate_topic(): """Test topic name/filter validation.""" # Invalid UTF-8, must not contain U+D800 to U+DFFF. with pytest.raises(vol.Invalid): - mqtt.valid_topic("\ud800") + mqtt.util.valid_topic("\ud800") with pytest.raises(vol.Invalid): - mqtt.valid_topic("\udfff") + mqtt.util.valid_topic("\udfff") # Topic MUST NOT be empty with pytest.raises(vol.Invalid): - mqtt.valid_topic("") + mqtt.util.valid_topic("") # Topic MUST NOT be longer than 65535 encoded bytes. with pytest.raises(vol.Invalid): - mqtt.valid_topic("ü" * 32768) + mqtt.util.valid_topic("ü" * 32768) # UTF-8 MUST NOT include null character with pytest.raises(vol.Invalid): - mqtt.valid_topic("bad\0one") + mqtt.util.valid_topic("bad\0one") # Topics "SHOULD NOT" include these special characters # (not MUST NOT, RFC2119). The receiver MAY close the connection. - mqtt.valid_topic("\u0001") - mqtt.valid_topic("\u001F") - mqtt.valid_topic("\u009F") - mqtt.valid_topic("\u009F") - mqtt.valid_topic("\uffff") + mqtt.util.valid_topic("\u0001") + mqtt.util.valid_topic("\u001F") + mqtt.util.valid_topic("\u009F") + mqtt.util.valid_topic("\u009F") + mqtt.util.valid_topic("\uffff") def test_validate_subscribe_topic(): @@ -587,7 +587,7 @@ async def test_retained_message_on_subscribe_received( mqtt_client_mock.subscribe.side_effect = side_effect # Fake that the client is connected - mqtt_mock.connected = True + mqtt_mock().connected = True calls_a = MagicMock() await mqtt.async_subscribe(hass, "test/state", calls_a) @@ -605,7 +605,7 @@ async def test_not_calling_unsubscribe_with_active_subscribers( ): """Test not calling unsubscribe() when other subscribers are active.""" # Fake that the client is connected - mqtt_mock.connected = True + mqtt_mock().connected = True unsub = await mqtt.async_subscribe(hass, "test/state", None) await mqtt.async_subscribe(hass, "test/state", None) @@ -620,7 +620,7 @@ async def test_not_calling_unsubscribe_with_active_subscribers( async def test_restore_subscriptions_on_reconnect(hass, mqtt_client_mock, mqtt_mock): """Test subscriptions are restored on reconnect.""" # Fake that the client is connected - mqtt_mock.connected = True + mqtt_mock().connected = True await mqtt.async_subscribe(hass, "test/state", None) await hass.async_block_till_done() @@ -637,7 +637,7 @@ async def test_restore_all_active_subscriptions_on_reconnect( ): """Test active subscriptions are restored correctly on reconnect.""" # Fake that the client is connected - mqtt_mock.connected = True + mqtt_mock().connected = True mqtt_client_mock.subscribe.side_effect = ( (0, 1), @@ -716,81 +716,107 @@ async def test_setup_raises_ConfigEntryNotReady_if_no_connect_broker(hass, caplo assert "Failed to connect to MQTT server due to exception:" in caplog.text -async def test_setup_uses_certificate_on_certificate_set_to_auto(hass, mock_mqtt): +async def test_setup_uses_certificate_on_certificate_set_to_auto(hass): """Test setup uses bundled certs when certificate is set to auto.""" - entry = MockConfigEntry( - domain=mqtt.DOMAIN, - data={mqtt.CONF_BROKER: "test-broker", "certificate": "auto"}, - ) + calls = [] - assert await mqtt.async_setup_entry(hass, entry) + def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None): + calls.append((certificate, certfile, keyfile, tls_version)) - assert mock_mqtt.called + with patch("paho.mqtt.client.Client") as mock_client: + mock_client().tls_set = mock_tls_set + entry = MockConfigEntry( + domain=mqtt.DOMAIN, + data={mqtt.CONF_BROKER: "test-broker", "certificate": "auto"}, + ) - import requests.certs + assert await mqtt.async_setup_entry(hass, entry) - expectedCertificate = requests.certs.where() - assert mock_mqtt.mock_calls[0][2]["certificate"] == expectedCertificate + assert calls + + import certifi + + expectedCertificate = certifi.where() + # assert mock_mqtt.mock_calls[0][1][2]["certificate"] == expectedCertificate + assert calls[0][0] == expectedCertificate -async def test_setup_does_not_use_certificate_on_mqtts_port(hass, mock_mqtt): - """Test setup doesn't use bundled certs when ssl set.""" - entry = MockConfigEntry( - domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "port": 8883} - ) - - assert await mqtt.async_setup_entry(hass, entry) - - assert mock_mqtt.called - assert mock_mqtt.mock_calls[0][2]["port"] == 8883 - - import requests.certs - - mqttsCertificateBundle = requests.certs.where() - assert mock_mqtt.mock_calls[0][2]["port"] != mqttsCertificateBundle - - -async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass, mock_mqtt): +async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass): """Test setup defaults to TLSv1 under python3.6.""" - entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"}) + calls = [] - assert await mqtt.async_setup_entry(hass, entry) + def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None): + calls.append((certificate, certfile, keyfile, tls_version)) - assert mock_mqtt.called + with patch("paho.mqtt.client.Client") as mock_client: + mock_client().tls_set = mock_tls_set + entry = MockConfigEntry( + domain=mqtt.DOMAIN, + data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"}, + ) - import sys + assert await mqtt.async_setup_entry(hass, entry) - if sys.hexversion >= 0x03060000: - expectedTlsVersion = ssl.PROTOCOL_TLS # pylint: disable=no-member - else: - expectedTlsVersion = ssl.PROTOCOL_TLSv1 + assert calls - assert mock_mqtt.mock_calls[0][2]["tls_version"] == expectedTlsVersion + import sys + + if sys.hexversion >= 0x03060000: + expectedTlsVersion = ssl.PROTOCOL_TLS # pylint: disable=no-member + else: + expectedTlsVersion = ssl.PROTOCOL_TLSv1 + + assert calls[0][3] == expectedTlsVersion -async def test_setup_with_tls_config_uses_tls_version1_2(hass, mock_mqtt): +async def test_setup_with_tls_config_uses_tls_version1_2(hass): """Test setup uses specified TLS version.""" - entry = MockConfigEntry( - domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "tls_version": "1.2"} - ) + calls = [] - assert await mqtt.async_setup_entry(hass, entry) + def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None): + calls.append((certificate, certfile, keyfile, tls_version)) - assert mock_mqtt.called + with patch("paho.mqtt.client.Client") as mock_client: + mock_client().tls_set = mock_tls_set + entry = MockConfigEntry( + domain=mqtt.DOMAIN, + data={ + "certificate": "auto", + mqtt.CONF_BROKER: "test-broker", + "tls_version": "1.2", + }, + ) - assert mock_mqtt.mock_calls[0][2]["tls_version"] == ssl.PROTOCOL_TLSv1_2 + assert await mqtt.async_setup_entry(hass, entry) + + assert calls + + assert calls[0][3] == ssl.PROTOCOL_TLSv1_2 -async def test_setup_with_tls_config_of_v1_under_python36_only_uses_v1(hass, mock_mqtt): +async def test_setup_with_tls_config_of_v1_under_python36_only_uses_v1(hass): """Test setup uses TLSv1.0 if explicitly chosen.""" - entry = MockConfigEntry( - domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "tls_version": "1.0"} - ) + calls = [] - assert await mqtt.async_setup_entry(hass, entry) + def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None): + calls.append((certificate, certfile, keyfile, tls_version)) - assert mock_mqtt.called - assert mock_mqtt.mock_calls[0][2]["tls_version"] == ssl.PROTOCOL_TLSv1 + with patch("paho.mqtt.client.Client") as mock_client: + mock_client().tls_set = mock_tls_set + entry = MockConfigEntry( + domain=mqtt.DOMAIN, + data={ + "certificate": "auto", + mqtt.CONF_BROKER: "test-broker", + "tls_version": "1.0", + }, + ) + + assert await mqtt.async_setup_entry(hass, entry) + + assert calls + + assert calls[0][3] == ssl.PROTOCOL_TLSv1 @pytest.mark.parametrize( diff --git a/tests/components/mqtt/test_server.py b/tests/components/mqtt/test_server.py index b3320d6aaca..95f61e7c82b 100644 --- a/tests/components/mqtt/test_server.py +++ b/tests/components/mqtt/test_server.py @@ -46,8 +46,8 @@ class TestMQTT: ) self.hass.block_till_done() assert mock_mqtt.called - assert mock_mqtt.mock_calls[1][2]["username"] == "homeassistant" - assert mock_mqtt.mock_calls[1][2]["password"] == password + assert mock_mqtt.mock_calls[1][1][2]["username"] == "homeassistant" + assert mock_mqtt.mock_calls[1][1][2]["password"] == password @patch("passlib.apps.custom_app_context", Mock(return_value="")) @patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock())) @@ -69,8 +69,8 @@ class TestMQTT: ) self.hass.block_till_done() assert mock_mqtt.called - assert mock_mqtt.mock_calls[1][2]["username"] == "homeassistant" - assert mock_mqtt.mock_calls[1][2]["password"] == password + assert mock_mqtt.mock_calls[1][1][2]["username"] == "homeassistant" + assert mock_mqtt.mock_calls[1][1][2]["password"] == password @patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock())) @patch("hbmqtt.broker.Broker.start", return_value=mock_coro()) diff --git a/tests/conftest.py b/tests/conftest.py index 118774eb8e5..a2fa8e8b2fd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -304,8 +304,11 @@ async def mqtt_mock(hass, mqtt_client_mock, mqtt_config): assert result await hass.async_block_till_done() - mqtt_component_mock = MagicMock(spec_set=hass.data["mqtt"], wraps=hass.data["mqtt"]) - hass.data["mqtt"].connected = mqtt_component_mock.connected + mqtt_component_mock = MagicMock( + return_value=hass.data["mqtt"], + spec_set=hass.data["mqtt"], + wraps=hass.data["mqtt"], + ) mqtt_component_mock._mqttc = mqtt_client_mock hass.data["mqtt"] = mqtt_component_mock