From 845bf80e725af8c921915906b0f796c7a8164d11 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Wed, 23 Feb 2022 12:29:32 +0100 Subject: [PATCH] Mqtt improve test coverage (#66279) Co-authored-by: Martin Hjelmare --- homeassistant/components/mqtt/config_flow.py | 4 +- tests/components/mqtt/test_config_flow.py | 59 ++-- tests/components/mqtt/test_init.py | 286 +++++++++++++++++-- 3 files changed, 310 insertions(+), 39 deletions(-) diff --git a/homeassistant/components/mqtt/config_flow.py b/homeassistant/components/mqtt/config_flow.py index 23e2a0d1e81..3f93e50829a 100644 --- a/homeassistant/components/mqtt/config_flow.py +++ b/homeassistant/components/mqtt/config_flow.py @@ -33,6 +33,8 @@ from .const import ( ) from .util import MQTT_WILL_BIRTH_SCHEMA +MQTT_TIMEOUT = 5 + class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Handle a config flow.""" @@ -337,7 +339,7 @@ def try_connection(broker, port, username, password, protocol="3.1"): client.loop_start() try: - return result.get(timeout=5) + return result.get(timeout=MQTT_TIMEOUT) except queue.Empty: return False finally: diff --git a/tests/components/mqtt/test_config_flow.py b/tests/components/mqtt/test_config_flow.py index f16a0e5e83a..d9aab02e821 100644 --- a/tests/components/mqtt/test_config_flow.py +++ b/tests/components/mqtt/test_config_flow.py @@ -1,5 +1,4 @@ """Test config flow.""" - from unittest.mock import patch import pytest @@ -30,6 +29,31 @@ def mock_try_connection(): yield mock_try +@pytest.fixture +def mock_try_connection_success(): + """Mock the try connection method with success.""" + + def loop_start(): + """Simulate connect on loop start.""" + mock_client().on_connect(mock_client, None, None, 0) + + with patch("paho.mqtt.client.Client") as mock_client: + mock_client().loop_start = loop_start + yield mock_client() + + +@pytest.fixture +def mock_try_connection_time_out(): + """Mock the try connection method with a time out.""" + + # Patch prevent waiting 5 sec for a timeout + with patch("paho.mqtt.client.Client") as mock_client, patch( + "homeassistant.components.mqtt.config_flow.MQTT_TIMEOUT", 0 + ): + mock_client().loop_start = lambda *args: 1 + yield mock_client() + + async def test_user_connection_works( hass, mock_try_connection, mock_finish_setup, mqtt_client_mock ): @@ -57,10 +81,10 @@ async def test_user_connection_works( assert len(mock_finish_setup.mock_calls) == 1 -async def test_user_connection_fails(hass, mock_try_connection, mock_finish_setup): +async def test_user_connection_fails( + hass, mock_try_connection_time_out, mock_finish_setup +): """Test if connection cannot be made.""" - mock_try_connection.return_value = False - result = await hass.config_entries.flow.async_init( "mqtt", context={"source": config_entries.SOURCE_USER} ) @@ -74,7 +98,7 @@ async def test_user_connection_fails(hass, mock_try_connection, mock_finish_setu assert result["errors"]["base"] == "cannot_connect" # Check we tried the connection - assert len(mock_try_connection.mock_calls) == 1 + assert len(mock_try_connection_time_out.mock_calls) # Check config entry did not setup assert len(mock_finish_setup.mock_calls) == 0 @@ -163,7 +187,12 @@ async def test_hassio_ignored(hass: HomeAssistant) -> None: result = await hass.config_entries.flow.async_init( mqtt.DOMAIN, data=HassioServiceInfo( - config={"addon": "Mosquitto", "host": "mock-mosquitto", "port": "1883"} + config={ + "addon": "Mosquitto", + "host": "mock-mosquitto", + "port": "1883", + "protocol": "3.1.1", + } ), context={"source": config_entries.SOURCE_HASSIO}, ) @@ -172,9 +201,7 @@ async def test_hassio_ignored(hass: HomeAssistant) -> None: assert result.get("reason") == "already_configured" -async def test_hassio_confirm( - hass, mock_try_connection, mock_finish_setup, mqtt_client_mock -): +async def test_hassio_confirm(hass, mock_try_connection_success, mock_finish_setup): """Test we can finish a config flow.""" mock_try_connection.return_value = True @@ -196,6 +223,7 @@ async def test_hassio_confirm( assert result["step_id"] == "hassio_confirm" assert result["description_placeholders"] == {"addon": "Mock Addon"} + mock_try_connection_success.reset_mock() result = await hass.config_entries.flow.async_configure( result["flow_id"], {"discovery": True} ) @@ -210,7 +238,7 @@ async def test_hassio_confirm( "discovery": True, } # Check we tried the connection - assert len(mock_try_connection.mock_calls) == 1 + assert len(mock_try_connection_success.mock_calls) # Check config entry got setup assert len(mock_finish_setup.mock_calls) == 1 @@ -368,10 +396,9 @@ def get_suggested(schema, key): async def test_option_flow_default_suggested_values( - hass, mqtt_mock, mock_try_connection + hass, mqtt_mock, mock_try_connection_success ): """Test config flow options has default/suggested values.""" - mock_try_connection.return_value = True config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0] config_entry.data = { mqtt.CONF_BROKER: "test-broker", @@ -516,7 +543,7 @@ async def test_option_flow_default_suggested_values( await hass.async_block_till_done() -async def test_options_user_connection_fails(hass, mock_try_connection): +async def test_options_user_connection_fails(hass, mock_try_connection_time_out): """Test if connection cannot be made.""" config_entry = MockConfigEntry(domain=mqtt.DOMAIN) config_entry.add_to_hass(hass) @@ -524,12 +551,10 @@ async def test_options_user_connection_fails(hass, mock_try_connection): mqtt.CONF_BROKER: "test-broker", mqtt.CONF_PORT: 1234, } - - mock_try_connection.return_value = False - result = await hass.config_entries.options.async_init(config_entry.entry_id) assert result["type"] == "form" + mock_try_connection_time_out.reset_mock() result = await hass.config_entries.options.async_configure( result["flow_id"], user_input={mqtt.CONF_BROKER: "bad-broker", mqtt.CONF_PORT: 2345}, @@ -539,7 +564,7 @@ async def test_options_user_connection_fails(hass, mock_try_connection): assert result["errors"]["base"] == "cannot_connect" # Check we tried the connection - assert len(mock_try_connection.mock_calls) == 1 + assert len(mock_try_connection_time_out.mock_calls) # Check config entry did not update assert config_entry.data == { mqtt.CONF_BROKER: "test-broker", diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index a9a96df4f8f..68ba2a040b8 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -1,16 +1,21 @@ """The tests for the MQTT component.""" import asyncio from datetime import datetime, timedelta +from functools import partial import json +import logging import ssl from unittest.mock import ANY, AsyncMock, MagicMock, call, mock_open, patch import pytest import voluptuous as vol +import yaml +from homeassistant import config as hass_config from homeassistant.components import mqtt, websocket_api from homeassistant.components.mqtt import debug_info from homeassistant.components.mqtt.mixins import MQTT_ENTITY_DEVICE_INFO_SCHEMA +from homeassistant.components.mqtt.models import ReceiveMessage from homeassistant.const import ( ATTR_ASSUMED_STATE, EVENT_HOMEASSISTANT_STARTED, @@ -34,6 +39,14 @@ from tests.common import ( ) from tests.testing_config.custom_components.test.sensor import DEVICE_CLASSES +_LOGGER = logging.getLogger(__name__) + + +class RecordCallsPartial(partial): + """Wrapper class for partial.""" + + __name__ = "RecordCallPartialTest" + @pytest.fixture(autouse=True) def mock_storage(hass_storage): @@ -675,6 +688,10 @@ async def test_subscribe_topic(hass, mqtt_mock, calls, record_calls): await hass.async_block_till_done() assert len(calls) == 1 + # Cannot unsubscribe twice + with pytest.raises(HomeAssistantError): + unsub() + async def test_subscribe_topic_non_async(hass, mqtt_mock, calls, record_calls): """Test the subscription of a topic using the non-async function.""" @@ -706,13 +723,13 @@ async def test_subscribe_bad_topic(hass, mqtt_mock, calls, record_calls): async def test_subscribe_deprecated(hass, mqtt_mock): """Test the subscription of a topic using deprecated callback signature.""" - calls = [] @callback def record_calls(topic, payload, qos): """Record calls.""" calls.append((topic, payload, qos)) + calls = [] unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls) async_fire_mqtt_message(hass, "test-topic", "test-payload") @@ -728,17 +745,59 @@ async def test_subscribe_deprecated(hass, mqtt_mock): await hass.async_block_till_done() assert len(calls) == 1 + mqtt_mock.async_publish.reset_mock() + + # Test with partial wrapper + calls = [] + unsub = await mqtt.async_subscribe( + hass, "test-topic", RecordCallsPartial(record_calls) + ) + + async_fire_mqtt_message(hass, "test-topic", "test-payload") + + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0][0] == "test-topic" + assert calls[0][1] == "test-payload" + + unsub() + + async_fire_mqtt_message(hass, "test-topic", "test-payload") + + await hass.async_block_till_done() + assert len(calls) == 1 async def test_subscribe_deprecated_async(hass, mqtt_mock): - """Test the subscription of a topic using deprecated callback signature.""" - calls = [] + """Test the subscription of a topic using deprecated coroutine signature.""" - async def record_calls(topic, payload, qos): + def async_record_calls(topic, payload, qos): """Record calls.""" calls.append((topic, payload, qos)) - unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls) + calls = [] + unsub = await mqtt.async_subscribe(hass, "test-topic", async_record_calls) + + async_fire_mqtt_message(hass, "test-topic", "test-payload") + + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0][0] == "test-topic" + assert calls[0][1] == "test-payload" + + unsub() + + async_fire_mqtt_message(hass, "test-topic", "test-payload") + + await hass.async_block_till_done() + assert len(calls) == 1 + mqtt_mock.async_publish.reset_mock() + + # Test with partial wrapper + calls = [] + unsub = await mqtt.async_subscribe( + hass, "test-topic", RecordCallsPartial(async_record_calls) + ) async_fire_mqtt_message(hass, "test-topic", "test-payload") @@ -1010,9 +1069,9 @@ async def test_restore_subscriptions_on_reconnect(hass, mqtt_client_mock, mqtt_m await hass.async_block_till_done() assert mqtt_client_mock.subscribe.call_count == 1 - mqtt_mock._mqtt_on_disconnect(None, None, 0) + mqtt_client_mock.on_disconnect(None, None, 0) with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0): - mqtt_mock._mqtt_on_connect(None, None, None, 0) + mqtt_client_mock.on_connect(None, None, None, 0) await hass.async_block_till_done() assert mqtt_client_mock.subscribe.call_count == 2 @@ -1044,23 +1103,143 @@ async def test_restore_all_active_subscriptions_on_reconnect( await hass.async_block_till_done() assert mqtt_client_mock.unsubscribe.call_count == 0 - mqtt_mock._mqtt_on_disconnect(None, None, 0) + mqtt_client_mock.on_disconnect(None, None, 0) with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0): - mqtt_mock._mqtt_on_connect(None, None, None, 0) + mqtt_client_mock.on_connect(None, None, None, 0) await hass.async_block_till_done() expected.append(call("test/state", 1)) assert mqtt_client_mock.subscribe.mock_calls == expected -async def test_setup_logs_error_if_no_connect_broker(hass, caplog): - """Test for setup failure if connection to broker is missing.""" +async def test_initial_setup_logs_error(hass, caplog, mqtt_client_mock): + """Test for setup failure if initial client connection fails.""" entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"}) + mqtt_client_mock.connect.return_value = 1 + assert await mqtt.async_setup_entry(hass, entry) + await hass.async_block_till_done() + assert "Failed to connect to MQTT server:" in caplog.text + + +async def test_logs_error_if_no_connect_broker( + hass, caplog, mqtt_mock, mqtt_client_mock +): + """Test for setup failure if connection to broker is missing.""" + # test with rc = 3 -> broker unavailable + mqtt_client_mock.on_connect(mqtt_client_mock, None, None, 3) + await hass.async_block_till_done() + assert ( + "Unable to connect to the MQTT broker: Connection Refused: broker unavailable." + in caplog.text + ) + + +@patch("homeassistant.components.mqtt.TIMEOUT_ACK", 0.3) +async def test_handle_mqtt_on_callback(hass, caplog, mqtt_mock, mqtt_client_mock): + """Test receiving an ACK callback before waiting for it.""" + # Simulate an ACK for mid == 1, this will call mqtt_mock._mqtt_handle_mid(mid) + mqtt_client_mock.on_publish(mqtt_client_mock, None, 1) + await hass.async_block_till_done() + # Make sure the ACK has been received + await hass.async_block_till_done() + # Now call publish without call back, this will call _wait_for_mid(msg_info.mid) + await mqtt.async_publish(hass, "no_callback/test-topic", "test-payload") + # Since the mid event was already set, we should not see any timeout + await hass.async_block_till_done() + assert ( + "Transmitting message on no_callback/test-topic: 'test-payload', mid: 1" + in caplog.text + ) + assert "No ACK from MQTT server" not in caplog.text + + +async def test_publish_error(hass, caplog): + """Test publish error.""" + entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"}) + + # simulate an Out of memory error with patch("paho.mqtt.client.Client") as mock_client: mock_client().connect = lambda *args: 1 + mock_client().publish().rc = 1 assert await mqtt.async_setup_entry(hass, entry) - assert "Failed to connect to MQTT server:" in caplog.text + await hass.async_block_till_done() + with pytest.raises(HomeAssistantError): + await mqtt.async_publish( + hass, "some-topic", b"test-payload", qos=0, retain=False, encoding=None + ) + assert "Failed to connect to MQTT server: Out of memory." in caplog.text + + +async def test_handle_message_callback(hass, caplog, mqtt_mock, mqtt_client_mock): + """Test for handling an incoming message callback.""" + msg = ReceiveMessage("some-topic", b"test-payload", 0, False) + mqtt_client_mock.on_connect(mqtt_client_mock, None, None, 0) + await mqtt.async_subscribe(hass, "some-topic", lambda *args: 0) + mqtt_client_mock.on_message(mock_mqtt, None, msg) + + await hass.async_block_till_done() + await hass.async_block_till_done() + assert "Received message on some-topic: b'test-payload'" in caplog.text + + +async def test_setup_override_configuration(hass, caplog, tmp_path): + """Test override setup from configuration entry.""" + calls_username_password_set = [] + + def mock_usename_password_set(username, password): + calls_username_password_set.append((username, password)) + + # Mock password setup from config + config = { + "username": "someuser", + "password": "someyamlconfiguredpassword", + "protocol": "3.1", + } + new_yaml_config_file = tmp_path / "configuration.yaml" + new_yaml_config = yaml.dump({mqtt.DOMAIN: config}) + new_yaml_config_file.write_text(new_yaml_config) + assert new_yaml_config_file.read_text() == new_yaml_config + + with patch.object(hass_config, "YAML_CONFIG_FILE", new_yaml_config_file): + # Mock config entry + entry = MockConfigEntry( + domain=mqtt.DOMAIN, + data={mqtt.CONF_BROKER: "test-broker", "password": "somepassword"}, + ) + + with patch("paho.mqtt.client.Client") as mock_client: + mock_client().username_pw_set = mock_usename_password_set + mock_client.on_connect(return_value=0) + await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: config}) + await entry.async_setup(hass) + await hass.async_block_till_done() + + assert ( + "Data in your configuration entry is going to override your configuration.yaml:" + in caplog.text + ) + + # Check if the protocol was set to 3.1 from configuration.yaml + assert mock_client.call_args[1]["protocol"] == 3 + + # Check if the password override worked + assert calls_username_password_set[0][0] == "someuser" + assert calls_username_password_set[0][1] == "somepassword" + + +async def test_setup_mqtt_client_protocol(hass): + """Test MQTT client protocol setup.""" + entry = MockConfigEntry( + domain=mqtt.DOMAIN, + data={mqtt.CONF_BROKER: "test-broker", mqtt.CONF_PROTOCOL: "3.1"}, + ) + with patch("paho.mqtt.client.Client") as mock_client: + mock_client.on_connect(return_value=0) + assert await mqtt.async_setup_entry(hass, entry) + + # check if protocol setup was correctly + assert mock_client.call_args[1]["protocol"] == 3 async def test_setup_raises_ConfigEntryNotReady_if_no_connect_broker(hass, caplog): @@ -1073,18 +1252,29 @@ async def test_setup_raises_ConfigEntryNotReady_if_no_connect_broker(hass, caplo assert "Failed to connect to MQTT server due to exception:" in caplog.text -async def test_setup_uses_certificate_on_certificate_set_to_auto(hass): - """Test setup uses bundled certs when certificate is set to auto.""" +@pytest.mark.parametrize("insecure", [None, False, True]) +async def test_setup_uses_certificate_on_certificate_set_to_auto_and_insecure( + hass, insecure +): + """Test setup uses bundled certs when certificate is set to auto and insecure.""" calls = [] + insecure_check = {"insecure": "not set"} def mock_tls_set(certificate, certfile=None, keyfile=None, tls_version=None): calls.append((certificate, certfile, keyfile, tls_version)) + def mock_tls_insecure_set(insecure_param): + insecure_check["insecure"] = insecure_param + + config_item_data = {mqtt.CONF_BROKER: "test-broker", "certificate": "auto"} + if insecure is not None: + config_item_data["tls_insecure"] = insecure with patch("paho.mqtt.client.Client") as mock_client: mock_client().tls_set = mock_tls_set + mock_client().tls_insecure_set = mock_tls_insecure_set entry = MockConfigEntry( domain=mqtt.DOMAIN, - data={mqtt.CONF_BROKER: "test-broker", "certificate": "auto"}, + data=config_item_data, ) assert await mqtt.async_setup_entry(hass, entry) @@ -1097,6 +1287,13 @@ async def test_setup_uses_certificate_on_certificate_set_to_auto(hass): # assert mock_mqtt.mock_calls[0][1][2]["certificate"] == expectedCertificate assert calls[0][0] == expectedCertificate + # test if insecure is set + assert ( + insecure_check["insecure"] == insecure + if insecure is not None + else insecure_check["insecure"] == "not set" + ) + async def test_setup_without_tls_config_uses_tlsv1_under_python36(hass): """Test setup defaults to TLSv1 under python3.6.""" @@ -1150,7 +1347,7 @@ async def test_custom_birth_message(hass, mqtt_client_mock, mqtt_mock): with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1): await mqtt.async_subscribe(hass, "birth", wait_birth) - mqtt_mock._mqtt_on_connect(None, None, 0, 0) + mqtt_client_mock.on_connect(None, None, 0, 0) await hass.async_block_till_done() await birth.wait() mqtt_client_mock.publish.assert_called_with("birth", "birth", 0, False) @@ -1180,7 +1377,7 @@ async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock): with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1): await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth) - mqtt_mock._mqtt_on_connect(None, None, 0, 0) + mqtt_client_mock.on_connect(None, None, 0, 0) await hass.async_block_till_done() await birth.wait() mqtt_client_mock.publish.assert_called_with( @@ -1195,7 +1392,7 @@ async def test_default_birth_message(hass, mqtt_client_mock, mqtt_mock): async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock): """Test disabling birth message.""" with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1): - mqtt_mock._mqtt_on_connect(None, None, 0, 0) + mqtt_client_mock.on_connect(None, None, 0, 0) await hass.async_block_till_done() await asyncio.sleep(0.2) mqtt_client_mock.publish.assert_not_called() @@ -1215,7 +1412,7 @@ async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock): } ], ) -async def test_delayed_birth_message(hass, mqtt_client_mock, mqtt_config): +async def test_delayed_birth_message(hass, mqtt_client_mock, mqtt_config, mqtt_mock): """Test sending birth message does not happen until Home Assistant starts.""" hass.state = CoreState.starting birth = asyncio.Event() @@ -1244,7 +1441,7 @@ async def test_delayed_birth_message(hass, mqtt_client_mock, mqtt_config): with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1): await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth) - mqtt_mock._mqtt_on_connect(None, None, 0, 0) + mqtt_client_mock.on_connect(None, None, 0, 0) await hass.async_block_till_done() with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(birth.wait(), 0.2) @@ -1313,7 +1510,7 @@ async def test_mqtt_subscribes_topics_on_connect(hass, mqtt_client_mock, mqtt_mo await mqtt.async_subscribe(hass, "still/pending", None, 1) hass.add_job = MagicMock() - mqtt_mock._mqtt_on_connect(None, None, 0, 0) + mqtt_client_mock.on_connect(None, None, 0, 0) await hass.async_block_till_done() @@ -1391,6 +1588,18 @@ async def test_mqtt_ws_subscription(hass, hass_ws_client, mqtt_mock): assert response["success"] +async def test_mqtt_ws_subscription_not_admin( + hass, hass_ws_client, mqtt_mock, hass_read_only_access_token +): + """Test MQTT websocket user is not admin.""" + client = await hass_ws_client(hass, access_token=hass_read_only_access_token) + await client.send_json({"id": 5, "type": "mqtt/subscribe", "topic": "test-topic"}) + response = await client.receive_json() + assert response["success"] is False + assert response["error"]["code"] == "unauthorized" + assert response["error"]["message"] == "Unauthorized" + + async def test_dump_service(hass, mqtt_mock): """Test that we can dump a topic.""" mopen = mock_open() @@ -2117,3 +2326,38 @@ async def test_service_info_compatibility(hass, caplog): with patch("homeassistant.helpers.frame._REPORTED_INTEGRATIONS", set()): assert discovery_info["topic"] == "tasmota/discovery/DC4F220848A2/config" assert "Detected integration that accessed discovery_info['topic']" in caplog.text + + +async def test_subscribe_connection_status(hass, mqtt_mock, mqtt_client_mock): + """Test connextion status subscription.""" + mqtt_connected_calls = [] + + @callback + async def async_mqtt_connected(status): + """Update state on connection/disconnection to MQTT broker.""" + mqtt_connected_calls.append(status) + + mqtt_mock.connected = True + + unsub = mqtt.async_subscribe_connection_status(hass, async_mqtt_connected) + await hass.async_block_till_done() + + # Mock connection status + mqtt_client_mock.on_connect(None, None, 0, 0) + await hass.async_block_till_done() + assert mqtt.is_connected(hass) is True + + # Mock disconnect status + mqtt_client_mock.on_disconnect(None, None, 0) + await hass.async_block_till_done() + + # Unsubscribe + unsub() + + mqtt_client_mock.on_connect(None, None, 0, 0) + await hass.async_block_till_done() + + # Check calls + assert len(mqtt_connected_calls) == 2 + assert mqtt_connected_calls[0] is True + assert mqtt_connected_calls[1] is False