Remove useless threading locks in mqtt (#118737)

This commit is contained in:
J. Nick Koston 2024-06-04 14:21:03 -05:00 committed by GitHub
parent 278751607f
commit 67b3be8432
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 106 additions and 15 deletions

View file

@ -0,0 +1,60 @@
"""Async wrappings for mqtt client."""
from __future__ import annotations
from functools import lru_cache
from types import TracebackType
from typing import Self
from paho.mqtt.client import Client as MQTTClient
_MQTT_LOCK_COUNT = 7
class NullLock:
"""Null lock."""
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
def __enter__(self) -> Self:
"""Enter the lock."""
return self
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
"""Exit the lock."""
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
def acquire(self, blocking: bool = False, timeout: int = -1) -> None:
"""Acquire the lock."""
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
def release(self) -> None:
"""Release the lock."""
class AsyncMQTTClient(MQTTClient):
"""Async MQTT Client.
Wrapper around paho.mqtt.client.Client to remove the locking
that is not needed since we are running in an async event loop.
"""
def async_setup(self) -> None:
"""Set up the client.
All the threading locks are replaced with NullLock
since the client is running in an async event loop
and will never run in multiple threads.
"""
self._in_callback_mutex = NullLock()
self._callback_mutex = NullLock()
self._msgtime_mutex = NullLock()
self._out_message_mutex = NullLock()
self._in_message_mutex = NullLock()
self._reconnect_delay_mutex = NullLock()
self._mid_generate_mutex = NullLock()

View file

@ -91,6 +91,8 @@ if TYPE_CHECKING:
# because integrations should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt
from .async_client import AsyncMQTTClient
_LOGGER = logging.getLogger(__name__)
MIN_BUFFER_SIZE = 131072 # Minimum buffer size to use if preferred size fails
@ -281,6 +283,9 @@ class MqttClientSetup:
# should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
# pylint: disable-next=import-outside-toplevel
from .async_client import AsyncMQTTClient
if (protocol := config.get(CONF_PROTOCOL, DEFAULT_PROTOCOL)) == PROTOCOL_31:
proto = mqtt.MQTTv31
elif protocol == PROTOCOL_5:
@ -293,9 +298,10 @@ class MqttClientSetup:
# However, that feature is not mandatory so we generate our own.
client_id = mqtt.base62(uuid.uuid4().int, padding=22)
transport = config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT)
self._client = mqtt.Client(
self._client = AsyncMQTTClient(
client_id, protocol=proto, transport=transport, reconnect_on_failure=False
)
self._client.async_setup()
# Enable logging
self._client.enable_logger()
@ -329,7 +335,7 @@ class MqttClientSetup:
self._client.tls_insecure_set(tls_insecure)
@property
def client(self) -> mqtt.Client:
def client(self) -> AsyncMQTTClient:
"""Return the paho MQTT client."""
return self._client
@ -434,7 +440,7 @@ class EnsureJobAfterCooldown:
class MQTT:
"""Home Assistant MQTT client."""
_mqttc: mqtt.Client
_mqttc: AsyncMQTTClient
_last_subscribe: float
_mqtt_data: MqttData
@ -533,7 +539,9 @@ class MQTT:
async def async_init_client(self) -> None:
"""Initialize paho client."""
with async_pause_setup(self.hass, SetupPhases.WAIT_IMPORT_PACKAGES):
await async_import_module(self.hass, "paho.mqtt.client")
await async_import_module(
self.hass, "homeassistant.components.mqtt.async_client"
)
mqttc = MqttClientSetup(self.conf).client
# on_socket_unregister_write and _async_on_socket_close

View file

@ -121,7 +121,9 @@ def mock_try_connection_success() -> Generator[MqttMockPahoClient, None, None]:
mock_client().on_unsubscribe(mock_client, 0, mid)
return (0, mid)
with patch("paho.mqtt.client.Client") as mock_client:
with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
mock_client().loop_start = loop_start
mock_client().subscribe = _subscribe
mock_client().unsubscribe = _unsubscribe
@ -135,7 +137,9 @@ def mock_try_connection_time_out() -> Generator[MagicMock, None, None]:
# Patch prevent waiting 5 sec for a timeout
with (
patch("paho.mqtt.client.Client") as mock_client,
patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client,
patch("homeassistant.components.mqtt.config_flow.MQTT_TIMEOUT", 0),
):
mock_client().loop_start = lambda *args: 1

View file

@ -180,7 +180,9 @@ async def test_mqtt_await_ack_at_disconnect(
mid = 100
rc = 0
with patch("paho.mqtt.client.Client") as mock_client:
with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
mqtt_client = mock_client.return_value
mqtt_client.connect = MagicMock(
return_value=0,
@ -191,10 +193,15 @@ async def test_mqtt_await_ack_at_disconnect(
mqtt_client.publish = MagicMock(return_value=FakeInfo())
entry = MockConfigEntry(
domain=mqtt.DOMAIN,
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"},
data={
"certificate": "auto",
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_DISCOVERY: False,
},
)
entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(entry.entry_id)
mqtt_client = mock_client.return_value
# publish from MQTT client without awaiting
@ -2219,7 +2226,9 @@ async def test_publish_error(
entry.add_to_hass(hass)
# simulate an Out of memory error
with patch("paho.mqtt.client.Client") as mock_client:
with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
mock_client().connect = lambda *args: 1
mock_client().publish().rc = 1
assert await hass.config_entries.async_setup(entry.entry_id)
@ -2354,7 +2363,9 @@ async def test_setup_mqtt_client_protocol(
protocol: int,
) -> None:
"""Test MQTT client protocol setup."""
with patch("paho.mqtt.client.Client") as mock_client:
with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
await mqtt_mock_entry()
# check if protocol setup was correctly
@ -2374,7 +2385,9 @@ async def test_handle_mqtt_timeout_on_callback(
mid = 100
rc = 0
with patch("paho.mqtt.client.Client") as mock_client:
with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
def _mock_ack(topic: str, qos: int = 0) -> tuple[int, int]:
# Handle ACK for subscribe normally
@ -2419,7 +2432,9 @@ async def test_setup_raises_config_entry_not_ready_if_no_connect_broker(
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
entry.add_to_hass(hass)
with patch("paho.mqtt.client.Client") as mock_client:
with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
mock_client().connect = MagicMock(side_effect=OSError("Connection error"))
assert await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
@ -2454,7 +2469,9 @@ async def test_setup_uses_certificate_on_certificate_set_to_auto_and_insecure(
def mock_tls_insecure_set(insecure_param) -> None:
insecure_check["insecure"] = insecure_param
with patch("paho.mqtt.client.Client") as mock_client:
with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
mock_client().tls_set = mock_tls_set
mock_client().tls_insecure_set = mock_tls_insecure_set
await mqtt_mock_entry()
@ -4023,7 +4040,7 @@ async def test_link_config_entry(
assert _check_entities() == 2
# reload entry and assert again
with patch("paho.mqtt.client.Client"):
with patch("homeassistant.components.mqtt.async_client.AsyncMQTTClient"):
await hass.config_entries.async_reload(mqtt_config_entry.entry_id)
await hass.async_block_till_done()

View file

@ -920,7 +920,9 @@ def mqtt_client_mock(hass: HomeAssistant) -> Generator[MqttMockPahoClient, None,
self.mid = mid
self.rc = 0
with patch("paho.mqtt.client.Client") as mock_client:
with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
# The below use a call_soon for the on_publish/on_subscribe/on_unsubscribe
# callbacks to simulate the behavior of the real MQTT client which will
# not be synchronous.