Remove useless threading locks in mqtt (#118737)

This commit is contained in:
J. Nick Koston 2024-06-04 14:21:03 -05:00 committed by GitHub
parent 278751607f
commit 67b3be8432
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 106 additions and 15 deletions

View file

@ -0,0 +1,60 @@
"""Async wrappings for mqtt client."""
from __future__ import annotations
from functools import lru_cache
from types import TracebackType
from typing import Self
from paho.mqtt.client import Client as MQTTClient
_MQTT_LOCK_COUNT = 7
class NullLock:
"""Null lock."""
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
def __enter__(self) -> Self:
"""Enter the lock."""
return self
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
"""Exit the lock."""
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
def acquire(self, blocking: bool = False, timeout: int = -1) -> None:
"""Acquire the lock."""
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
def release(self) -> None:
"""Release the lock."""
class AsyncMQTTClient(MQTTClient):
"""Async MQTT Client.
Wrapper around paho.mqtt.client.Client to remove the locking
that is not needed since we are running in an async event loop.
"""
def async_setup(self) -> None:
"""Set up the client.
All the threading locks are replaced with NullLock
since the client is running in an async event loop
and will never run in multiple threads.
"""
self._in_callback_mutex = NullLock()
self._callback_mutex = NullLock()
self._msgtime_mutex = NullLock()
self._out_message_mutex = NullLock()
self._in_message_mutex = NullLock()
self._reconnect_delay_mutex = NullLock()
self._mid_generate_mutex = NullLock()

View file

@ -91,6 +91,8 @@ if TYPE_CHECKING:
# because integrations should be able to optionally rely on MQTT. # because integrations should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
from .async_client import AsyncMQTTClient
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
MIN_BUFFER_SIZE = 131072 # Minimum buffer size to use if preferred size fails MIN_BUFFER_SIZE = 131072 # Minimum buffer size to use if preferred size fails
@ -281,6 +283,9 @@ class MqttClientSetup:
# should be able to optionally rely on MQTT. # should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
# pylint: disable-next=import-outside-toplevel
from .async_client import AsyncMQTTClient
if (protocol := config.get(CONF_PROTOCOL, DEFAULT_PROTOCOL)) == PROTOCOL_31: if (protocol := config.get(CONF_PROTOCOL, DEFAULT_PROTOCOL)) == PROTOCOL_31:
proto = mqtt.MQTTv31 proto = mqtt.MQTTv31
elif protocol == PROTOCOL_5: elif protocol == PROTOCOL_5:
@ -293,9 +298,10 @@ class MqttClientSetup:
# However, that feature is not mandatory so we generate our own. # However, that feature is not mandatory so we generate our own.
client_id = mqtt.base62(uuid.uuid4().int, padding=22) client_id = mqtt.base62(uuid.uuid4().int, padding=22)
transport = config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT) transport = config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT)
self._client = mqtt.Client( self._client = AsyncMQTTClient(
client_id, protocol=proto, transport=transport, reconnect_on_failure=False client_id, protocol=proto, transport=transport, reconnect_on_failure=False
) )
self._client.async_setup()
# Enable logging # Enable logging
self._client.enable_logger() self._client.enable_logger()
@ -329,7 +335,7 @@ class MqttClientSetup:
self._client.tls_insecure_set(tls_insecure) self._client.tls_insecure_set(tls_insecure)
@property @property
def client(self) -> mqtt.Client: def client(self) -> AsyncMQTTClient:
"""Return the paho MQTT client.""" """Return the paho MQTT client."""
return self._client return self._client
@ -434,7 +440,7 @@ class EnsureJobAfterCooldown:
class MQTT: class MQTT:
"""Home Assistant MQTT client.""" """Home Assistant MQTT client."""
_mqttc: mqtt.Client _mqttc: AsyncMQTTClient
_last_subscribe: float _last_subscribe: float
_mqtt_data: MqttData _mqtt_data: MqttData
@ -533,7 +539,9 @@ class MQTT:
async def async_init_client(self) -> None: async def async_init_client(self) -> None:
"""Initialize paho client.""" """Initialize paho client."""
with async_pause_setup(self.hass, SetupPhases.WAIT_IMPORT_PACKAGES): with async_pause_setup(self.hass, SetupPhases.WAIT_IMPORT_PACKAGES):
await async_import_module(self.hass, "paho.mqtt.client") await async_import_module(
self.hass, "homeassistant.components.mqtt.async_client"
)
mqttc = MqttClientSetup(self.conf).client mqttc = MqttClientSetup(self.conf).client
# on_socket_unregister_write and _async_on_socket_close # on_socket_unregister_write and _async_on_socket_close

View file

@ -121,7 +121,9 @@ def mock_try_connection_success() -> Generator[MqttMockPahoClient, None, None]:
mock_client().on_unsubscribe(mock_client, 0, mid) mock_client().on_unsubscribe(mock_client, 0, mid)
return (0, mid) return (0, mid)
with patch("paho.mqtt.client.Client") as mock_client: with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
mock_client().loop_start = loop_start mock_client().loop_start = loop_start
mock_client().subscribe = _subscribe mock_client().subscribe = _subscribe
mock_client().unsubscribe = _unsubscribe mock_client().unsubscribe = _unsubscribe
@ -135,7 +137,9 @@ def mock_try_connection_time_out() -> Generator[MagicMock, None, None]:
# Patch prevent waiting 5 sec for a timeout # Patch prevent waiting 5 sec for a timeout
with ( with (
patch("paho.mqtt.client.Client") as mock_client, patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client,
patch("homeassistant.components.mqtt.config_flow.MQTT_TIMEOUT", 0), patch("homeassistant.components.mqtt.config_flow.MQTT_TIMEOUT", 0),
): ):
mock_client().loop_start = lambda *args: 1 mock_client().loop_start = lambda *args: 1

View file

@ -180,7 +180,9 @@ async def test_mqtt_await_ack_at_disconnect(
mid = 100 mid = 100
rc = 0 rc = 0
with patch("paho.mqtt.client.Client") as mock_client: with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
mqtt_client = mock_client.return_value mqtt_client = mock_client.return_value
mqtt_client.connect = MagicMock( mqtt_client.connect = MagicMock(
return_value=0, return_value=0,
@ -191,10 +193,15 @@ async def test_mqtt_await_ack_at_disconnect(
mqtt_client.publish = MagicMock(return_value=FakeInfo()) mqtt_client.publish = MagicMock(return_value=FakeInfo())
entry = MockConfigEntry( entry = MockConfigEntry(
domain=mqtt.DOMAIN, domain=mqtt.DOMAIN,
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"}, data={
"certificate": "auto",
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_DISCOVERY: False,
},
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(entry.entry_id) assert await hass.config_entries.async_setup(entry.entry_id)
mqtt_client = mock_client.return_value mqtt_client = mock_client.return_value
# publish from MQTT client without awaiting # publish from MQTT client without awaiting
@ -2219,7 +2226,9 @@ async def test_publish_error(
entry.add_to_hass(hass) entry.add_to_hass(hass)
# simulate an Out of memory error # simulate an Out of memory error
with patch("paho.mqtt.client.Client") as mock_client: with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
mock_client().connect = lambda *args: 1 mock_client().connect = lambda *args: 1
mock_client().publish().rc = 1 mock_client().publish().rc = 1
assert await hass.config_entries.async_setup(entry.entry_id) assert await hass.config_entries.async_setup(entry.entry_id)
@ -2354,7 +2363,9 @@ async def test_setup_mqtt_client_protocol(
protocol: int, protocol: int,
) -> None: ) -> None:
"""Test MQTT client protocol setup.""" """Test MQTT client protocol setup."""
with patch("paho.mqtt.client.Client") as mock_client: with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
await mqtt_mock_entry() await mqtt_mock_entry()
# check if protocol setup was correctly # check if protocol setup was correctly
@ -2374,7 +2385,9 @@ async def test_handle_mqtt_timeout_on_callback(
mid = 100 mid = 100
rc = 0 rc = 0
with patch("paho.mqtt.client.Client") as mock_client: with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
def _mock_ack(topic: str, qos: int = 0) -> tuple[int, int]: def _mock_ack(topic: str, qos: int = 0) -> tuple[int, int]:
# Handle ACK for subscribe normally # Handle ACK for subscribe normally
@ -2419,7 +2432,9 @@ async def test_setup_raises_config_entry_not_ready_if_no_connect_broker(
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"}) entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
entry.add_to_hass(hass) entry.add_to_hass(hass)
with patch("paho.mqtt.client.Client") as mock_client: with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
mock_client().connect = MagicMock(side_effect=OSError("Connection error")) mock_client().connect = MagicMock(side_effect=OSError("Connection error"))
assert await hass.config_entries.async_setup(entry.entry_id) assert await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -2454,7 +2469,9 @@ async def test_setup_uses_certificate_on_certificate_set_to_auto_and_insecure(
def mock_tls_insecure_set(insecure_param) -> None: def mock_tls_insecure_set(insecure_param) -> None:
insecure_check["insecure"] = insecure_param insecure_check["insecure"] = insecure_param
with patch("paho.mqtt.client.Client") as mock_client: with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
mock_client().tls_set = mock_tls_set mock_client().tls_set = mock_tls_set
mock_client().tls_insecure_set = mock_tls_insecure_set mock_client().tls_insecure_set = mock_tls_insecure_set
await mqtt_mock_entry() await mqtt_mock_entry()
@ -4023,7 +4040,7 @@ async def test_link_config_entry(
assert _check_entities() == 2 assert _check_entities() == 2
# reload entry and assert again # reload entry and assert again
with patch("paho.mqtt.client.Client"): with patch("homeassistant.components.mqtt.async_client.AsyncMQTTClient"):
await hass.config_entries.async_reload(mqtt_config_entry.entry_id) await hass.config_entries.async_reload(mqtt_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()

View file

@ -920,7 +920,9 @@ def mqtt_client_mock(hass: HomeAssistant) -> Generator[MqttMockPahoClient, None,
self.mid = mid self.mid = mid
self.rc = 0 self.rc = 0
with patch("paho.mqtt.client.Client") as mock_client: with patch(
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
) as mock_client:
# The below use a call_soon for the on_publish/on_subscribe/on_unsubscribe # The below use a call_soon for the on_publish/on_subscribe/on_unsubscribe
# callbacks to simulate the behavior of the real MQTT client which will # callbacks to simulate the behavior of the real MQTT client which will
# not be synchronous. # not be synchronous.