Move imports in mqtt component (#27835)

* move imports to top-level in mqtt server

* move imports to top-level in mqtt configflow

* move imports to top-level in mqtt init

* move imports to top-level in mqtt vacuum

* move imports to top-level in mqtt light
This commit is contained in:
Malte Franken 2019-10-18 11:04:27 +11:00 committed by Paulus Schoutsen
parent 7637ceb880
commit 1a5b4c105a
19 changed files with 130 additions and 144 deletions

View file

@ -1,5 +1,6 @@
"""Support for MQTT message handling.""" """Support for MQTT message handling."""
import asyncio import asyncio
import sys
from functools import partial, wraps from functools import partial, wraps
import inspect import inspect
from itertools import groupby from itertools import groupby
@ -15,6 +16,8 @@ from typing import Any, Callable, List, Optional, Union
import attr import attr
import requests.certs import requests.certs
import voluptuous as vol import voluptuous as vol
import paho.mqtt.client as mqtt
from paho.mqtt.matcher import MQTTMatcher
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
@ -36,6 +39,7 @@ from homeassistant.exceptions import (
ConfigEntryNotReady, ConfigEntryNotReady,
) )
from homeassistant.helpers import config_validation as cv, template from homeassistant.helpers import config_validation as cv, template
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, ServiceDataType from homeassistant.helpers.typing import ConfigType, HomeAssistantType, ServiceDataType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
@ -50,7 +54,12 @@ from .const import (
DEFAULT_DISCOVERY, DEFAULT_DISCOVERY,
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_HASH,
PROTOCOL_311,
DEFAULT_QOS,
) )
from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash
from .models import PublishPayloadType, Message, MessageCallbackType
from .subscription import async_subscribe_topics, async_unsubscribe_topics
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -95,11 +104,9 @@ CONF_VIA_DEVICE = "via_device"
CONF_DEPRECATED_VIA_HUB = "via_hub" CONF_DEPRECATED_VIA_HUB = "via_hub"
PROTOCOL_31 = "3.1" PROTOCOL_31 = "3.1"
PROTOCOL_311 = "3.1.1"
DEFAULT_PORT = 1883 DEFAULT_PORT = 1883
DEFAULT_KEEPALIVE = 60 DEFAULT_KEEPALIVE = 60
DEFAULT_QOS = 0
DEFAULT_RETAIN = False DEFAULT_RETAIN = False
DEFAULT_PROTOCOL = PROTOCOL_311 DEFAULT_PROTOCOL = PROTOCOL_311
DEFAULT_DISCOVERY_PREFIX = "homeassistant" DEFAULT_DISCOVERY_PREFIX = "homeassistant"
@ -329,23 +336,9 @@ MQTT_PUBLISH_SCHEMA = vol.Schema(
# pylint: disable=invalid-name # pylint: disable=invalid-name
PublishPayloadType = Union[str, bytes, int, float, None]
SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None
@attr.s(slots=True, frozen=True)
class Message:
"""MQTT Message."""
topic = attr.ib(type=str)
payload = attr.ib(type=PublishPayloadType)
qos = attr.ib(type=int)
retain = attr.ib(type=bool)
MessageCallbackType = Callable[[Message], None]
def _build_publish_data(topic: Any, qos: int, retain: bool) -> ServiceDataType: def _build_publish_data(topic: Any, qos: int, retain: bool) -> ServiceDataType:
"""Build the arguments for the publish service without the payload.""" """Build the arguments for the publish service without the payload."""
data = {ATTR_TOPIC: topic} data = {ATTR_TOPIC: topic}
@ -629,8 +622,6 @@ async def async_setup_entry(hass, entry):
elif conf_tls_version == "1.0": elif conf_tls_version == "1.0":
tls_version = ssl.PROTOCOL_TLSv1 tls_version = ssl.PROTOCOL_TLSv1
else: else:
import sys
# Python3.6 supports automatic negotiation of highest TLS version # Python3.6 supports automatic negotiation of highest TLS version
if sys.hexversion >= 0x03060000: if sys.hexversion >= 0x03060000:
tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member
@ -735,8 +726,6 @@ class MQTT:
tls_version: Optional[int], tls_version: Optional[int],
) -> None: ) -> None:
"""Initialize Home Assistant MQTT client.""" """Initialize Home Assistant MQTT client."""
import paho.mqtt.client as mqtt
self.hass = hass self.hass = hass
self.broker = broker self.broker = broker
self.port = port self.port = port
@ -808,8 +797,6 @@ class MQTT:
return CONNECTION_FAILED_RECOVERABLE return CONNECTION_FAILED_RECOVERABLE
if result != 0: if result != 0:
import paho.mqtt.client as mqtt
_LOGGER.error("Failed to connect: %s", mqtt.error_string(result)) _LOGGER.error("Failed to connect: %s", mqtt.error_string(result))
return CONNECTION_FAILED return CONNECTION_FAILED
@ -891,8 +878,6 @@ class MQTT:
Resubscribe to all topics we were subscribed to and publish birth Resubscribe to all topics we were subscribed to and publish birth
message. message.
""" """
import paho.mqtt.client as mqtt
if result_code != mqtt.CONNACK_ACCEPTED: if result_code != mqtt.CONNACK_ACCEPTED:
_LOGGER.error( _LOGGER.error(
"Unable to connect to the MQTT broker: %s", "Unable to connect to the MQTT broker: %s",
@ -984,8 +969,6 @@ class MQTT:
def _raise_on_error(result_code: int) -> None: def _raise_on_error(result_code: int) -> None:
"""Raise error if error result.""" """Raise error if error result."""
if result_code != 0: if result_code != 0:
import paho.mqtt.client as mqtt
raise HomeAssistantError( raise HomeAssistantError(
"Error talking to MQTT: {}".format(mqtt.error_string(result_code)) "Error talking to MQTT: {}".format(mqtt.error_string(result_code))
) )
@ -993,8 +976,6 @@ def _raise_on_error(result_code: int) -> None:
def _match_topic(subscription: str, topic: str) -> bool: def _match_topic(subscription: str, topic: str) -> bool:
"""Test if topic matches subscription.""" """Test if topic matches subscription."""
from paho.mqtt.matcher import MQTTMatcher
matcher = MQTTMatcher() matcher = MQTTMatcher()
matcher[subscription] = True matcher[subscription] = True
try: try:
@ -1028,8 +1009,6 @@ class MqttAttributes(Entity):
async def _attributes_subscribe_topics(self): async def _attributes_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
from .subscription import async_subscribe_topics
attr_tpl = self._attributes_config.get(CONF_JSON_ATTRS_TEMPLATE) attr_tpl = self._attributes_config.get(CONF_JSON_ATTRS_TEMPLATE)
if attr_tpl is not None: if attr_tpl is not None:
attr_tpl.hass = self.hass attr_tpl.hass = self.hass
@ -1065,8 +1044,6 @@ class MqttAttributes(Entity):
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
"""Unsubscribe when removed.""" """Unsubscribe when removed."""
from .subscription import async_unsubscribe_topics
self._attributes_sub_state = await async_unsubscribe_topics( self._attributes_sub_state = await async_unsubscribe_topics(
self.hass, self._attributes_sub_state self.hass, self._attributes_sub_state
) )
@ -1102,7 +1079,6 @@ class MqttAvailability(Entity):
async def _availability_subscribe_topics(self): async def _availability_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
from .subscription import async_subscribe_topics
@callback @callback
def availability_message_received(msg: Message) -> None: def availability_message_received(msg: Message) -> None:
@ -1128,8 +1104,6 @@ class MqttAvailability(Entity):
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
"""Unsubscribe when removed.""" """Unsubscribe when removed."""
from .subscription import async_unsubscribe_topics
self._availability_sub_state = await async_unsubscribe_topics( self._availability_sub_state = await async_unsubscribe_topics(
self.hass, self._availability_sub_state self.hass, self._availability_sub_state
) )
@ -1154,9 +1128,6 @@ class MqttDiscoveryUpdate(Entity):
"""Subscribe to discovery updates.""" """Subscribe to discovery updates."""
await super().async_added_to_hass() await super().async_added_to_hass()
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash
@callback @callback
def discovery_callback(payload): def discovery_callback(payload):
"""Handle discovery update.""" """Handle discovery update."""

View file

@ -3,6 +3,7 @@ from collections import OrderedDict
import queue import queue
import voluptuous as vol import voluptuous as vol
import paho.mqtt.client as mqtt
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.const import ( from homeassistant.const import (
@ -125,8 +126,6 @@ class FlowHandler(config_entries.ConfigFlow):
def try_connection(broker, port, username, password, protocol="3.1"): def try_connection(broker, port, username, password, protocol="3.1"):
"""Test if we can connect to an MQTT broker.""" """Test if we can connect to an MQTT broker."""
import paho.mqtt.client as mqtt
if protocol == "3.1": if protocol == "3.1":
proto = mqtt.MQTTv31 proto = mqtt.MQTTv31
else: else:

View file

@ -5,3 +5,5 @@ DEFAULT_DISCOVERY = False
ATTR_DISCOVERY_HASH = "discovery_hash" ATTR_DISCOVERY_HASH = "discovery_hash"
CONF_STATE_TOPIC = "state_topic" CONF_STATE_TOPIC = "state_topic"
PROTOCOL_311 = "3.1.1"
DEFAULT_QOS = 0

View file

@ -16,34 +16,24 @@ from homeassistant.components.mqtt.discovery import (
) )
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.typing import HomeAssistantType, ConfigType from homeassistant.helpers.typing import HomeAssistantType, ConfigType
from .schema import CONF_SCHEMA, MQTT_LIGHT_SCHEMA_SCHEMA
from .schema_basic import PLATFORM_SCHEMA_BASIC, async_setup_entity_basic
from .schema_json import PLATFORM_SCHEMA_JSON, async_setup_entity_json
from .schema_template import PLATFORM_SCHEMA_TEMPLATE, async_setup_entity_template
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
CONF_SCHEMA = "schema"
def validate_mqtt_light(value): def validate_mqtt_light(value):
"""Validate MQTT light schema.""" """Validate MQTT light schema."""
from . import schema_basic
from . import schema_json
from . import schema_template
schemas = { schemas = {
"basic": schema_basic.PLATFORM_SCHEMA_BASIC, "basic": PLATFORM_SCHEMA_BASIC,
"json": schema_json.PLATFORM_SCHEMA_JSON, "json": PLATFORM_SCHEMA_JSON,
"template": schema_template.PLATFORM_SCHEMA_TEMPLATE, "template": PLATFORM_SCHEMA_TEMPLATE,
} }
return schemas[value[CONF_SCHEMA]](value) return schemas[value[CONF_SCHEMA]](value)
MQTT_LIGHT_SCHEMA_SCHEMA = vol.Schema(
{
vol.Optional(CONF_SCHEMA, default="basic"): vol.All(
vol.Lower, vol.Any("basic", "json", "template")
)
}
)
PLATFORM_SCHEMA = vol.All( PLATFORM_SCHEMA = vol.All(
MQTT_LIGHT_SCHEMA_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA), validate_mqtt_light MQTT_LIGHT_SCHEMA_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA), validate_mqtt_light
) )
@ -81,14 +71,10 @@ async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None config, async_add_entities, config_entry=None, discovery_hash=None
): ):
"""Set up a MQTT Light.""" """Set up a MQTT Light."""
from . import schema_basic
from . import schema_json
from . import schema_template
setup_entity = { setup_entity = {
"basic": schema_basic.async_setup_entity_basic, "basic": async_setup_entity_basic,
"json": schema_json.async_setup_entity_json, "json": async_setup_entity_json,
"template": schema_template.async_setup_entity_template, "template": async_setup_entity_template,
} }
await setup_entity[config[CONF_SCHEMA]]( await setup_entity[config[CONF_SCHEMA]](
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_hash

View file

@ -0,0 +1,12 @@
"""Shared schema code."""
import voluptuous as vol
CONF_SCHEMA = "schema"
MQTT_LIGHT_SCHEMA_SCHEMA = vol.Schema(
{
vol.Optional(CONF_SCHEMA, default="basic"): vol.All(
vol.Lower, vol.Any("basic", "json", "template")
)
}
)

View file

@ -56,7 +56,7 @@ from homeassistant.helpers.restore_state import RestoreEntity
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
import homeassistant.util.color as color_util import homeassistant.util.color as color_util
from . import MQTT_LIGHT_SCHEMA_SCHEMA from .schema import MQTT_LIGHT_SCHEMA_SCHEMA
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View file

@ -59,7 +59,7 @@ from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
import homeassistant.util.color as color_util import homeassistant.util.color as color_util
from . import MQTT_LIGHT_SCHEMA_SCHEMA from .schema import MQTT_LIGHT_SCHEMA_SCHEMA
from .schema_basic import CONF_BRIGHTNESS_SCALE from .schema_basic import CONF_BRIGHTNESS_SCALE
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View file

@ -49,7 +49,7 @@ import homeassistant.helpers.config_validation as cv
import homeassistant.util.color as color_util import homeassistant.util.color as color_util
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from . import MQTT_LIGHT_SCHEMA_SCHEMA from .schema import MQTT_LIGHT_SCHEMA_SCHEMA
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View file

@ -0,0 +1,20 @@
"""Modesl used by multiple MQTT modules."""
from typing import Union, Callable
import attr
# pylint: disable=invalid-name
PublishPayloadType = Union[str, bytes, int, float, None]
@attr.s(slots=True, frozen=True)
class Message:
"""MQTT Message."""
topic = attr.ib(type=str)
payload = attr.ib(type=PublishPayloadType)
qos = attr.ib(type=int)
retain = attr.ib(type=bool)
MessageCallbackType = Callable[[Message], None]

View file

@ -4,10 +4,14 @@ import logging
import tempfile import tempfile
import voluptuous as vol import voluptuous as vol
from hbmqtt.broker import Broker, BrokerException
from passlib.apps import custom_app_context
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from .const import PROTOCOL_311
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# None allows custom config to be created through generate_config # None allows custom config to be created through generate_config
@ -33,8 +37,6 @@ def async_start(hass, password, server_config):
This method is a coroutine. This method is a coroutine.
""" """
from hbmqtt.broker import Broker, BrokerException
passwd = tempfile.NamedTemporaryFile() passwd = tempfile.NamedTemporaryFile()
gen_server_config, client_config = generate_config(hass, passwd, password) gen_server_config, client_config = generate_config(hass, passwd, password)
@ -63,8 +65,6 @@ def async_start(hass, password, server_config):
def generate_config(hass, passwd, password): def generate_config(hass, passwd, password):
"""Generate a configuration based on current Home Assistant instance.""" """Generate a configuration based on current Home Assistant instance."""
from . import PROTOCOL_311
config = { config = {
"listeners": { "listeners": {
"default": { "default": {
@ -83,8 +83,6 @@ def generate_config(hass, passwd, password):
username = "homeassistant" username = "homeassistant"
# Encrypt with what hbmqtt uses to verify # Encrypt with what hbmqtt uses to verify
from passlib.apps import custom_app_context
passwd.write( passwd.write(
"homeassistant:{}\n".format(custom_app_context.encrypt(password)).encode( "homeassistant:{}\n".format(custom_app_context.encrypt(password)).encode(
"utf-8" "utf-8"

View file

@ -8,7 +8,8 @@ from homeassistant.components import mqtt
from homeassistant.helpers.typing import HomeAssistantType from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from . import DEFAULT_QOS, MessageCallbackType from .const import DEFAULT_QOS
from .models import MessageCallbackType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View file

@ -15,51 +15,19 @@ from homeassistant.components.mqtt.discovery import (
clear_discovery_hash, clear_discovery_hash,
) )
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from .schema import CONF_SCHEMA, LEGACY, STATE, MQTT_VACUUM_SCHEMA
from .schema_legacy import PLATFORM_SCHEMA_LEGACY, async_setup_entity_legacy
from .schema_state import PLATFORM_SCHEMA_STATE, async_setup_entity_state
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
CONF_SCHEMA = "schema"
LEGACY = "legacy"
STATE = "state"
def validate_mqtt_vacuum(value): def validate_mqtt_vacuum(value):
"""Validate MQTT vacuum schema.""" """Validate MQTT vacuum schema."""
from . import schema_legacy schemas = {LEGACY: PLATFORM_SCHEMA_LEGACY, STATE: PLATFORM_SCHEMA_STATE}
from . import schema_state
schemas = {
LEGACY: schema_legacy.PLATFORM_SCHEMA_LEGACY,
STATE: schema_state.PLATFORM_SCHEMA_STATE,
}
return schemas[value[CONF_SCHEMA]](value) return schemas[value[CONF_SCHEMA]](value)
def services_to_strings(services, service_to_string):
"""Convert SUPPORT_* service bitmask to list of service strings."""
strings = []
for service in service_to_string:
if service & services:
strings.append(service_to_string[service])
return strings
def strings_to_services(strings, string_to_service):
"""Convert service strings to SUPPORT_* service bitmask."""
services = 0
for string in strings:
services |= string_to_service[string]
return services
MQTT_VACUUM_SCHEMA = vol.Schema(
{
vol.Optional(CONF_SCHEMA, default=LEGACY): vol.All(
vol.Lower, vol.Any(LEGACY, STATE)
)
}
)
PLATFORM_SCHEMA = vol.All( PLATFORM_SCHEMA = vol.All(
MQTT_VACUUM_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA), validate_mqtt_vacuum MQTT_VACUUM_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA), validate_mqtt_vacuum
) )
@ -95,13 +63,7 @@ async def _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash=None config, async_add_entities, config_entry, discovery_hash=None
): ):
"""Set up the MQTT vacuum.""" """Set up the MQTT vacuum."""
from . import schema_legacy setup_entity = {LEGACY: async_setup_entity_legacy, STATE: async_setup_entity_state}
from . import schema_state
setup_entity = {
LEGACY: schema_legacy.async_setup_entity_legacy,
STATE: schema_state.async_setup_entity_state,
}
await setup_entity[config[CONF_SCHEMA]]( await setup_entity[config[CONF_SCHEMA]](
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_hash
) )

View file

@ -0,0 +1,31 @@
"""Shared schema code."""
import voluptuous as vol
CONF_SCHEMA = "schema"
LEGACY = "legacy"
STATE = "state"
MQTT_VACUUM_SCHEMA = vol.Schema(
{
vol.Optional(CONF_SCHEMA, default=LEGACY): vol.All(
vol.Lower, vol.Any(LEGACY, STATE)
)
}
)
def services_to_strings(services, service_to_string):
"""Convert SUPPORT_* service bitmask to list of service strings."""
strings = []
for service in service_to_string:
if service & services:
strings.append(service_to_string[service])
return strings
def strings_to_services(strings, string_to_service):
"""Convert service strings to SUPPORT_* service bitmask."""
services = 0
for string in strings:
services |= string_to_service[string]
return services

View file

@ -33,7 +33,7 @@ from homeassistant.components.mqtt import (
subscription, subscription,
) )
from . import MQTT_VACUUM_SCHEMA, services_to_strings, strings_to_services from .schema import MQTT_VACUUM_SCHEMA, services_to_strings, strings_to_services
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View file

@ -46,7 +46,7 @@ from homeassistant.components.mqtt import (
CONF_QOS, CONF_QOS,
) )
from . import MQTT_VACUUM_SCHEMA, services_to_strings, strings_to_services from .schema import MQTT_VACUUM_SCHEMA, services_to_strings, strings_to_services
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View file

@ -27,6 +27,7 @@ from homeassistant.auth import (
) )
from homeassistant.auth.permissions import system_policies from homeassistant.auth.permissions import system_policies
from homeassistant.components import mqtt, recorder from homeassistant.components import mqtt, recorder
from homeassistant.components.mqtt.models import Message
from homeassistant.config import async_process_component_config from homeassistant.config import async_process_component_config
from homeassistant.const import ( from homeassistant.const import (
ATTR_DISCOVERED, ATTR_DISCOVERED,
@ -271,7 +272,7 @@ def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
"""Fire the MQTT message.""" """Fire the MQTT message."""
if isinstance(payload, str): if isinstance(payload, str):
payload = payload.encode("utf-8") payload = payload.encode("utf-8")
msg = mqtt.Message(topic, payload, qos, retain) msg = Message(topic, payload, qos, retain)
hass.data["mqtt"]._mqtt_handle_message(msg) hass.data["mqtt"]._mqtt_handle_message(msg)

View file

@ -5,10 +5,8 @@ import json
from homeassistant.components import mqtt, vacuum from homeassistant.components import mqtt, vacuum
from homeassistant.components.mqtt import CONF_COMMAND_TOPIC from homeassistant.components.mqtt import CONF_COMMAND_TOPIC
from homeassistant.components.mqtt.discovery import async_start from homeassistant.components.mqtt.discovery import async_start
from homeassistant.components.mqtt.vacuum import ( from homeassistant.components.mqtt.vacuum import schema_legacy as mqttvacuum
schema_legacy as mqttvacuum, from homeassistant.components.mqtt.vacuum.schema import services_to_strings
services_to_strings,
)
from homeassistant.components.mqtt.vacuum.schema_legacy import ( from homeassistant.components.mqtt.vacuum.schema_legacy import (
ALL_SERVICES, ALL_SERVICES,
SERVICE_TO_STRING, SERVICE_TO_STRING,
@ -80,7 +78,7 @@ async def test_default_supported_features(hass, mqtt_mock):
async def test_all_commands(hass, mqtt_mock): async def test_all_commands(hass, mqtt_mock):
"""Test simple commands to the vacuum.""" """Test simple commands to the vacuum."""
config = deepcopy(DEFAULT_CONFIG) config = deepcopy(DEFAULT_CONFIG)
config[mqttvacuum.CONF_SUPPORTED_FEATURES] = mqttvacuum.services_to_strings( config[mqttvacuum.CONF_SUPPORTED_FEATURES] = services_to_strings(
ALL_SERVICES, SERVICE_TO_STRING ALL_SERVICES, SERVICE_TO_STRING
) )
@ -221,7 +219,7 @@ async def test_attributes_without_supported_features(hass, mqtt_mock):
async def test_status(hass, mqtt_mock): async def test_status(hass, mqtt_mock):
"""Test status updates from the vacuum.""" """Test status updates from the vacuum."""
config = deepcopy(DEFAULT_CONFIG) config = deepcopy(DEFAULT_CONFIG)
config[mqttvacuum.CONF_SUPPORTED_FEATURES] = mqttvacuum.services_to_strings( config[mqttvacuum.CONF_SUPPORTED_FEATURES] = services_to_strings(
ALL_SERVICES, SERVICE_TO_STRING ALL_SERVICES, SERVICE_TO_STRING
) )
@ -260,7 +258,7 @@ async def test_status(hass, mqtt_mock):
async def test_status_battery(hass, mqtt_mock): async def test_status_battery(hass, mqtt_mock):
"""Test status updates from the vacuum.""" """Test status updates from the vacuum."""
config = deepcopy(DEFAULT_CONFIG) config = deepcopy(DEFAULT_CONFIG)
config[mqttvacuum.CONF_SUPPORTED_FEATURES] = mqttvacuum.services_to_strings( config[mqttvacuum.CONF_SUPPORTED_FEATURES] = services_to_strings(
ALL_SERVICES, SERVICE_TO_STRING ALL_SERVICES, SERVICE_TO_STRING
) )
@ -277,7 +275,7 @@ async def test_status_battery(hass, mqtt_mock):
async def test_status_cleaning(hass, mqtt_mock): async def test_status_cleaning(hass, mqtt_mock):
"""Test status updates from the vacuum.""" """Test status updates from the vacuum."""
config = deepcopy(DEFAULT_CONFIG) config = deepcopy(DEFAULT_CONFIG)
config[mqttvacuum.CONF_SUPPORTED_FEATURES] = mqttvacuum.services_to_strings( config[mqttvacuum.CONF_SUPPORTED_FEATURES] = services_to_strings(
ALL_SERVICES, SERVICE_TO_STRING ALL_SERVICES, SERVICE_TO_STRING
) )
@ -294,7 +292,7 @@ async def test_status_cleaning(hass, mqtt_mock):
async def test_status_docked(hass, mqtt_mock): async def test_status_docked(hass, mqtt_mock):
"""Test status updates from the vacuum.""" """Test status updates from the vacuum."""
config = deepcopy(DEFAULT_CONFIG) config = deepcopy(DEFAULT_CONFIG)
config[mqttvacuum.CONF_SUPPORTED_FEATURES] = mqttvacuum.services_to_strings( config[mqttvacuum.CONF_SUPPORTED_FEATURES] = services_to_strings(
ALL_SERVICES, SERVICE_TO_STRING ALL_SERVICES, SERVICE_TO_STRING
) )
@ -311,7 +309,7 @@ async def test_status_docked(hass, mqtt_mock):
async def test_status_charging(hass, mqtt_mock): async def test_status_charging(hass, mqtt_mock):
"""Test status updates from the vacuum.""" """Test status updates from the vacuum."""
config = deepcopy(DEFAULT_CONFIG) config = deepcopy(DEFAULT_CONFIG)
config[mqttvacuum.CONF_SUPPORTED_FEATURES] = mqttvacuum.services_to_strings( config[mqttvacuum.CONF_SUPPORTED_FEATURES] = services_to_strings(
ALL_SERVICES, SERVICE_TO_STRING ALL_SERVICES, SERVICE_TO_STRING
) )
@ -328,7 +326,7 @@ async def test_status_charging(hass, mqtt_mock):
async def test_status_fan_speed(hass, mqtt_mock): async def test_status_fan_speed(hass, mqtt_mock):
"""Test status updates from the vacuum.""" """Test status updates from the vacuum."""
config = deepcopy(DEFAULT_CONFIG) config = deepcopy(DEFAULT_CONFIG)
config[mqttvacuum.CONF_SUPPORTED_FEATURES] = mqttvacuum.services_to_strings( config[mqttvacuum.CONF_SUPPORTED_FEATURES] = services_to_strings(
ALL_SERVICES, SERVICE_TO_STRING ALL_SERVICES, SERVICE_TO_STRING
) )
@ -345,7 +343,7 @@ async def test_status_fan_speed(hass, mqtt_mock):
async def test_status_error(hass, mqtt_mock): async def test_status_error(hass, mqtt_mock):
"""Test status updates from the vacuum.""" """Test status updates from the vacuum."""
config = deepcopy(DEFAULT_CONFIG) config = deepcopy(DEFAULT_CONFIG)
config[mqttvacuum.CONF_SUPPORTED_FEATURES] = mqttvacuum.services_to_strings( config[mqttvacuum.CONF_SUPPORTED_FEATURES] = services_to_strings(
ALL_SERVICES, SERVICE_TO_STRING ALL_SERVICES, SERVICE_TO_STRING
) )
@ -371,7 +369,7 @@ async def test_battery_template(hass, mqtt_mock):
config = deepcopy(DEFAULT_CONFIG) config = deepcopy(DEFAULT_CONFIG)
config.update( config.update(
{ {
mqttvacuum.CONF_SUPPORTED_FEATURES: mqttvacuum.services_to_strings( mqttvacuum.CONF_SUPPORTED_FEATURES: services_to_strings(
ALL_SERVICES, SERVICE_TO_STRING ALL_SERVICES, SERVICE_TO_STRING
), ),
mqttvacuum.CONF_BATTERY_LEVEL_TOPIC: "retroroomba/battery_level", mqttvacuum.CONF_BATTERY_LEVEL_TOPIC: "retroroomba/battery_level",
@ -390,7 +388,7 @@ async def test_battery_template(hass, mqtt_mock):
async def test_status_invalid_json(hass, mqtt_mock): async def test_status_invalid_json(hass, mqtt_mock):
"""Test to make sure nothing breaks if the vacuum sends bad JSON.""" """Test to make sure nothing breaks if the vacuum sends bad JSON."""
config = deepcopy(DEFAULT_CONFIG) config = deepcopy(DEFAULT_CONFIG)
config[mqttvacuum.CONF_SUPPORTED_FEATURES] = mqttvacuum.services_to_strings( config[mqttvacuum.CONF_SUPPORTED_FEATURES] = services_to_strings(
ALL_SERVICES, SERVICE_TO_STRING ALL_SERVICES, SERVICE_TO_STRING
) )

View file

@ -19,9 +19,13 @@ class TestMQTT:
"""Stop everything that was started.""" """Stop everything that was started."""
self.hass.stop() self.hass.stop()
@patch("passlib.apps.custom_app_context", Mock(return_value="")) @patch(
"homeassistant.components.mqtt.server.custom_app_context", Mock(return_value="")
)
@patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock())) @patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock()))
@patch("hbmqtt.broker.Broker", Mock(return_value=MagicMock())) @patch(
"homeassistant.components.mqtt.server.Broker", Mock(return_value=MagicMock())
)
@patch("hbmqtt.broker.Broker.start", Mock(return_value=mock_coro())) @patch("hbmqtt.broker.Broker.start", Mock(return_value=mock_coro()))
@patch("homeassistant.components.mqtt.MQTT") @patch("homeassistant.components.mqtt.MQTT")
def test_creating_config_with_pass_and_no_http_pass(self, mock_mqtt): def test_creating_config_with_pass_and_no_http_pass(self, mock_mqtt):
@ -41,9 +45,13 @@ class TestMQTT:
assert mock_mqtt.mock_calls[1][2]["username"] == "homeassistant" assert mock_mqtt.mock_calls[1][2]["username"] == "homeassistant"
assert mock_mqtt.mock_calls[1][2]["password"] == password assert mock_mqtt.mock_calls[1][2]["password"] == password
@patch("passlib.apps.custom_app_context", Mock(return_value="")) @patch(
"homeassistant.components.mqtt.server.custom_app_context", Mock(return_value="")
)
@patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock())) @patch("tempfile.NamedTemporaryFile", Mock(return_value=MagicMock()))
@patch("hbmqtt.broker.Broker", Mock(return_value=MagicMock())) @patch(
"homeassistant.components.mqtt.server.Broker", Mock(return_value=MagicMock())
)
@patch("hbmqtt.broker.Broker.start", Mock(return_value=mock_coro())) @patch("hbmqtt.broker.Broker.start", Mock(return_value=mock_coro()))
@patch("homeassistant.components.mqtt.MQTT") @patch("homeassistant.components.mqtt.MQTT")
def test_creating_config_with_pass_and_http_pass(self, mock_mqtt): def test_creating_config_with_pass_and_http_pass(self, mock_mqtt):

View file

@ -5,11 +5,8 @@ import json
from homeassistant.components import mqtt, vacuum from homeassistant.components import mqtt, vacuum
from homeassistant.components.mqtt import CONF_COMMAND_TOPIC, CONF_STATE_TOPIC from homeassistant.components.mqtt import CONF_COMMAND_TOPIC, CONF_STATE_TOPIC
from homeassistant.components.mqtt.discovery import async_start from homeassistant.components.mqtt.discovery import async_start
from homeassistant.components.mqtt.vacuum import ( from homeassistant.components.mqtt.vacuum import CONF_SCHEMA, schema_state as mqttvacuum
CONF_SCHEMA, from homeassistant.components.mqtt.vacuum.schema import services_to_strings
schema_state as mqttvacuum,
services_to_strings,
)
from homeassistant.components.mqtt.vacuum.schema_state import SERVICE_TO_STRING from homeassistant.components.mqtt.vacuum.schema_state import SERVICE_TO_STRING
from homeassistant.components.vacuum import ( from homeassistant.components.vacuum import (
ATTR_BATTERY_ICON, ATTR_BATTERY_ICON,
@ -259,7 +256,7 @@ async def test_no_fan_vacuum(hass, mqtt_mock):
async def test_status_invalid_json(hass, mqtt_mock): async def test_status_invalid_json(hass, mqtt_mock):
"""Test to make sure nothing breaks if the vacuum sends bad JSON.""" """Test to make sure nothing breaks if the vacuum sends bad JSON."""
config = deepcopy(DEFAULT_CONFIG) config = deepcopy(DEFAULT_CONFIG)
config[mqttvacuum.CONF_SUPPORTED_FEATURES] = mqttvacuum.services_to_strings( config[mqttvacuum.CONF_SUPPORTED_FEATURES] = services_to_strings(
mqttvacuum.ALL_SERVICES, SERVICE_TO_STRING mqttvacuum.ALL_SERVICES, SERVICE_TO_STRING
) )