Support reconfiguring MQTT config entry (#36537)
This commit is contained in:
parent
ee816ed3dd
commit
747490ab34
13 changed files with 881 additions and 249 deletions
|
@ -19,7 +19,7 @@ DEFAULT_QOS = 0
|
||||||
TRIGGER_SCHEMA = vol.Schema(
|
TRIGGER_SCHEMA = vol.Schema(
|
||||||
{
|
{
|
||||||
vol.Required(CONF_PLATFORM): mqtt.DOMAIN,
|
vol.Required(CONF_PLATFORM): mqtt.DOMAIN,
|
||||||
vol.Required(CONF_TOPIC): mqtt.valid_subscribe_topic,
|
vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic,
|
||||||
vol.Optional(CONF_PAYLOAD): cv.string,
|
vol.Optional(CONF_PAYLOAD): cv.string,
|
||||||
vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string,
|
vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string,
|
||||||
vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All(
|
vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All(
|
||||||
|
|
|
@ -274,7 +274,7 @@ async def system_options_update(hass, connection, msg):
|
||||||
{"type": "config_entries/update", "entry_id": str, vol.Optional("title"): str}
|
{"type": "config_entries/update", "entry_id": str, vol.Optional("title"): str}
|
||||||
)
|
)
|
||||||
async def config_entry_update(hass, connection, msg):
|
async def config_entry_update(hass, connection, msg):
|
||||||
"""Update config entry system options."""
|
"""Update config entry."""
|
||||||
changes = dict(msg)
|
changes = dict(msg)
|
||||||
changes.pop("id")
|
changes.pop("id")
|
||||||
changes.pop("type")
|
changes.pop("type")
|
||||||
|
|
|
@ -12,7 +12,7 @@ import sys
|
||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import requests.certs
|
import certifi
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
|
@ -46,11 +46,20 @@ from . import debug_info, discovery, server
|
||||||
from .const import (
|
from .const import (
|
||||||
ATTR_DISCOVERY_HASH,
|
ATTR_DISCOVERY_HASH,
|
||||||
ATTR_DISCOVERY_TOPIC,
|
ATTR_DISCOVERY_TOPIC,
|
||||||
|
ATTR_PAYLOAD,
|
||||||
|
ATTR_QOS,
|
||||||
|
ATTR_RETAIN,
|
||||||
|
ATTR_TOPIC,
|
||||||
|
CONF_BIRTH_MESSAGE,
|
||||||
CONF_BROKER,
|
CONF_BROKER,
|
||||||
CONF_DISCOVERY,
|
CONF_DISCOVERY,
|
||||||
|
CONF_QOS,
|
||||||
|
CONF_RETAIN,
|
||||||
CONF_STATE_TOPIC,
|
CONF_STATE_TOPIC,
|
||||||
|
CONF_WILL_MESSAGE,
|
||||||
DEFAULT_DISCOVERY,
|
DEFAULT_DISCOVERY,
|
||||||
DEFAULT_QOS,
|
DEFAULT_QOS,
|
||||||
|
DEFAULT_RETAIN,
|
||||||
MQTT_CONNECTED,
|
MQTT_CONNECTED,
|
||||||
MQTT_DISCONNECTED,
|
MQTT_DISCONNECTED,
|
||||||
PROTOCOL_311,
|
PROTOCOL_311,
|
||||||
|
@ -59,6 +68,7 @@ from .debug_info import log_messages
|
||||||
from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash, set_discovery_hash
|
from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash, set_discovery_hash
|
||||||
from .models import Message, MessageCallbackType, PublishPayloadType
|
from .models import Message, MessageCallbackType, PublishPayloadType
|
||||||
from .subscription import async_subscribe_topics, async_unsubscribe_topics
|
from .subscription import async_subscribe_topics, async_unsubscribe_topics
|
||||||
|
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -80,17 +90,12 @@ CONF_CLIENT_CERT = "client_cert"
|
||||||
CONF_TLS_INSECURE = "tls_insecure"
|
CONF_TLS_INSECURE = "tls_insecure"
|
||||||
CONF_TLS_VERSION = "tls_version"
|
CONF_TLS_VERSION = "tls_version"
|
||||||
|
|
||||||
CONF_BIRTH_MESSAGE = "birth_message"
|
|
||||||
CONF_WILL_MESSAGE = "will_message"
|
|
||||||
|
|
||||||
CONF_COMMAND_TOPIC = "command_topic"
|
CONF_COMMAND_TOPIC = "command_topic"
|
||||||
CONF_AVAILABILITY_TOPIC = "availability_topic"
|
CONF_AVAILABILITY_TOPIC = "availability_topic"
|
||||||
CONF_PAYLOAD_AVAILABLE = "payload_available"
|
CONF_PAYLOAD_AVAILABLE = "payload_available"
|
||||||
CONF_PAYLOAD_NOT_AVAILABLE = "payload_not_available"
|
CONF_PAYLOAD_NOT_AVAILABLE = "payload_not_available"
|
||||||
CONF_JSON_ATTRS_TOPIC = "json_attributes_topic"
|
CONF_JSON_ATTRS_TOPIC = "json_attributes_topic"
|
||||||
CONF_JSON_ATTRS_TEMPLATE = "json_attributes_template"
|
CONF_JSON_ATTRS_TEMPLATE = "json_attributes_template"
|
||||||
CONF_QOS = "qos"
|
|
||||||
CONF_RETAIN = "retain"
|
|
||||||
|
|
||||||
CONF_UNIQUE_ID = "unique_id"
|
CONF_UNIQUE_ID = "unique_id"
|
||||||
CONF_IDENTIFIERS = "identifiers"
|
CONF_IDENTIFIERS = "identifiers"
|
||||||
|
@ -105,18 +110,13 @@ PROTOCOL_31 = "3.1"
|
||||||
|
|
||||||
DEFAULT_PORT = 1883
|
DEFAULT_PORT = 1883
|
||||||
DEFAULT_KEEPALIVE = 60
|
DEFAULT_KEEPALIVE = 60
|
||||||
DEFAULT_RETAIN = False
|
|
||||||
DEFAULT_PROTOCOL = PROTOCOL_311
|
DEFAULT_PROTOCOL = PROTOCOL_311
|
||||||
DEFAULT_DISCOVERY_PREFIX = "homeassistant"
|
DEFAULT_DISCOVERY_PREFIX = "homeassistant"
|
||||||
DEFAULT_TLS_PROTOCOL = "auto"
|
DEFAULT_TLS_PROTOCOL = "auto"
|
||||||
DEFAULT_PAYLOAD_AVAILABLE = "online"
|
DEFAULT_PAYLOAD_AVAILABLE = "online"
|
||||||
DEFAULT_PAYLOAD_NOT_AVAILABLE = "offline"
|
DEFAULT_PAYLOAD_NOT_AVAILABLE = "offline"
|
||||||
|
|
||||||
ATTR_TOPIC = "topic"
|
|
||||||
ATTR_PAYLOAD = "payload"
|
|
||||||
ATTR_PAYLOAD_TEMPLATE = "payload_template"
|
ATTR_PAYLOAD_TEMPLATE = "payload_template"
|
||||||
ATTR_QOS = CONF_QOS
|
|
||||||
ATTR_RETAIN = CONF_RETAIN
|
|
||||||
|
|
||||||
MAX_RECONNECT_WAIT = 300 # seconds
|
MAX_RECONNECT_WAIT = 300 # seconds
|
||||||
|
|
||||||
|
@ -125,59 +125,6 @@ CONNECTION_FAILED = "connection_failed"
|
||||||
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
|
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
|
||||||
|
|
||||||
|
|
||||||
def valid_topic(value: Any) -> str:
|
|
||||||
"""Validate that this is a valid topic name/filter."""
|
|
||||||
value = cv.string(value)
|
|
||||||
try:
|
|
||||||
raw_value = value.encode("utf-8")
|
|
||||||
except UnicodeError:
|
|
||||||
raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.")
|
|
||||||
if not raw_value:
|
|
||||||
raise vol.Invalid("MQTT topic name/filter must not be empty.")
|
|
||||||
if len(raw_value) > 65535:
|
|
||||||
raise vol.Invalid(
|
|
||||||
"MQTT topic name/filter must not be longer than 65535 encoded bytes."
|
|
||||||
)
|
|
||||||
if "\0" in value:
|
|
||||||
raise vol.Invalid("MQTT topic name/filter must not contain null character.")
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def valid_subscribe_topic(value: Any) -> str:
|
|
||||||
"""Validate that we can subscribe using this MQTT topic."""
|
|
||||||
value = valid_topic(value)
|
|
||||||
for i in (i for i, c in enumerate(value) if c == "+"):
|
|
||||||
if (i > 0 and value[i - 1] != "/") or (
|
|
||||||
i < len(value) - 1 and value[i + 1] != "/"
|
|
||||||
):
|
|
||||||
raise vol.Invalid(
|
|
||||||
"Single-level wildcard must occupy an entire level of the filter"
|
|
||||||
)
|
|
||||||
|
|
||||||
index = value.find("#")
|
|
||||||
if index != -1:
|
|
||||||
if index != len(value) - 1:
|
|
||||||
# If there are multiple wildcards, this will also trigger
|
|
||||||
raise vol.Invalid(
|
|
||||||
"Multi-level wildcard must be the last "
|
|
||||||
"character in the topic filter."
|
|
||||||
)
|
|
||||||
if len(value) > 1 and value[index - 1] != "/":
|
|
||||||
raise vol.Invalid(
|
|
||||||
"Multi-level wildcard must be after a topic level separator."
|
|
||||||
)
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def valid_publish_topic(value: Any) -> str:
|
|
||||||
"""Validate that we can publish using this MQTT topic."""
|
|
||||||
value = valid_topic(value)
|
|
||||||
if "+" in value or "#" in value:
|
|
||||||
raise vol.Invalid("Wildcards can not be used in topic names")
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType:
|
def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType:
|
||||||
"""Validate that a device info entry has at least one identifying value."""
|
"""Validate that a device info entry has at least one identifying value."""
|
||||||
if not value.get(CONF_IDENTIFIERS) and not value.get(CONF_CONNECTIONS):
|
if not value.get(CONF_IDENTIFIERS) and not value.get(CONF_CONNECTIONS):
|
||||||
|
@ -188,8 +135,6 @@ def validate_device_has_at_least_one_identifier(value: ConfigType) -> ConfigType
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
|
|
||||||
|
|
||||||
CLIENT_KEY_AUTH_MSG = (
|
CLIENT_KEY_AUTH_MSG = (
|
||||||
"client_key and client_cert must both be present in "
|
"client_key and client_cert must both be present in "
|
||||||
"the MQTT broker configuration"
|
"the MQTT broker configuration"
|
||||||
|
@ -554,6 +499,11 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_config(entry, conf):
|
||||||
|
"""Merge configuration.yaml config with config entry."""
|
||||||
|
return {**conf, **entry.data}
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(hass, entry):
|
async def async_setup_entry(hass, entry):
|
||||||
"""Load a config entry."""
|
"""Load a config entry."""
|
||||||
conf = hass.data.get(DATA_MQTT_CONFIG)
|
conf = hass.data.get(DATA_MQTT_CONFIG)
|
||||||
|
@ -574,76 +524,9 @@ async def async_setup_entry(hass, entry):
|
||||||
entry.data,
|
entry.data,
|
||||||
)
|
)
|
||||||
|
|
||||||
conf.update(entry.data)
|
conf = _merge_config(entry, conf)
|
||||||
|
|
||||||
broker = conf[CONF_BROKER]
|
hass.data[DATA_MQTT] = MQTT(hass, entry, conf,)
|
||||||
port = conf[CONF_PORT]
|
|
||||||
client_id = conf.get(CONF_CLIENT_ID)
|
|
||||||
keepalive = conf[CONF_KEEPALIVE]
|
|
||||||
username = conf.get(CONF_USERNAME)
|
|
||||||
password = conf.get(CONF_PASSWORD)
|
|
||||||
certificate = conf.get(CONF_CERTIFICATE)
|
|
||||||
client_key = conf.get(CONF_CLIENT_KEY)
|
|
||||||
client_cert = conf.get(CONF_CLIENT_CERT)
|
|
||||||
tls_insecure = conf.get(CONF_TLS_INSECURE)
|
|
||||||
protocol = conf[CONF_PROTOCOL]
|
|
||||||
|
|
||||||
# For cloudmqtt.com, secured connection, auto fill in certificate
|
|
||||||
if (
|
|
||||||
certificate is None
|
|
||||||
and 19999 < conf[CONF_PORT] < 30000
|
|
||||||
and broker.endswith(".cloudmqtt.com")
|
|
||||||
):
|
|
||||||
certificate = os.path.join(
|
|
||||||
os.path.dirname(__file__), "addtrustexternalcaroot.crt"
|
|
||||||
)
|
|
||||||
|
|
||||||
# When the certificate is set to auto, use bundled certs from requests
|
|
||||||
elif certificate == "auto":
|
|
||||||
certificate = requests.certs.where()
|
|
||||||
|
|
||||||
if CONF_WILL_MESSAGE in conf:
|
|
||||||
will_message = Message(**conf[CONF_WILL_MESSAGE])
|
|
||||||
else:
|
|
||||||
will_message = None
|
|
||||||
|
|
||||||
if CONF_BIRTH_MESSAGE in conf:
|
|
||||||
birth_message = Message(**conf[CONF_BIRTH_MESSAGE])
|
|
||||||
else:
|
|
||||||
birth_message = None
|
|
||||||
|
|
||||||
# Be able to override versions other than TLSv1.0 under Python3.6
|
|
||||||
conf_tls_version: str = conf.get(CONF_TLS_VERSION)
|
|
||||||
if conf_tls_version == "1.2":
|
|
||||||
tls_version = ssl.PROTOCOL_TLSv1_2
|
|
||||||
elif conf_tls_version == "1.1":
|
|
||||||
tls_version = ssl.PROTOCOL_TLSv1_1
|
|
||||||
elif conf_tls_version == "1.0":
|
|
||||||
tls_version = ssl.PROTOCOL_TLSv1
|
|
||||||
else:
|
|
||||||
# Python3.6 supports automatic negotiation of highest TLS version
|
|
||||||
if sys.hexversion >= 0x03060000:
|
|
||||||
tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member
|
|
||||||
else:
|
|
||||||
tls_version = ssl.PROTOCOL_TLSv1
|
|
||||||
|
|
||||||
hass.data[DATA_MQTT] = MQTT(
|
|
||||||
hass,
|
|
||||||
broker=broker,
|
|
||||||
port=port,
|
|
||||||
client_id=client_id,
|
|
||||||
keepalive=keepalive,
|
|
||||||
username=username,
|
|
||||||
password=password,
|
|
||||||
certificate=certificate,
|
|
||||||
client_key=client_key,
|
|
||||||
client_cert=client_cert,
|
|
||||||
tls_insecure=tls_insecure,
|
|
||||||
protocol=protocol,
|
|
||||||
will_message=will_message,
|
|
||||||
birth_message=birth_message,
|
|
||||||
tls_version=tls_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
await hass.data[DATA_MQTT].async_connect()
|
await hass.data[DATA_MQTT].async_connect()
|
||||||
|
|
||||||
|
@ -732,53 +615,101 @@ class Subscription:
|
||||||
class MQTT:
|
class MQTT:
|
||||||
"""Home Assistant MQTT client."""
|
"""Home Assistant MQTT client."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, hass: HomeAssistantType, config_entry, conf,) -> None:
|
||||||
self,
|
|
||||||
hass: HomeAssistantType,
|
|
||||||
broker: str,
|
|
||||||
port: int,
|
|
||||||
client_id: Optional[str],
|
|
||||||
keepalive: Optional[int],
|
|
||||||
username: Optional[str],
|
|
||||||
password: Optional[str],
|
|
||||||
certificate: Optional[str],
|
|
||||||
client_key: Optional[str],
|
|
||||||
client_cert: Optional[str],
|
|
||||||
tls_insecure: Optional[bool],
|
|
||||||
protocol: Optional[str],
|
|
||||||
will_message: Optional[Message],
|
|
||||||
birth_message: Optional[Message],
|
|
||||||
tls_version: Optional[int],
|
|
||||||
) -> None:
|
|
||||||
"""Initialize Home Assistant MQTT client."""
|
"""Initialize Home Assistant MQTT client."""
|
||||||
# We don't import them on the top because some integrations
|
# We don't import on the top because some integrations
|
||||||
# should be able to optionally rely on MQTT.
|
# should be able to optionally rely on MQTT.
|
||||||
# pylint: disable=import-outside-toplevel
|
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||||
import paho.mqtt.client as mqtt
|
|
||||||
|
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.broker = broker
|
self.config_entry = config_entry
|
||||||
self.port = port
|
self.conf = conf
|
||||||
self.keepalive = keepalive
|
|
||||||
self.subscriptions: List[Subscription] = []
|
self.subscriptions: List[Subscription] = []
|
||||||
self.birth_message = birth_message
|
|
||||||
self.connected = False
|
self.connected = False
|
||||||
self._mqttc: mqtt.Client = None
|
self._mqttc: mqtt.Client = None
|
||||||
self._paho_lock = asyncio.Lock()
|
self._paho_lock = asyncio.Lock()
|
||||||
|
|
||||||
if protocol == PROTOCOL_31:
|
self.init_client()
|
||||||
|
self.config_entry.add_update_listener(self.async_config_entry_updated)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def async_config_entry_updated(hass, entry) -> None:
|
||||||
|
"""Handle signals of config entry being updated.
|
||||||
|
|
||||||
|
This is a static method because a class method (bound method), can not be used with weak references.
|
||||||
|
Causes for this is config entry options changing.
|
||||||
|
"""
|
||||||
|
self = hass.data[DATA_MQTT]
|
||||||
|
|
||||||
|
conf = hass.data.get(DATA_MQTT_CONFIG)
|
||||||
|
if conf is None:
|
||||||
|
conf = CONFIG_SCHEMA({DOMAIN: dict(entry.data)})[DOMAIN]
|
||||||
|
|
||||||
|
self.conf = _merge_config(entry, conf)
|
||||||
|
await self.async_disconnect()
|
||||||
|
self.init_client()
|
||||||
|
await self.async_connect()
|
||||||
|
|
||||||
|
await discovery.async_stop(hass)
|
||||||
|
if self.conf.get(CONF_DISCOVERY):
|
||||||
|
await _async_setup_discovery(hass, self.conf, entry)
|
||||||
|
|
||||||
|
def init_client(self):
|
||||||
|
"""Initialize paho client."""
|
||||||
|
# We don't import on the top because some integrations
|
||||||
|
# should be able to optionally rely on MQTT.
|
||||||
|
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
if self.conf[CONF_PROTOCOL] == PROTOCOL_31:
|
||||||
proto: int = mqtt.MQTTv31
|
proto: int = mqtt.MQTTv31
|
||||||
else:
|
else:
|
||||||
proto = mqtt.MQTTv311
|
proto = mqtt.MQTTv311
|
||||||
|
|
||||||
|
client_id = self.conf.get(CONF_CLIENT_ID)
|
||||||
if client_id is None:
|
if client_id is None:
|
||||||
self._mqttc = mqtt.Client(protocol=proto)
|
self._mqttc = mqtt.Client(protocol=proto)
|
||||||
else:
|
else:
|
||||||
self._mqttc = mqtt.Client(client_id, protocol=proto)
|
self._mqttc = mqtt.Client(client_id, protocol=proto)
|
||||||
|
|
||||||
|
username = self.conf.get(CONF_USERNAME)
|
||||||
|
password = self.conf.get(CONF_PASSWORD)
|
||||||
if username is not None:
|
if username is not None:
|
||||||
self._mqttc.username_pw_set(username, password)
|
self._mqttc.username_pw_set(username, password)
|
||||||
|
|
||||||
|
certificate = self.conf.get(CONF_CERTIFICATE)
|
||||||
|
|
||||||
|
# For cloudmqtt.com, secured connection, auto fill in certificate
|
||||||
|
if (
|
||||||
|
certificate is None
|
||||||
|
and 19999 < self.conf[CONF_PORT] < 30000
|
||||||
|
and self.conf[CONF_BROKER].endswith(".cloudmqtt.com")
|
||||||
|
):
|
||||||
|
certificate = os.path.join(
|
||||||
|
os.path.dirname(__file__), "addtrustexternalcaroot.crt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# When the certificate is set to auto, use bundled certs from certifi
|
||||||
|
elif certificate == "auto":
|
||||||
|
certificate = certifi.where()
|
||||||
|
|
||||||
|
# Be able to override versions other than TLSv1.0 under Python3.6
|
||||||
|
conf_tls_version: str = self.conf.get(CONF_TLS_VERSION)
|
||||||
|
if conf_tls_version == "1.2":
|
||||||
|
tls_version = ssl.PROTOCOL_TLSv1_2
|
||||||
|
elif conf_tls_version == "1.1":
|
||||||
|
tls_version = ssl.PROTOCOL_TLSv1_1
|
||||||
|
elif conf_tls_version == "1.0":
|
||||||
|
tls_version = ssl.PROTOCOL_TLSv1
|
||||||
|
else:
|
||||||
|
# Python3.6 supports automatic negotiation of highest TLS version
|
||||||
|
if sys.hexversion >= 0x03060000:
|
||||||
|
tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member
|
||||||
|
else:
|
||||||
|
tls_version = ssl.PROTOCOL_TLSv1
|
||||||
|
|
||||||
|
client_key = self.conf.get(CONF_CLIENT_KEY)
|
||||||
|
client_cert = self.conf.get(CONF_CLIENT_CERT)
|
||||||
|
tls_insecure = self.conf.get(CONF_TLS_INSECURE)
|
||||||
if certificate is not None:
|
if certificate is not None:
|
||||||
self._mqttc.tls_set(
|
self._mqttc.tls_set(
|
||||||
certificate,
|
certificate,
|
||||||
|
@ -794,6 +725,11 @@ class MQTT:
|
||||||
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
||||||
self._mqttc.on_message = self._mqtt_on_message
|
self._mqttc.on_message = self._mqtt_on_message
|
||||||
|
|
||||||
|
if CONF_WILL_MESSAGE in self.conf:
|
||||||
|
will_message = Message(**self.conf[CONF_WILL_MESSAGE])
|
||||||
|
else:
|
||||||
|
will_message = None
|
||||||
|
|
||||||
if will_message is not None:
|
if will_message is not None:
|
||||||
self._mqttc.will_set( # pylint: disable=no-value-for-parameter
|
self._mqttc.will_set( # pylint: disable=no-value-for-parameter
|
||||||
*attr.astuple(
|
*attr.astuple(
|
||||||
|
@ -813,14 +749,17 @@ class MQTT:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_connect(self) -> str:
|
async def async_connect(self) -> str:
|
||||||
"""Connect to the host. Does process messages yet."""
|
"""Connect to the host. Does not process messages yet."""
|
||||||
# pylint: disable=import-outside-toplevel
|
# pylint: disable=import-outside-toplevel
|
||||||
import paho.mqtt.client as mqtt
|
import paho.mqtt.client as mqtt
|
||||||
|
|
||||||
result: int = None
|
result: int = None
|
||||||
try:
|
try:
|
||||||
result = await self.hass.async_add_executor_job(
|
result = await self.hass.async_add_executor_job(
|
||||||
self._mqttc.connect, self.broker, self.port, self.keepalive
|
self._mqttc.connect,
|
||||||
|
self.conf[CONF_BROKER],
|
||||||
|
self.conf[CONF_PORT],
|
||||||
|
self.conf[CONF_KEEPALIVE],
|
||||||
)
|
)
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
_LOGGER.error("Failed to connect to MQTT server due to exception: %s", err)
|
_LOGGER.error("Failed to connect to MQTT server due to exception: %s", err)
|
||||||
|
@ -922,7 +861,12 @@ class MQTT:
|
||||||
|
|
||||||
self.connected = True
|
self.connected = True
|
||||||
dispatcher_send(self.hass, MQTT_CONNECTED)
|
dispatcher_send(self.hass, MQTT_CONNECTED)
|
||||||
_LOGGER.info("Connected to MQTT server (%s)", result_code)
|
_LOGGER.info(
|
||||||
|
"Connected to MQTT server %s:%s (%s)",
|
||||||
|
self.conf[CONF_BROKER],
|
||||||
|
self.conf[CONF_PORT],
|
||||||
|
result_code,
|
||||||
|
)
|
||||||
|
|
||||||
# Group subscriptions to only re-subscribe once for each topic.
|
# Group subscriptions to only re-subscribe once for each topic.
|
||||||
keyfunc = attrgetter("topic")
|
keyfunc = attrgetter("topic")
|
||||||
|
@ -931,11 +875,12 @@ class MQTT:
|
||||||
max_qos = max(subscription.qos for subscription in subs)
|
max_qos = max(subscription.qos for subscription in subs)
|
||||||
self.hass.add_job(self._async_perform_subscription, topic, max_qos)
|
self.hass.add_job(self._async_perform_subscription, topic, max_qos)
|
||||||
|
|
||||||
if self.birth_message:
|
if CONF_BIRTH_MESSAGE in self.conf:
|
||||||
|
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
|
||||||
self.hass.add_job(
|
self.hass.add_job(
|
||||||
self.async_publish( # pylint: disable=no-value-for-parameter
|
self.async_publish( # pylint: disable=no-value-for-parameter
|
||||||
*attr.astuple(
|
*attr.astuple(
|
||||||
self.birth_message,
|
birth_message,
|
||||||
filter=lambda attr, value: attr.name
|
filter=lambda attr, value: attr.name
|
||||||
not in ["subscribed_topic", "timestamp"],
|
not in ["subscribed_topic", "timestamp"],
|
||||||
)
|
)
|
||||||
|
@ -990,7 +935,12 @@ class MQTT:
|
||||||
"""Disconnected callback."""
|
"""Disconnected callback."""
|
||||||
self.connected = False
|
self.connected = False
|
||||||
dispatcher_send(self.hass, MQTT_DISCONNECTED)
|
dispatcher_send(self.hass, MQTT_DISCONNECTED)
|
||||||
_LOGGER.warning("Disconnected from MQTT server (%s)", result_code)
|
_LOGGER.warning(
|
||||||
|
"Disconnected from MQTT server %s:%s (%s)",
|
||||||
|
self.conf[CONF_BROKER],
|
||||||
|
self.conf[CONF_PORT],
|
||||||
|
result_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _raise_on_error(result_code: int) -> None:
|
def _raise_on_error(result_code: int) -> None:
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""Config flow for MQTT."""
|
"""Config flow for MQTT."""
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
import logging
|
||||||
import queue
|
import queue
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -13,7 +14,22 @@ from homeassistant.const import (
|
||||||
CONF_USERNAME,
|
CONF_USERNAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .const import CONF_BROKER, CONF_DISCOVERY, DEFAULT_DISCOVERY
|
from .const import (
|
||||||
|
ATTR_PAYLOAD,
|
||||||
|
ATTR_QOS,
|
||||||
|
ATTR_RETAIN,
|
||||||
|
ATTR_TOPIC,
|
||||||
|
CONF_BIRTH_MESSAGE,
|
||||||
|
CONF_BROKER,
|
||||||
|
CONF_DISCOVERY,
|
||||||
|
CONF_WILL_MESSAGE,
|
||||||
|
DEFAULT_DISCOVERY,
|
||||||
|
DEFAULT_QOS,
|
||||||
|
DEFAULT_RETAIN,
|
||||||
|
)
|
||||||
|
from .util import MQTT_WILL_BIRTH_SCHEMA
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@config_entries.HANDLERS.register("mqtt")
|
@config_entries.HANDLERS.register("mqtt")
|
||||||
|
@ -25,6 +41,11 @@ class FlowHandler(config_entries.ConfigFlow):
|
||||||
|
|
||||||
_hassio_discovery = None
|
_hassio_discovery = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def async_get_options_flow(config_entry):
|
||||||
|
"""Get the options flow for this handler."""
|
||||||
|
return MQTTOptionsFlowHandler(config_entry)
|
||||||
|
|
||||||
async def async_step_user(self, user_input=None):
|
async def async_step_user(self, user_input=None):
|
||||||
"""Handle a flow initialized by the user."""
|
"""Handle a flow initialized by the user."""
|
||||||
if self._async_current_entries():
|
if self._async_current_entries():
|
||||||
|
@ -123,6 +144,163 @@ class FlowHandler(config_entries.ConfigFlow):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
||||||
|
"""Handle MQTT options."""
|
||||||
|
|
||||||
|
def __init__(self, config_entry):
|
||||||
|
"""Initialize MQTT options flow."""
|
||||||
|
self.config_entry = config_entry
|
||||||
|
self.broker_config = {}
|
||||||
|
self.options = dict(config_entry.options)
|
||||||
|
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
"""Manage the MQTT options."""
|
||||||
|
return await self.async_step_broker()
|
||||||
|
|
||||||
|
async def async_step_broker(self, user_input=None):
|
||||||
|
"""Manage the MQTT options."""
|
||||||
|
errors = {}
|
||||||
|
current_config = self.config_entry.data
|
||||||
|
if user_input is not None:
|
||||||
|
can_connect = await self.hass.async_add_executor_job(
|
||||||
|
try_connection,
|
||||||
|
user_input[CONF_BROKER],
|
||||||
|
user_input[CONF_PORT],
|
||||||
|
user_input.get(CONF_USERNAME),
|
||||||
|
user_input.get(CONF_PASSWORD),
|
||||||
|
)
|
||||||
|
|
||||||
|
if can_connect:
|
||||||
|
self.broker_config.update(user_input)
|
||||||
|
return await self.async_step_options()
|
||||||
|
|
||||||
|
errors["base"] = "cannot_connect"
|
||||||
|
|
||||||
|
fields = OrderedDict()
|
||||||
|
fields[vol.Required(CONF_BROKER, default=current_config[CONF_BROKER])] = str
|
||||||
|
fields[vol.Required(CONF_PORT, default=current_config[CONF_PORT])] = vol.Coerce(
|
||||||
|
int
|
||||||
|
)
|
||||||
|
fields[
|
||||||
|
vol.Optional(
|
||||||
|
CONF_USERNAME,
|
||||||
|
description={"suggested_value": current_config.get(CONF_USERNAME)},
|
||||||
|
)
|
||||||
|
] = str
|
||||||
|
fields[
|
||||||
|
vol.Optional(
|
||||||
|
CONF_PASSWORD,
|
||||||
|
description={"suggested_value": current_config.get(CONF_PASSWORD)},
|
||||||
|
)
|
||||||
|
] = str
|
||||||
|
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="broker", data_schema=vol.Schema(fields), errors=errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_step_options(self, user_input=None):
|
||||||
|
"""Manage the MQTT options."""
|
||||||
|
errors = {}
|
||||||
|
current_config = self.config_entry.data
|
||||||
|
options_config = {}
|
||||||
|
if user_input is not None:
|
||||||
|
bad_birth = False
|
||||||
|
bad_will = False
|
||||||
|
|
||||||
|
if "birth_topic" in user_input:
|
||||||
|
birth_message = {
|
||||||
|
ATTR_TOPIC: user_input["birth_topic"],
|
||||||
|
ATTR_PAYLOAD: user_input.get("birth_payload", ""),
|
||||||
|
ATTR_QOS: user_input["birth_qos"],
|
||||||
|
ATTR_RETAIN: user_input["birth_retain"],
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
birth_message = MQTT_WILL_BIRTH_SCHEMA(birth_message)
|
||||||
|
options_config[CONF_BIRTH_MESSAGE] = birth_message
|
||||||
|
except vol.Invalid:
|
||||||
|
errors["base"] = "bad_birth"
|
||||||
|
bad_birth = True
|
||||||
|
|
||||||
|
if "will_topic" in user_input:
|
||||||
|
will_message = {
|
||||||
|
ATTR_TOPIC: user_input["will_topic"],
|
||||||
|
ATTR_PAYLOAD: user_input.get("will_payload", ""),
|
||||||
|
ATTR_QOS: user_input["will_qos"],
|
||||||
|
ATTR_RETAIN: user_input["will_retain"],
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
will_message = MQTT_WILL_BIRTH_SCHEMA(will_message)
|
||||||
|
options_config[CONF_WILL_MESSAGE] = will_message
|
||||||
|
except vol.Invalid:
|
||||||
|
errors["base"] = "bad_will"
|
||||||
|
bad_will = True
|
||||||
|
|
||||||
|
options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY]
|
||||||
|
|
||||||
|
if not bad_birth and not bad_will:
|
||||||
|
updated_config = {}
|
||||||
|
updated_config.update(self.broker_config)
|
||||||
|
updated_config.update(options_config)
|
||||||
|
self.hass.config_entries.async_update_entry(
|
||||||
|
self.config_entry, data=updated_config
|
||||||
|
)
|
||||||
|
return self.async_create_entry(title="", data=None)
|
||||||
|
|
||||||
|
birth_topic = None
|
||||||
|
birth_payload = None
|
||||||
|
birth_qos = DEFAULT_QOS
|
||||||
|
birth_retain = DEFAULT_RETAIN
|
||||||
|
if CONF_BIRTH_MESSAGE in current_config:
|
||||||
|
birth_topic = current_config[CONF_BIRTH_MESSAGE][ATTR_TOPIC]
|
||||||
|
birth_payload = current_config[CONF_BIRTH_MESSAGE][ATTR_PAYLOAD]
|
||||||
|
birth_qos = current_config[CONF_BIRTH_MESSAGE].get(ATTR_QOS, DEFAULT_QOS)
|
||||||
|
birth_retain = current_config[CONF_BIRTH_MESSAGE].get(
|
||||||
|
ATTR_RETAIN, DEFAULT_RETAIN
|
||||||
|
)
|
||||||
|
|
||||||
|
will_topic = None
|
||||||
|
will_payload = None
|
||||||
|
will_qos = DEFAULT_QOS
|
||||||
|
will_retain = DEFAULT_RETAIN
|
||||||
|
if CONF_WILL_MESSAGE in current_config:
|
||||||
|
will_topic = current_config[CONF_WILL_MESSAGE][ATTR_TOPIC]
|
||||||
|
will_payload = current_config[CONF_WILL_MESSAGE][ATTR_PAYLOAD]
|
||||||
|
will_qos = current_config[CONF_WILL_MESSAGE].get(ATTR_QOS, DEFAULT_QOS)
|
||||||
|
will_retain = current_config[CONF_WILL_MESSAGE].get(
|
||||||
|
ATTR_RETAIN, DEFAULT_RETAIN
|
||||||
|
)
|
||||||
|
|
||||||
|
fields = OrderedDict()
|
||||||
|
fields[
|
||||||
|
vol.Optional(
|
||||||
|
CONF_DISCOVERY,
|
||||||
|
default=current_config.get(CONF_DISCOVERY, DEFAULT_DISCOVERY),
|
||||||
|
)
|
||||||
|
] = bool
|
||||||
|
fields[
|
||||||
|
vol.Optional("birth_topic", description={"suggested_value": birth_topic})
|
||||||
|
] = str
|
||||||
|
fields[
|
||||||
|
vol.Optional(
|
||||||
|
"birth_payload", description={"suggested_value": birth_payload}
|
||||||
|
)
|
||||||
|
] = str
|
||||||
|
fields[vol.Optional("birth_qos", default=birth_qos)] = vol.In([0, 1, 2])
|
||||||
|
fields[vol.Optional("birth_retain", default=birth_retain)] = bool
|
||||||
|
fields[
|
||||||
|
vol.Optional("will_topic", description={"suggested_value": will_topic})
|
||||||
|
] = str
|
||||||
|
fields[
|
||||||
|
vol.Optional("will_payload", description={"suggested_value": will_payload})
|
||||||
|
] = str
|
||||||
|
fields[vol.Optional("will_qos", default=will_qos)] = vol.In([0, 1, 2])
|
||||||
|
fields[vol.Optional("will_retain", default=will_retain)] = bool
|
||||||
|
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="options", data_schema=vol.Schema(fields), errors=errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def try_connection(broker, port, username, password, protocol="3.1"):
|
def try_connection(broker, port, username, password, protocol="3.1"):
|
||||||
"""Test if we can connect to an MQTT broker."""
|
"""Test if we can connect to an MQTT broker."""
|
||||||
# pylint: disable=import-outside-toplevel
|
# pylint: disable=import-outside-toplevel
|
||||||
|
|
|
@ -1,14 +1,25 @@
|
||||||
"""Constants used by multiple MQTT modules."""
|
"""Constants used by multiple MQTT modules."""
|
||||||
CONF_BROKER = "broker"
|
|
||||||
CONF_DISCOVERY = "discovery"
|
|
||||||
DEFAULT_DISCOVERY = False
|
|
||||||
|
|
||||||
ATTR_DISCOVERY_HASH = "discovery_hash"
|
ATTR_DISCOVERY_HASH = "discovery_hash"
|
||||||
ATTR_DISCOVERY_PAYLOAD = "discovery_payload"
|
ATTR_DISCOVERY_PAYLOAD = "discovery_payload"
|
||||||
ATTR_DISCOVERY_TOPIC = "discovery_topic"
|
ATTR_DISCOVERY_TOPIC = "discovery_topic"
|
||||||
|
ATTR_PAYLOAD = "payload"
|
||||||
|
ATTR_QOS = "qos"
|
||||||
|
ATTR_RETAIN = "retain"
|
||||||
|
ATTR_TOPIC = "topic"
|
||||||
|
|
||||||
|
CONF_BROKER = "broker"
|
||||||
|
CONF_BIRTH_MESSAGE = "birth_message"
|
||||||
|
CONF_DISCOVERY = "discovery"
|
||||||
|
CONF_QOS = ATTR_QOS
|
||||||
|
CONF_RETAIN = ATTR_RETAIN
|
||||||
CONF_STATE_TOPIC = "state_topic"
|
CONF_STATE_TOPIC = "state_topic"
|
||||||
PROTOCOL_311 = "3.1.1"
|
CONF_WILL_MESSAGE = "will_message"
|
||||||
|
|
||||||
|
DEFAULT_DISCOVERY = False
|
||||||
DEFAULT_QOS = 0
|
DEFAULT_QOS = 0
|
||||||
|
DEFAULT_RETAIN = False
|
||||||
|
|
||||||
MQTT_CONNECTED = "mqtt_connected"
|
MQTT_CONNECTED = "mqtt_connected"
|
||||||
MQTT_DISCONNECTED = "mqtt_disconnected"
|
MQTT_DISCONNECTED = "mqtt_disconnected"
|
||||||
|
|
||||||
|
PROTOCOL_311 = "3.1.1"
|
||||||
|
|
|
@ -35,8 +35,9 @@ SUPPORTED_COMPONENTS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
ALREADY_DISCOVERED = "mqtt_discovered_components"
|
ALREADY_DISCOVERED = "mqtt_discovered_components"
|
||||||
DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock"
|
|
||||||
CONFIG_ENTRY_IS_SETUP = "mqtt_config_entry_is_setup"
|
CONFIG_ENTRY_IS_SETUP = "mqtt_config_entry_is_setup"
|
||||||
|
DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock"
|
||||||
|
DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe"
|
||||||
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
|
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
|
||||||
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
|
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
|
||||||
|
|
||||||
|
@ -163,8 +164,15 @@ async def async_start(
|
||||||
hass.data[DATA_CONFIG_ENTRY_LOCK] = asyncio.Lock()
|
hass.data[DATA_CONFIG_ENTRY_LOCK] = asyncio.Lock()
|
||||||
hass.data[CONFIG_ENTRY_IS_SETUP] = set()
|
hass.data[CONFIG_ENTRY_IS_SETUP] = set()
|
||||||
|
|
||||||
await mqtt.async_subscribe(
|
hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe(
|
||||||
hass, f"{discovery_topic}/#", async_device_message_received, 0
|
hass, f"{discovery_topic}/#", async_device_message_received, 0
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def async_stop(hass: HomeAssistantType) -> bool:
|
||||||
|
"""Stop MQTT Discovery."""
|
||||||
|
if DISCOVERY_UNSUBSCRIBE in hass.data and hass.data[DISCOVERY_UNSUBSCRIBE]:
|
||||||
|
hass.data[DISCOVERY_UNSUBSCRIBE]()
|
||||||
|
hass.data[DISCOVERY_UNSUBSCRIBE] = None
|
||||||
|
|
|
@ -47,5 +47,37 @@
|
||||||
"button_5": "Fifth button",
|
"button_5": "Fifth button",
|
||||||
"button_6": "Sixth button"
|
"button_6": "Sixth button"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"options": {
|
||||||
|
"step": {
|
||||||
|
"broker": {
|
||||||
|
"description": "Please enter the connection information of your MQTT broker.",
|
||||||
|
"data": {
|
||||||
|
"broker": "Broker",
|
||||||
|
"port": "[%key:common::config_flow::data::port%]",
|
||||||
|
"username": "[%key:common::config_flow::data::username%]",
|
||||||
|
"password": "[%key:common::config_flow::data::password%]"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"options": {
|
||||||
|
"description": "Please select MQTT options.",
|
||||||
|
"data": {
|
||||||
|
"discovery": "Enable discovery",
|
||||||
|
"birth_topic": "Birth message topic",
|
||||||
|
"birth_payload": "Birth message payload",
|
||||||
|
"birth_qos": "Birth message QoS",
|
||||||
|
"birth_retain": "Birth message retain",
|
||||||
|
"will_topic": "Will message topic",
|
||||||
|
"will_payload": "Will message payload",
|
||||||
|
"will_qos": "Will message QoS",
|
||||||
|
"will_retain": "Will message retain"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"cannot_connect": "Unable to connect to the broker.",
|
||||||
|
"bad_birth": "Invalid birth topic.",
|
||||||
|
"bad_will": "Invalid will topic."
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
82
homeassistant/components/mqtt/util.py
Normal file
82
homeassistant/components/mqtt/util.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
"""Utility functions for the MQTT integration."""
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.const import CONF_PAYLOAD
|
||||||
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
|
||||||
|
from .const import (
|
||||||
|
ATTR_PAYLOAD,
|
||||||
|
ATTR_QOS,
|
||||||
|
ATTR_RETAIN,
|
||||||
|
ATTR_TOPIC,
|
||||||
|
DEFAULT_QOS,
|
||||||
|
DEFAULT_RETAIN,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def valid_topic(value: Any) -> str:
|
||||||
|
"""Validate that this is a valid topic name/filter."""
|
||||||
|
value = cv.string(value)
|
||||||
|
try:
|
||||||
|
raw_value = value.encode("utf-8")
|
||||||
|
except UnicodeError:
|
||||||
|
raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.")
|
||||||
|
if not raw_value:
|
||||||
|
raise vol.Invalid("MQTT topic name/filter must not be empty.")
|
||||||
|
if len(raw_value) > 65535:
|
||||||
|
raise vol.Invalid(
|
||||||
|
"MQTT topic name/filter must not be longer than 65535 encoded bytes."
|
||||||
|
)
|
||||||
|
if "\0" in value:
|
||||||
|
raise vol.Invalid("MQTT topic name/filter must not contain null character.")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def valid_subscribe_topic(value: Any) -> str:
|
||||||
|
"""Validate that we can subscribe using this MQTT topic."""
|
||||||
|
value = valid_topic(value)
|
||||||
|
for i in (i for i, c in enumerate(value) if c == "+"):
|
||||||
|
if (i > 0 and value[i - 1] != "/") or (
|
||||||
|
i < len(value) - 1 and value[i + 1] != "/"
|
||||||
|
):
|
||||||
|
raise vol.Invalid(
|
||||||
|
"Single-level wildcard must occupy an entire level of the filter"
|
||||||
|
)
|
||||||
|
|
||||||
|
index = value.find("#")
|
||||||
|
if index != -1:
|
||||||
|
if index != len(value) - 1:
|
||||||
|
# If there are multiple wildcards, this will also trigger
|
||||||
|
raise vol.Invalid(
|
||||||
|
"Multi-level wildcard must be the last "
|
||||||
|
"character in the topic filter."
|
||||||
|
)
|
||||||
|
if len(value) > 1 and value[index - 1] != "/":
|
||||||
|
raise vol.Invalid(
|
||||||
|
"Multi-level wildcard must be after a topic level separator."
|
||||||
|
)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def valid_publish_topic(value: Any) -> str:
|
||||||
|
"""Validate that we can publish using this MQTT topic."""
|
||||||
|
value = valid_topic(value)
|
||||||
|
if "+" in value or "#" in value:
|
||||||
|
raise vol.Invalid("Wildcards can not be used in topic names")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
|
||||||
|
|
||||||
|
MQTT_WILL_BIRTH_SCHEMA = vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required(ATTR_TOPIC): valid_publish_topic,
|
||||||
|
vol.Required(ATTR_PAYLOAD, CONF_PAYLOAD): cv.string,
|
||||||
|
vol.Optional(ATTR_QOS, default=DEFAULT_QOS): _VALID_QOS_SCHEMA,
|
||||||
|
vol.Optional(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean,
|
||||||
|
},
|
||||||
|
required=True,
|
||||||
|
)
|
|
@ -1008,7 +1008,8 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
|
||||||
entry = self.hass.config_entries.async_get_entry(flow.handler)
|
entry = self.hass.config_entries.async_get_entry(flow.handler)
|
||||||
if entry is None:
|
if entry is None:
|
||||||
raise UnknownEntry(flow.handler)
|
raise UnknownEntry(flow.handler)
|
||||||
self.hass.config_entries.async_update_entry(entry, options=result["data"])
|
if result["data"] is not None:
|
||||||
|
self.hass.config_entries.async_update_entry(entry, options=result["data"])
|
||||||
|
|
||||||
result["result"] = True
|
result["result"] = True
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
"""Test config flow."""
|
"""Test config flow."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant import data_entry_flow
|
||||||
|
from homeassistant.components import mqtt
|
||||||
|
from homeassistant.components.mqtt.discovery import async_start
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.async_mock import patch
|
from tests.async_mock import patch
|
||||||
|
@ -144,3 +148,340 @@ async def test_hassio_confirm(hass, mock_try_connection, mock_finish_setup):
|
||||||
assert len(mock_try_connection.mock_calls) == 1
|
assert len(mock_try_connection.mock_calls) == 1
|
||||||
# Check config entry got setup
|
# Check config entry got setup
|
||||||
assert len(mock_finish_setup.mock_calls) == 1
|
assert len(mock_finish_setup.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_option_flow(hass, mqtt_mock, mock_try_connection):
|
||||||
|
"""Test config flow options."""
|
||||||
|
mock_try_connection.return_value = True
|
||||||
|
config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
|
||||||
|
await async_start(hass, "homeassistant", config_entry)
|
||||||
|
config_entry.data = {
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
mqtt.CONF_PORT: 1234,
|
||||||
|
}
|
||||||
|
|
||||||
|
mqtt_mock.async_connect.reset_mock()
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result["step_id"] == "broker"
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={
|
||||||
|
mqtt.CONF_BROKER: "another-broker",
|
||||||
|
mqtt.CONF_PORT: 2345,
|
||||||
|
mqtt.CONF_USERNAME: "user",
|
||||||
|
mqtt.CONF_PASSWORD: "pass",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result["step_id"] == "options"
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert mqtt_mock.async_connect.call_count == 0
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={
|
||||||
|
mqtt.CONF_DISCOVERY: True,
|
||||||
|
"birth_topic": "ha_state/online",
|
||||||
|
"birth_payload": "online",
|
||||||
|
"birth_qos": 1,
|
||||||
|
"birth_retain": True,
|
||||||
|
"will_topic": "ha_state/offline",
|
||||||
|
"will_payload": "offline",
|
||||||
|
"will_qos": 2,
|
||||||
|
"will_retain": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
|
assert result["data"] is None
|
||||||
|
assert config_entry.data == {
|
||||||
|
mqtt.CONF_BROKER: "another-broker",
|
||||||
|
mqtt.CONF_PORT: 2345,
|
||||||
|
mqtt.CONF_USERNAME: "user",
|
||||||
|
mqtt.CONF_PASSWORD: "pass",
|
||||||
|
mqtt.CONF_DISCOVERY: True,
|
||||||
|
mqtt.CONF_BIRTH_MESSAGE: {
|
||||||
|
mqtt.ATTR_TOPIC: "ha_state/online",
|
||||||
|
mqtt.ATTR_PAYLOAD: "online",
|
||||||
|
mqtt.ATTR_QOS: 1,
|
||||||
|
mqtt.ATTR_RETAIN: True,
|
||||||
|
},
|
||||||
|
mqtt.CONF_WILL_MESSAGE: {
|
||||||
|
mqtt.ATTR_TOPIC: "ha_state/offline",
|
||||||
|
mqtt.ATTR_PAYLOAD: "offline",
|
||||||
|
mqtt.ATTR_QOS: 2,
|
||||||
|
mqtt.ATTR_RETAIN: True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert mqtt_mock.async_connect.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_default(schema, key):
|
||||||
|
"""Get default value for key in voluptuous schema."""
|
||||||
|
for k in schema.keys():
|
||||||
|
if k == key:
|
||||||
|
if k.default == vol.UNDEFINED:
|
||||||
|
return None
|
||||||
|
return k.default()
|
||||||
|
|
||||||
|
|
||||||
|
def get_suggested(schema, key):
|
||||||
|
"""Get suggested value for key in voluptuous schema."""
|
||||||
|
for k in schema.keys():
|
||||||
|
if k == key:
|
||||||
|
if k.description is None or "suggested_value" not in k.description:
|
||||||
|
return None
|
||||||
|
return k.description["suggested_value"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_option_flow_default_suggested_values(
|
||||||
|
hass, mqtt_mock, mock_try_connection
|
||||||
|
):
|
||||||
|
"""Test config flow options has default/suggested values."""
|
||||||
|
mock_try_connection.return_value = True
|
||||||
|
config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
|
||||||
|
await async_start(hass, "homeassistant", config_entry)
|
||||||
|
config_entry.data = {
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
mqtt.CONF_PORT: 1234,
|
||||||
|
mqtt.CONF_USERNAME: "user",
|
||||||
|
mqtt.CONF_PASSWORD: "pass",
|
||||||
|
mqtt.CONF_DISCOVERY: True,
|
||||||
|
mqtt.CONF_BIRTH_MESSAGE: {
|
||||||
|
mqtt.ATTR_TOPIC: "ha_state/online",
|
||||||
|
mqtt.ATTR_PAYLOAD: "online",
|
||||||
|
mqtt.ATTR_QOS: 1,
|
||||||
|
mqtt.ATTR_RETAIN: True,
|
||||||
|
},
|
||||||
|
mqtt.CONF_WILL_MESSAGE: {
|
||||||
|
mqtt.ATTR_TOPIC: "ha_state/offline",
|
||||||
|
mqtt.ATTR_PAYLOAD: "offline",
|
||||||
|
mqtt.ATTR_QOS: 2,
|
||||||
|
mqtt.ATTR_RETAIN: False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test default/suggested values from config
|
||||||
|
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result["step_id"] == "broker"
|
||||||
|
defaults = {
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
mqtt.CONF_PORT: 1234,
|
||||||
|
}
|
||||||
|
suggested = {
|
||||||
|
mqtt.CONF_USERNAME: "user",
|
||||||
|
mqtt.CONF_PASSWORD: "pass",
|
||||||
|
}
|
||||||
|
for k, v in defaults.items():
|
||||||
|
assert get_default(result["data_schema"].schema, k) == v
|
||||||
|
for k, v in suggested.items():
|
||||||
|
assert get_suggested(result["data_schema"].schema, k) == v
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={
|
||||||
|
mqtt.CONF_BROKER: "another-broker",
|
||||||
|
mqtt.CONF_PORT: 2345,
|
||||||
|
mqtt.CONF_USERNAME: "us3r",
|
||||||
|
mqtt.CONF_PASSWORD: "p4ss",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result["step_id"] == "options"
|
||||||
|
defaults = {
|
||||||
|
mqtt.CONF_DISCOVERY: True,
|
||||||
|
"birth_qos": 1,
|
||||||
|
"birth_retain": True,
|
||||||
|
"will_qos": 2,
|
||||||
|
"will_retain": False,
|
||||||
|
}
|
||||||
|
suggested = {
|
||||||
|
"birth_topic": "ha_state/online",
|
||||||
|
"birth_payload": "online",
|
||||||
|
"will_topic": "ha_state/offline",
|
||||||
|
"will_payload": "offline",
|
||||||
|
}
|
||||||
|
for k, v in defaults.items():
|
||||||
|
assert get_default(result["data_schema"].schema, k) == v
|
||||||
|
for k, v in suggested.items():
|
||||||
|
assert get_suggested(result["data_schema"].schema, k) == v
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={
|
||||||
|
mqtt.CONF_DISCOVERY: False,
|
||||||
|
"birth_topic": "ha_state/onl1ne",
|
||||||
|
"birth_payload": "onl1ne",
|
||||||
|
"birth_qos": 2,
|
||||||
|
"birth_retain": False,
|
||||||
|
"will_topic": "ha_state/offl1ne",
|
||||||
|
"will_payload": "offl1ne",
|
||||||
|
"will_qos": 1,
|
||||||
|
"will_retain": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
|
|
||||||
|
# Test updated default/suggested values from config
|
||||||
|
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result["step_id"] == "broker"
|
||||||
|
defaults = {
|
||||||
|
mqtt.CONF_BROKER: "another-broker",
|
||||||
|
mqtt.CONF_PORT: 2345,
|
||||||
|
}
|
||||||
|
suggested = {
|
||||||
|
mqtt.CONF_USERNAME: "us3r",
|
||||||
|
mqtt.CONF_PASSWORD: "p4ss",
|
||||||
|
}
|
||||||
|
for k, v in defaults.items():
|
||||||
|
assert get_default(result["data_schema"].schema, k) == v
|
||||||
|
for k, v in suggested.items():
|
||||||
|
assert get_suggested(result["data_schema"].schema, k) == v
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345},
|
||||||
|
)
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result["step_id"] == "options"
|
||||||
|
defaults = {
|
||||||
|
mqtt.CONF_DISCOVERY: False,
|
||||||
|
"birth_qos": 2,
|
||||||
|
"birth_retain": False,
|
||||||
|
"will_qos": 1,
|
||||||
|
"will_retain": True,
|
||||||
|
}
|
||||||
|
suggested = {
|
||||||
|
"birth_topic": "ha_state/onl1ne",
|
||||||
|
"birth_payload": "onl1ne",
|
||||||
|
"will_topic": "ha_state/offl1ne",
|
||||||
|
"will_payload": "offl1ne",
|
||||||
|
}
|
||||||
|
for k, v in defaults.items():
|
||||||
|
assert get_default(result["data_schema"].schema, k) == v
|
||||||
|
for k, v in suggested.items():
|
||||||
|
assert get_suggested(result["data_schema"].schema, k) == v
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={
|
||||||
|
mqtt.CONF_DISCOVERY: True,
|
||||||
|
"birth_topic": "ha_state/onl1ne",
|
||||||
|
"birth_payload": "onl1ne",
|
||||||
|
"birth_qos": 2,
|
||||||
|
"birth_retain": False,
|
||||||
|
"will_topic": "ha_state/offl1ne",
|
||||||
|
"will_payload": "offl1ne",
|
||||||
|
"will_qos": 1,
|
||||||
|
"will_retain": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
|
|
||||||
|
|
||||||
|
async def test_options_user_connection_fails(hass, mock_try_connection):
|
||||||
|
"""Test if connection cannot be made."""
|
||||||
|
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
config_entry.data = {
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
mqtt.CONF_PORT: 1234,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_try_connection.return_value = False
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||||
|
assert result["type"] == "form"
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={mqtt.CONF_BROKER: "bad-broker", mqtt.CONF_PORT: 2345},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == "form"
|
||||||
|
assert result["errors"]["base"] == "cannot_connect"
|
||||||
|
|
||||||
|
# Check we tried the connection
|
||||||
|
assert len(mock_try_connection.mock_calls) == 1
|
||||||
|
# Check config entry did not update
|
||||||
|
assert config_entry.data == {
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
mqtt.CONF_PORT: 1234,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_options_bad_birth_message_fails(hass, mock_try_connection):
|
||||||
|
"""Test bad birth message."""
|
||||||
|
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
config_entry.data = {
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
mqtt.CONF_PORT: 1234,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_try_connection.return_value = True
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||||
|
assert result["type"] == "form"
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == "form"
|
||||||
|
assert result["step_id"] == "options"
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
result["flow_id"], user_input={"birth_topic": "ha_state/online/#"},
|
||||||
|
)
|
||||||
|
assert result["type"] == "form"
|
||||||
|
assert result["errors"]["base"] == "bad_birth"
|
||||||
|
|
||||||
|
# Check config entry did not update
|
||||||
|
assert config_entry.data == {
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
mqtt.CONF_PORT: 1234,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_options_bad_will_message_fails(hass, mock_try_connection):
|
||||||
|
"""Test bad will message."""
|
||||||
|
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
config_entry.data = {
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
mqtt.CONF_PORT: 1234,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_try_connection.return_value = True
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||||
|
assert result["type"] == "form"
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={mqtt.CONF_BROKER: "another-broker", mqtt.CONF_PORT: 2345},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == "form"
|
||||||
|
assert result["step_id"] == "options"
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
result["flow_id"], user_input={"will_topic": "ha_state/offline/#"},
|
||||||
|
)
|
||||||
|
assert result["type"] == "form"
|
||||||
|
assert result["errors"]["base"] == "bad_will"
|
||||||
|
|
||||||
|
# Check config entry did not update
|
||||||
|
assert config_entry.data == {
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
mqtt.CONF_PORT: 1234,
|
||||||
|
}
|
||||||
|
|
|
@ -171,26 +171,26 @@ def test_validate_topic():
|
||||||
"""Test topic name/filter validation."""
|
"""Test topic name/filter validation."""
|
||||||
# Invalid UTF-8, must not contain U+D800 to U+DFFF.
|
# Invalid UTF-8, must not contain U+D800 to U+DFFF.
|
||||||
with pytest.raises(vol.Invalid):
|
with pytest.raises(vol.Invalid):
|
||||||
mqtt.valid_topic("\ud800")
|
mqtt.util.valid_topic("\ud800")
|
||||||
with pytest.raises(vol.Invalid):
|
with pytest.raises(vol.Invalid):
|
||||||
mqtt.valid_topic("\udfff")
|
mqtt.util.valid_topic("\udfff")
|
||||||
# Topic MUST NOT be empty
|
# Topic MUST NOT be empty
|
||||||
with pytest.raises(vol.Invalid):
|
with pytest.raises(vol.Invalid):
|
||||||
mqtt.valid_topic("")
|
mqtt.util.valid_topic("")
|
||||||
# Topic MUST NOT be longer than 65535 encoded bytes.
|
# Topic MUST NOT be longer than 65535 encoded bytes.
|
||||||
with pytest.raises(vol.Invalid):
|
with pytest.raises(vol.Invalid):
|
||||||
mqtt.valid_topic("ü" * 32768)
|
mqtt.util.valid_topic("ü" * 32768)
|
||||||
# UTF-8 MUST NOT include null character
|
# UTF-8 MUST NOT include null character
|
||||||
with pytest.raises(vol.Invalid):
|
with pytest.raises(vol.Invalid):
|
||||||
mqtt.valid_topic("bad\0one")
|
mqtt.util.valid_topic("bad\0one")
|
||||||
|
|
||||||
# Topics "SHOULD NOT" include these special characters
|
# Topics "SHOULD NOT" include these special characters
|
||||||
# (not MUST NOT, RFC2119). The receiver MAY close the connection.
|
# (not MUST NOT, RFC2119). The receiver MAY close the connection.
|
||||||
mqtt.valid_topic("\u0001")
|
mqtt.util.valid_topic("\u0001")
|
||||||
mqtt.valid_topic("\u001F")
|
mqtt.util.valid_topic("\u001F")
|
||||||
mqtt.valid_topic("\u009F")
|
mqtt.util.valid_topic("\u009F")
|
||||||
mqtt.valid_topic("\u009F")
|
mqtt.util.valid_topic("\u009F")
|
||||||
mqtt.valid_topic("\uffff")
|
mqtt.util.valid_topic("\uffff")
|
||||||
|
|
||||||
|
|
||||||
def test_validate_subscribe_topic():
|
def test_validate_subscribe_topic():
|
||||||
|
@ -587,7 +587,7 @@ async def test_retained_message_on_subscribe_received(
|
||||||
mqtt_client_mock.subscribe.side_effect = side_effect
|
mqtt_client_mock.subscribe.side_effect = side_effect
|
||||||
|
|
||||||
# Fake that the client is connected
|
# Fake that the client is connected
|
||||||
mqtt_mock.connected = True
|
mqtt_mock().connected = True
|
||||||
|
|
||||||
calls_a = MagicMock()
|
calls_a = MagicMock()
|
||||||
await mqtt.async_subscribe(hass, "test/state", calls_a)
|
await mqtt.async_subscribe(hass, "test/state", calls_a)
|
||||||
|
@ -605,7 +605,7 @@ async def test_not_calling_unsubscribe_with_active_subscribers(
|
||||||
):
|
):
|
||||||
"""Test not calling unsubscribe() when other subscribers are active."""
|
"""Test not calling unsubscribe() when other subscribers are active."""
|
||||||
# Fake that the client is connected
|
# Fake that the client is connected
|
||||||
mqtt_mock.connected = True
|
mqtt_mock().connected = True
|
||||||
|
|
||||||
unsub = await mqtt.async_subscribe(hass, "test/state", None)
|
unsub = await mqtt.async_subscribe(hass, "test/state", None)
|
||||||
await mqtt.async_subscribe(hass, "test/state", None)
|
await mqtt.async_subscribe(hass, "test/state", None)
|
||||||
|
@ -620,7 +620,7 @@ async def test_not_calling_unsubscribe_with_active_subscribers(
|
||||||
async def test_restore_subscriptions_on_reconnect(hass, mqtt_client_mock, mqtt_mock):
|
async def test_restore_subscriptions_on_reconnect(hass, mqtt_client_mock, mqtt_mock):
|
||||||
"""Test subscriptions are restored on reconnect."""
|
"""Test subscriptions are restored on reconnect."""
|
||||||
# Fake that the client is connected
|
# Fake that the client is connected
|
||||||
mqtt_mock.connected = True
|
mqtt_mock().connected = True
|
||||||
|
|
||||||
await mqtt.async_subscribe(hass, "test/state", None)
|
await mqtt.async_subscribe(hass, "test/state", None)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
@ -637,7 +637,7 @@ async def test_restore_all_active_subscriptions_on_reconnect(
|
||||||
):
|
):
|
||||||
"""Test active subscriptions are restored correctly on reconnect."""
|
"""Test active subscriptions are restored correctly on reconnect."""
|
||||||
# Fake that the client is connected
|
# Fake that the client is connected
|
||||||
mqtt_mock.connected = True
|
mqtt_mock().connected = True
|
||||||
|
|
||||||
mqtt_client_mock.subscribe.side_effect = (
|
mqtt_client_mock.subscribe.side_effect = (
|
||||||
(0, 1),
|
(0, 1),
|
||||||
|
@ -716,81 +716,107 @@ async def test_setup_raises_ConfigEntryNotReady_if_no_connect_broker(hass, caplo
|
||||||
assert "Failed to connect to MQTT server due to exception:" in caplog.text
|
assert "Failed to connect to MQTT server due to exception:" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_uses_certificate_on_certificate_set_to_auto(hass, mock_mqtt):
|
async def test_setup_uses_certificate_on_certificate_set_to_auto(hass):
|
||||||
"""Test setup uses bundled certs when certificate is set to auto."""
|
"""Test setup uses bundled certs when certificate is set to auto."""
|
||||||
entry = MockConfigEntry(
|
calls = []
|
||||||
domain=mqtt.DOMAIN,
|
|
||||||
data={mqtt.CONF_BROKER: "test-broker", "certificate": "auto"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert await mqtt.async_setup_entry(hass, entry)
|
def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
|
||||||
|
calls.append((certificate, certfile, keyfile, tls_version))
|
||||||
|
|
||||||
assert mock_mqtt.called
|
with patch("paho.mqtt.client.Client") as mock_client:
|
||||||
|
mock_client().tls_set = mock_tls_set
|
||||||
|
entry = MockConfigEntry(
|
||||||
|
domain=mqtt.DOMAIN,
|
||||||
|
data={mqtt.CONF_BROKER: "test-broker", "certificate": "auto"},
|
||||||
|
)
|
||||||
|
|
||||||
import requests.certs
|
assert await mqtt.async_setup_entry(hass, entry)
|
||||||
|
|
||||||
expectedCertificate = requests.certs.where()
|
assert calls
|
||||||
assert mock_mqtt.mock_calls[0][2]["certificate"] == expectedCertificate
|
|
||||||
|
import certifi
|
||||||
|
|
||||||
|
expectedCertificate = certifi.where()
|
||||||
|
# assert mock_mqtt.mock_calls[0][1][2]["certificate"] == expectedCertificate
|
||||||
|
assert calls[0][0] == expectedCertificate
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_does_not_use_certificate_on_mqtts_port(hass, mock_mqtt):
|
async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass):
|
||||||
"""Test setup doesn't use bundled certs when ssl set."""
|
|
||||||
entry = MockConfigEntry(
|
|
||||||
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "port": 8883}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert await mqtt.async_setup_entry(hass, entry)
|
|
||||||
|
|
||||||
assert mock_mqtt.called
|
|
||||||
assert mock_mqtt.mock_calls[0][2]["port"] == 8883
|
|
||||||
|
|
||||||
import requests.certs
|
|
||||||
|
|
||||||
mqttsCertificateBundle = requests.certs.where()
|
|
||||||
assert mock_mqtt.mock_calls[0][2]["port"] != mqttsCertificateBundle
|
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass, mock_mqtt):
|
|
||||||
"""Test setup defaults to TLSv1 under python3.6."""
|
"""Test setup defaults to TLSv1 under python3.6."""
|
||||||
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
|
calls = []
|
||||||
|
|
||||||
assert await mqtt.async_setup_entry(hass, entry)
|
def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
|
||||||
|
calls.append((certificate, certfile, keyfile, tls_version))
|
||||||
|
|
||||||
assert mock_mqtt.called
|
with patch("paho.mqtt.client.Client") as mock_client:
|
||||||
|
mock_client().tls_set = mock_tls_set
|
||||||
|
entry = MockConfigEntry(
|
||||||
|
domain=mqtt.DOMAIN,
|
||||||
|
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"},
|
||||||
|
)
|
||||||
|
|
||||||
import sys
|
assert await mqtt.async_setup_entry(hass, entry)
|
||||||
|
|
||||||
if sys.hexversion >= 0x03060000:
|
assert calls
|
||||||
expectedTlsVersion = ssl.PROTOCOL_TLS # pylint: disable=no-member
|
|
||||||
else:
|
|
||||||
expectedTlsVersion = ssl.PROTOCOL_TLSv1
|
|
||||||
|
|
||||||
assert mock_mqtt.mock_calls[0][2]["tls_version"] == expectedTlsVersion
|
import sys
|
||||||
|
|
||||||
|
if sys.hexversion >= 0x03060000:
|
||||||
|
expectedTlsVersion = ssl.PROTOCOL_TLS # pylint: disable=no-member
|
||||||
|
else:
|
||||||
|
expectedTlsVersion = ssl.PROTOCOL_TLSv1
|
||||||
|
|
||||||
|
assert calls[0][3] == expectedTlsVersion
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_with_tls_config_uses_tls_version1_2(hass, mock_mqtt):
|
async def test_setup_with_tls_config_uses_tls_version1_2(hass):
|
||||||
"""Test setup uses specified TLS version."""
|
"""Test setup uses specified TLS version."""
|
||||||
entry = MockConfigEntry(
|
calls = []
|
||||||
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "tls_version": "1.2"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert await mqtt.async_setup_entry(hass, entry)
|
def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
|
||||||
|
calls.append((certificate, certfile, keyfile, tls_version))
|
||||||
|
|
||||||
assert mock_mqtt.called
|
with patch("paho.mqtt.client.Client") as mock_client:
|
||||||
|
mock_client().tls_set = mock_tls_set
|
||||||
|
entry = MockConfigEntry(
|
||||||
|
domain=mqtt.DOMAIN,
|
||||||
|
data={
|
||||||
|
"certificate": "auto",
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
"tls_version": "1.2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
assert mock_mqtt.mock_calls[0][2]["tls_version"] == ssl.PROTOCOL_TLSv1_2
|
assert await mqtt.async_setup_entry(hass, entry)
|
||||||
|
|
||||||
|
assert calls
|
||||||
|
|
||||||
|
assert calls[0][3] == ssl.PROTOCOL_TLSv1_2
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_with_tls_config_of_v1_under_python36_only_uses_v1(hass, mock_mqtt):
|
async def test_setup_with_tls_config_of_v1_under_python36_only_uses_v1(hass):
|
||||||
"""Test setup uses TLSv1.0 if explicitly chosen."""
|
"""Test setup uses TLSv1.0 if explicitly chosen."""
|
||||||
entry = MockConfigEntry(
|
calls = []
|
||||||
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker", "tls_version": "1.0"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert await mqtt.async_setup_entry(hass, entry)
|
def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None):
|
||||||
|
calls.append((certificate, certfile, keyfile, tls_version))
|
||||||
|
|
||||||
assert mock_mqtt.called
|
with patch("paho.mqtt.client.Client") as mock_client:
|
||||||
assert mock_mqtt.mock_calls[0][2]["tls_version"] == ssl.PROTOCOL_TLSv1
|
mock_client().tls_set = mock_tls_set
|
||||||
|
entry = MockConfigEntry(
|
||||||
|
domain=mqtt.DOMAIN,
|
||||||
|
data={
|
||||||
|
"certificate": "auto",
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
"tls_version": "1.0",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await mqtt.async_setup_entry(hass, entry)
|
||||||
|
|
||||||
|
assert calls
|
||||||
|
|
||||||
|
assert calls[0][3] == ssl.PROTOCOL_TLSv1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
@ -46,8 +46,8 @@ class TestMQTT:
|
||||||
)
|
)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
assert mock_mqtt.called
|
assert mock_mqtt.called
|
||||||
assert mock_mqtt.mock_calls[1][2]["username"] == "homeassistant"
|
assert mock_mqtt.mock_calls[1][1][2]["username"] == "homeassistant"
|
||||||
assert mock_mqtt.mock_calls[1][2]["password"] == password
|
assert mock_mqtt.mock_calls[1][1][2]["password"] == password
|
||||||
|
|
||||||
@patch("passlib.apps.custom_app_context", Mock(return_value=""))
|
@patch("passlib.apps.custom_app_context", Mock(return_value=""))
|
||||||
@patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock()))
|
@patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock()))
|
||||||
|
@ -69,8 +69,8 @@ class TestMQTT:
|
||||||
)
|
)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
assert mock_mqtt.called
|
assert mock_mqtt.called
|
||||||
assert mock_mqtt.mock_calls[1][2]["username"] == "homeassistant"
|
assert mock_mqtt.mock_calls[1][1][2]["username"] == "homeassistant"
|
||||||
assert mock_mqtt.mock_calls[1][2]["password"] == password
|
assert mock_mqtt.mock_calls[1][1][2]["password"] == password
|
||||||
|
|
||||||
@patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock()))
|
@patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock()))
|
||||||
@patch("hbmqtt.broker.Broker.start", return_value=mock_coro())
|
@patch("hbmqtt.broker.Broker.start", return_value=mock_coro())
|
||||||
|
|
|
@ -304,8 +304,11 @@ async def mqtt_mock(hass, mqtt_client_mock, mqtt_config):
|
||||||
assert result
|
assert result
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
mqtt_component_mock = MagicMock(spec_set=hass.data["mqtt"], wraps=hass.data["mqtt"])
|
mqtt_component_mock = MagicMock(
|
||||||
hass.data["mqtt"].connected = mqtt_component_mock.connected
|
return_value=hass.data["mqtt"],
|
||||||
|
spec_set=hass.data["mqtt"],
|
||||||
|
wraps=hass.data["mqtt"],
|
||||||
|
)
|
||||||
mqtt_component_mock._mqttc = mqtt_client_mock
|
mqtt_component_mock._mqttc = mqtt_client_mock
|
||||||
|
|
||||||
hass.data["mqtt"] = mqtt_component_mock
|
hass.data["mqtt"] = mqtt_component_mock
|
||||||
|
|
Loading…
Add table
Reference in a new issue