Remove useless threading locks in mqtt (#118737)
This commit is contained in:
parent
278751607f
commit
67b3be8432
5 changed files with 106 additions and 15 deletions
60
homeassistant/components/mqtt/async_client.py
Normal file
60
homeassistant/components/mqtt/async_client.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
"""Async wrappings for mqtt client."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from types import TracebackType
|
||||
from typing import Self
|
||||
|
||||
from paho.mqtt.client import Client as MQTTClient
|
||||
|
||||
_MQTT_LOCK_COUNT = 7
|
||||
|
||||
|
||||
class NullLock:
|
||||
"""Null lock."""
|
||||
|
||||
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
|
||||
def __enter__(self) -> Self:
|
||||
"""Enter the lock."""
|
||||
return self
|
||||
|
||||
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit the lock."""
|
||||
|
||||
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
|
||||
def acquire(self, blocking: bool = False, timeout: int = -1) -> None:
|
||||
"""Acquire the lock."""
|
||||
|
||||
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
|
||||
def release(self) -> None:
|
||||
"""Release the lock."""
|
||||
|
||||
|
||||
class AsyncMQTTClient(MQTTClient):
|
||||
"""Async MQTT Client.
|
||||
|
||||
Wrapper around paho.mqtt.client.Client to remove the locking
|
||||
that is not needed since we are running in an async event loop.
|
||||
"""
|
||||
|
||||
def async_setup(self) -> None:
|
||||
"""Set up the client.
|
||||
|
||||
All the threading locks are replaced with NullLock
|
||||
since the client is running in an async event loop
|
||||
and will never run in multiple threads.
|
||||
"""
|
||||
self._in_callback_mutex = NullLock()
|
||||
self._callback_mutex = NullLock()
|
||||
self._msgtime_mutex = NullLock()
|
||||
self._out_message_mutex = NullLock()
|
||||
self._in_message_mutex = NullLock()
|
||||
self._reconnect_delay_mutex = NullLock()
|
||||
self._mid_generate_mutex = NullLock()
|
|
@ -91,6 +91,8 @@ if TYPE_CHECKING:
|
|||
# because integrations should be able to optionally rely on MQTT.
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
from .async_client import AsyncMQTTClient
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
MIN_BUFFER_SIZE = 131072 # Minimum buffer size to use if preferred size fails
|
||||
|
@ -281,6 +283,9 @@ class MqttClientSetup:
|
|||
# should be able to optionally rely on MQTT.
|
||||
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
from .async_client import AsyncMQTTClient
|
||||
|
||||
if (protocol := config.get(CONF_PROTOCOL, DEFAULT_PROTOCOL)) == PROTOCOL_31:
|
||||
proto = mqtt.MQTTv31
|
||||
elif protocol == PROTOCOL_5:
|
||||
|
@ -293,9 +298,10 @@ class MqttClientSetup:
|
|||
# However, that feature is not mandatory so we generate our own.
|
||||
client_id = mqtt.base62(uuid.uuid4().int, padding=22)
|
||||
transport = config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT)
|
||||
self._client = mqtt.Client(
|
||||
self._client = AsyncMQTTClient(
|
||||
client_id, protocol=proto, transport=transport, reconnect_on_failure=False
|
||||
)
|
||||
self._client.async_setup()
|
||||
|
||||
# Enable logging
|
||||
self._client.enable_logger()
|
||||
|
@ -329,7 +335,7 @@ class MqttClientSetup:
|
|||
self._client.tls_insecure_set(tls_insecure)
|
||||
|
||||
@property
|
||||
def client(self) -> mqtt.Client:
|
||||
def client(self) -> AsyncMQTTClient:
|
||||
"""Return the paho MQTT client."""
|
||||
return self._client
|
||||
|
||||
|
@ -434,7 +440,7 @@ class EnsureJobAfterCooldown:
|
|||
class MQTT:
|
||||
"""Home Assistant MQTT client."""
|
||||
|
||||
_mqttc: mqtt.Client
|
||||
_mqttc: AsyncMQTTClient
|
||||
_last_subscribe: float
|
||||
_mqtt_data: MqttData
|
||||
|
||||
|
@ -533,7 +539,9 @@ class MQTT:
|
|||
async def async_init_client(self) -> None:
|
||||
"""Initialize paho client."""
|
||||
with async_pause_setup(self.hass, SetupPhases.WAIT_IMPORT_PACKAGES):
|
||||
await async_import_module(self.hass, "paho.mqtt.client")
|
||||
await async_import_module(
|
||||
self.hass, "homeassistant.components.mqtt.async_client"
|
||||
)
|
||||
|
||||
mqttc = MqttClientSetup(self.conf).client
|
||||
# on_socket_unregister_write and _async_on_socket_close
|
||||
|
|
|
@ -121,7 +121,9 @@ def mock_try_connection_success() -> Generator[MqttMockPahoClient, None, None]:
|
|||
mock_client().on_unsubscribe(mock_client, 0, mid)
|
||||
return (0, mid)
|
||||
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
with patch(
|
||||
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||
) as mock_client:
|
||||
mock_client().loop_start = loop_start
|
||||
mock_client().subscribe = _subscribe
|
||||
mock_client().unsubscribe = _unsubscribe
|
||||
|
@ -135,7 +137,9 @@ def mock_try_connection_time_out() -> Generator[MagicMock, None, None]:
|
|||
|
||||
# Patch prevent waiting 5 sec for a timeout
|
||||
with (
|
||||
patch("paho.mqtt.client.Client") as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||
) as mock_client,
|
||||
patch("homeassistant.components.mqtt.config_flow.MQTT_TIMEOUT", 0),
|
||||
):
|
||||
mock_client().loop_start = lambda *args: 1
|
||||
|
|
|
@ -180,7 +180,9 @@ async def test_mqtt_await_ack_at_disconnect(
|
|||
mid = 100
|
||||
rc = 0
|
||||
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
with patch(
|
||||
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||
) as mock_client:
|
||||
mqtt_client = mock_client.return_value
|
||||
mqtt_client.connect = MagicMock(
|
||||
return_value=0,
|
||||
|
@ -191,10 +193,15 @@ async def test_mqtt_await_ack_at_disconnect(
|
|||
mqtt_client.publish = MagicMock(return_value=FakeInfo())
|
||||
entry = MockConfigEntry(
|
||||
domain=mqtt.DOMAIN,
|
||||
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"},
|
||||
data={
|
||||
"certificate": "auto",
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
mqtt.CONF_DISCOVERY: False,
|
||||
},
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
assert await hass.config_entries.async_setup(entry.entry_id)
|
||||
|
||||
mqtt_client = mock_client.return_value
|
||||
|
||||
# publish from MQTT client without awaiting
|
||||
|
@ -2219,7 +2226,9 @@ async def test_publish_error(
|
|||
entry.add_to_hass(hass)
|
||||
|
||||
# simulate an Out of memory error
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
with patch(
|
||||
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||
) as mock_client:
|
||||
mock_client().connect = lambda *args: 1
|
||||
mock_client().publish().rc = 1
|
||||
assert await hass.config_entries.async_setup(entry.entry_id)
|
||||
|
@ -2354,7 +2363,9 @@ async def test_setup_mqtt_client_protocol(
|
|||
protocol: int,
|
||||
) -> None:
|
||||
"""Test MQTT client protocol setup."""
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
with patch(
|
||||
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||
) as mock_client:
|
||||
await mqtt_mock_entry()
|
||||
|
||||
# check if protocol setup was correctly
|
||||
|
@ -2374,7 +2385,9 @@ async def test_handle_mqtt_timeout_on_callback(
|
|||
mid = 100
|
||||
rc = 0
|
||||
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
with patch(
|
||||
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||
) as mock_client:
|
||||
|
||||
def _mock_ack(topic: str, qos: int = 0) -> tuple[int, int]:
|
||||
# Handle ACK for subscribe normally
|
||||
|
@ -2419,7 +2432,9 @@ async def test_setup_raises_config_entry_not_ready_if_no_connect_broker(
|
|||
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
with patch(
|
||||
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||
) as mock_client:
|
||||
mock_client().connect = MagicMock(side_effect=OSError("Connection error"))
|
||||
assert await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
@ -2454,7 +2469,9 @@ async def test_setup_uses_certificate_on_certificate_set_to_auto_and_insecure(
|
|||
def mock_tls_insecure_set(insecure_param) -> None:
|
||||
insecure_check["insecure"] = insecure_param
|
||||
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
with patch(
|
||||
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||
) as mock_client:
|
||||
mock_client().tls_set = mock_tls_set
|
||||
mock_client().tls_insecure_set = mock_tls_insecure_set
|
||||
await mqtt_mock_entry()
|
||||
|
@ -4023,7 +4040,7 @@ async def test_link_config_entry(
|
|||
assert _check_entities() == 2
|
||||
|
||||
# reload entry and assert again
|
||||
with patch("paho.mqtt.client.Client"):
|
||||
with patch("homeassistant.components.mqtt.async_client.AsyncMQTTClient"):
|
||||
await hass.config_entries.async_reload(mqtt_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
|
|
@ -920,7 +920,9 @@ def mqtt_client_mock(hass: HomeAssistant) -> Generator[MqttMockPahoClient, None,
|
|||
self.mid = mid
|
||||
self.rc = 0
|
||||
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
with patch(
|
||||
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||
) as mock_client:
|
||||
# The below use a call_soon for the on_publish/on_subscribe/on_unsubscribe
|
||||
# callbacks to simulate the behavior of the real MQTT client which will
|
||||
# not be synchronous.
|
||||
|
|
Loading…
Add table
Reference in a new issue