Remove useless threading locks in mqtt (#118737)
This commit is contained in:
parent
278751607f
commit
67b3be8432
5 changed files with 106 additions and 15 deletions
60
homeassistant/components/mqtt/async_client.py
Normal file
60
homeassistant/components/mqtt/async_client.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
"""Async wrappings for mqtt client."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
|
from paho.mqtt.client import Client as MQTTClient
|
||||||
|
|
||||||
|
_MQTT_LOCK_COUNT = 7
|
||||||
|
|
||||||
|
|
||||||
|
class NullLock:
|
||||||
|
"""Null lock."""
|
||||||
|
|
||||||
|
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
"""Enter the lock."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_value: BaseException | None,
|
||||||
|
traceback: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
"""Exit the lock."""
|
||||||
|
|
||||||
|
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
|
||||||
|
def acquire(self, blocking: bool = False, timeout: int = -1) -> None:
|
||||||
|
"""Acquire the lock."""
|
||||||
|
|
||||||
|
@lru_cache(maxsize=_MQTT_LOCK_COUNT)
|
||||||
|
def release(self) -> None:
|
||||||
|
"""Release the lock."""
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMQTTClient(MQTTClient):
|
||||||
|
"""Async MQTT Client.
|
||||||
|
|
||||||
|
Wrapper around paho.mqtt.client.Client to remove the locking
|
||||||
|
that is not needed since we are running in an async event loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def async_setup(self) -> None:
|
||||||
|
"""Set up the client.
|
||||||
|
|
||||||
|
All the threading locks are replaced with NullLock
|
||||||
|
since the client is running in an async event loop
|
||||||
|
and will never run in multiple threads.
|
||||||
|
"""
|
||||||
|
self._in_callback_mutex = NullLock()
|
||||||
|
self._callback_mutex = NullLock()
|
||||||
|
self._msgtime_mutex = NullLock()
|
||||||
|
self._out_message_mutex = NullLock()
|
||||||
|
self._in_message_mutex = NullLock()
|
||||||
|
self._reconnect_delay_mutex = NullLock()
|
||||||
|
self._mid_generate_mutex = NullLock()
|
|
@ -91,6 +91,8 @@ if TYPE_CHECKING:
|
||||||
# because integrations should be able to optionally rely on MQTT.
|
# because integrations should be able to optionally rely on MQTT.
|
||||||
import paho.mqtt.client as mqtt
|
import paho.mqtt.client as mqtt
|
||||||
|
|
||||||
|
from .async_client import AsyncMQTTClient
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
MIN_BUFFER_SIZE = 131072 # Minimum buffer size to use if preferred size fails
|
MIN_BUFFER_SIZE = 131072 # Minimum buffer size to use if preferred size fails
|
||||||
|
@ -281,6 +283,9 @@ class MqttClientSetup:
|
||||||
# should be able to optionally rely on MQTT.
|
# should be able to optionally rely on MQTT.
|
||||||
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
# pylint: disable-next=import-outside-toplevel
|
||||||
|
from .async_client import AsyncMQTTClient
|
||||||
|
|
||||||
if (protocol := config.get(CONF_PROTOCOL, DEFAULT_PROTOCOL)) == PROTOCOL_31:
|
if (protocol := config.get(CONF_PROTOCOL, DEFAULT_PROTOCOL)) == PROTOCOL_31:
|
||||||
proto = mqtt.MQTTv31
|
proto = mqtt.MQTTv31
|
||||||
elif protocol == PROTOCOL_5:
|
elif protocol == PROTOCOL_5:
|
||||||
|
@ -293,9 +298,10 @@ class MqttClientSetup:
|
||||||
# However, that feature is not mandatory so we generate our own.
|
# However, that feature is not mandatory so we generate our own.
|
||||||
client_id = mqtt.base62(uuid.uuid4().int, padding=22)
|
client_id = mqtt.base62(uuid.uuid4().int, padding=22)
|
||||||
transport = config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT)
|
transport = config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT)
|
||||||
self._client = mqtt.Client(
|
self._client = AsyncMQTTClient(
|
||||||
client_id, protocol=proto, transport=transport, reconnect_on_failure=False
|
client_id, protocol=proto, transport=transport, reconnect_on_failure=False
|
||||||
)
|
)
|
||||||
|
self._client.async_setup()
|
||||||
|
|
||||||
# Enable logging
|
# Enable logging
|
||||||
self._client.enable_logger()
|
self._client.enable_logger()
|
||||||
|
@ -329,7 +335,7 @@ class MqttClientSetup:
|
||||||
self._client.tls_insecure_set(tls_insecure)
|
self._client.tls_insecure_set(tls_insecure)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> mqtt.Client:
|
def client(self) -> AsyncMQTTClient:
|
||||||
"""Return the paho MQTT client."""
|
"""Return the paho MQTT client."""
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
|
@ -434,7 +440,7 @@ class EnsureJobAfterCooldown:
|
||||||
class MQTT:
|
class MQTT:
|
||||||
"""Home Assistant MQTT client."""
|
"""Home Assistant MQTT client."""
|
||||||
|
|
||||||
_mqttc: mqtt.Client
|
_mqttc: AsyncMQTTClient
|
||||||
_last_subscribe: float
|
_last_subscribe: float
|
||||||
_mqtt_data: MqttData
|
_mqtt_data: MqttData
|
||||||
|
|
||||||
|
@ -533,7 +539,9 @@ class MQTT:
|
||||||
async def async_init_client(self) -> None:
|
async def async_init_client(self) -> None:
|
||||||
"""Initialize paho client."""
|
"""Initialize paho client."""
|
||||||
with async_pause_setup(self.hass, SetupPhases.WAIT_IMPORT_PACKAGES):
|
with async_pause_setup(self.hass, SetupPhases.WAIT_IMPORT_PACKAGES):
|
||||||
await async_import_module(self.hass, "paho.mqtt.client")
|
await async_import_module(
|
||||||
|
self.hass, "homeassistant.components.mqtt.async_client"
|
||||||
|
)
|
||||||
|
|
||||||
mqttc = MqttClientSetup(self.conf).client
|
mqttc = MqttClientSetup(self.conf).client
|
||||||
# on_socket_unregister_write and _async_on_socket_close
|
# on_socket_unregister_write and _async_on_socket_close
|
||||||
|
|
|
@ -121,7 +121,9 @@ def mock_try_connection_success() -> Generator[MqttMockPahoClient, None, None]:
|
||||||
mock_client().on_unsubscribe(mock_client, 0, mid)
|
mock_client().on_unsubscribe(mock_client, 0, mid)
|
||||||
return (0, mid)
|
return (0, mid)
|
||||||
|
|
||||||
with patch("paho.mqtt.client.Client") as mock_client:
|
with patch(
|
||||||
|
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||||
|
) as mock_client:
|
||||||
mock_client().loop_start = loop_start
|
mock_client().loop_start = loop_start
|
||||||
mock_client().subscribe = _subscribe
|
mock_client().subscribe = _subscribe
|
||||||
mock_client().unsubscribe = _unsubscribe
|
mock_client().unsubscribe = _unsubscribe
|
||||||
|
@ -135,7 +137,9 @@ def mock_try_connection_time_out() -> Generator[MagicMock, None, None]:
|
||||||
|
|
||||||
# Patch prevent waiting 5 sec for a timeout
|
# Patch prevent waiting 5 sec for a timeout
|
||||||
with (
|
with (
|
||||||
patch("paho.mqtt.client.Client") as mock_client,
|
patch(
|
||||||
|
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||||
|
) as mock_client,
|
||||||
patch("homeassistant.components.mqtt.config_flow.MQTT_TIMEOUT", 0),
|
patch("homeassistant.components.mqtt.config_flow.MQTT_TIMEOUT", 0),
|
||||||
):
|
):
|
||||||
mock_client().loop_start = lambda *args: 1
|
mock_client().loop_start = lambda *args: 1
|
||||||
|
|
|
@ -180,7 +180,9 @@ async def test_mqtt_await_ack_at_disconnect(
|
||||||
mid = 100
|
mid = 100
|
||||||
rc = 0
|
rc = 0
|
||||||
|
|
||||||
with patch("paho.mqtt.client.Client") as mock_client:
|
with patch(
|
||||||
|
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||||
|
) as mock_client:
|
||||||
mqtt_client = mock_client.return_value
|
mqtt_client = mock_client.return_value
|
||||||
mqtt_client.connect = MagicMock(
|
mqtt_client.connect = MagicMock(
|
||||||
return_value=0,
|
return_value=0,
|
||||||
|
@ -191,10 +193,15 @@ async def test_mqtt_await_ack_at_disconnect(
|
||||||
mqtt_client.publish = MagicMock(return_value=FakeInfo())
|
mqtt_client.publish = MagicMock(return_value=FakeInfo())
|
||||||
entry = MockConfigEntry(
|
entry = MockConfigEntry(
|
||||||
domain=mqtt.DOMAIN,
|
domain=mqtt.DOMAIN,
|
||||||
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"},
|
data={
|
||||||
|
"certificate": "auto",
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
mqtt.CONF_DISCOVERY: False,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
entry.add_to_hass(hass)
|
entry.add_to_hass(hass)
|
||||||
assert await hass.config_entries.async_setup(entry.entry_id)
|
assert await hass.config_entries.async_setup(entry.entry_id)
|
||||||
|
|
||||||
mqtt_client = mock_client.return_value
|
mqtt_client = mock_client.return_value
|
||||||
|
|
||||||
# publish from MQTT client without awaiting
|
# publish from MQTT client without awaiting
|
||||||
|
@ -2219,7 +2226,9 @@ async def test_publish_error(
|
||||||
entry.add_to_hass(hass)
|
entry.add_to_hass(hass)
|
||||||
|
|
||||||
# simulate an Out of memory error
|
# simulate an Out of memory error
|
||||||
with patch("paho.mqtt.client.Client") as mock_client:
|
with patch(
|
||||||
|
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||||
|
) as mock_client:
|
||||||
mock_client().connect = lambda *args: 1
|
mock_client().connect = lambda *args: 1
|
||||||
mock_client().publish().rc = 1
|
mock_client().publish().rc = 1
|
||||||
assert await hass.config_entries.async_setup(entry.entry_id)
|
assert await hass.config_entries.async_setup(entry.entry_id)
|
||||||
|
@ -2354,7 +2363,9 @@ async def test_setup_mqtt_client_protocol(
|
||||||
protocol: int,
|
protocol: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test MQTT client protocol setup."""
|
"""Test MQTT client protocol setup."""
|
||||||
with patch("paho.mqtt.client.Client") as mock_client:
|
with patch(
|
||||||
|
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||||
|
) as mock_client:
|
||||||
await mqtt_mock_entry()
|
await mqtt_mock_entry()
|
||||||
|
|
||||||
# check if protocol setup was correctly
|
# check if protocol setup was correctly
|
||||||
|
@ -2374,7 +2385,9 @@ async def test_handle_mqtt_timeout_on_callback(
|
||||||
mid = 100
|
mid = 100
|
||||||
rc = 0
|
rc = 0
|
||||||
|
|
||||||
with patch("paho.mqtt.client.Client") as mock_client:
|
with patch(
|
||||||
|
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||||
|
) as mock_client:
|
||||||
|
|
||||||
def _mock_ack(topic: str, qos: int = 0) -> tuple[int, int]:
|
def _mock_ack(topic: str, qos: int = 0) -> tuple[int, int]:
|
||||||
# Handle ACK for subscribe normally
|
# Handle ACK for subscribe normally
|
||||||
|
@ -2419,7 +2432,9 @@ async def test_setup_raises_config_entry_not_ready_if_no_connect_broker(
|
||||||
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
|
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
|
||||||
entry.add_to_hass(hass)
|
entry.add_to_hass(hass)
|
||||||
|
|
||||||
with patch("paho.mqtt.client.Client") as mock_client:
|
with patch(
|
||||||
|
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||||
|
) as mock_client:
|
||||||
mock_client().connect = MagicMock(side_effect=OSError("Connection error"))
|
mock_client().connect = MagicMock(side_effect=OSError("Connection error"))
|
||||||
assert await hass.config_entries.async_setup(entry.entry_id)
|
assert await hass.config_entries.async_setup(entry.entry_id)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
@ -2454,7 +2469,9 @@ async def test_setup_uses_certificate_on_certificate_set_to_auto_and_insecure(
|
||||||
def mock_tls_insecure_set(insecure_param) -> None:
|
def mock_tls_insecure_set(insecure_param) -> None:
|
||||||
insecure_check["insecure"] = insecure_param
|
insecure_check["insecure"] = insecure_param
|
||||||
|
|
||||||
with patch("paho.mqtt.client.Client") as mock_client:
|
with patch(
|
||||||
|
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||||
|
) as mock_client:
|
||||||
mock_client().tls_set = mock_tls_set
|
mock_client().tls_set = mock_tls_set
|
||||||
mock_client().tls_insecure_set = mock_tls_insecure_set
|
mock_client().tls_insecure_set = mock_tls_insecure_set
|
||||||
await mqtt_mock_entry()
|
await mqtt_mock_entry()
|
||||||
|
@ -4023,7 +4040,7 @@ async def test_link_config_entry(
|
||||||
assert _check_entities() == 2
|
assert _check_entities() == 2
|
||||||
|
|
||||||
# reload entry and assert again
|
# reload entry and assert again
|
||||||
with patch("paho.mqtt.client.Client"):
|
with patch("homeassistant.components.mqtt.async_client.AsyncMQTTClient"):
|
||||||
await hass.config_entries.async_reload(mqtt_config_entry.entry_id)
|
await hass.config_entries.async_reload(mqtt_config_entry.entry_id)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
|
@ -920,7 +920,9 @@ def mqtt_client_mock(hass: HomeAssistant) -> Generator[MqttMockPahoClient, None,
|
||||||
self.mid = mid
|
self.mid = mid
|
||||||
self.rc = 0
|
self.rc = 0
|
||||||
|
|
||||||
with patch("paho.mqtt.client.Client") as mock_client:
|
with patch(
|
||||||
|
"homeassistant.components.mqtt.async_client.AsyncMQTTClient"
|
||||||
|
) as mock_client:
|
||||||
# The below use a call_soon for the on_publish/on_subscribe/on_unsubscribe
|
# The below use a call_soon for the on_publish/on_subscribe/on_unsubscribe
|
||||||
# callbacks to simulate the behavior of the real MQTT client which will
|
# callbacks to simulate the behavior of the real MQTT client which will
|
||||||
# not be synchronous.
|
# not be synchronous.
|
||||||
|
|
Loading…
Add table
Reference in a new issue