Make mqtt internal subscription a normal function (#118092)
Co-authored-by: Jan Bouwhuis <jbouwh@users.noreply.github.com>
This commit is contained in:
parent
ecd48cc447
commit
9be829ba1f
30 changed files with 140 additions and 83 deletions
|
@ -39,6 +39,7 @@ from .client import ( # noqa: F401
|
||||||
MQTT,
|
MQTT,
|
||||||
async_publish,
|
async_publish,
|
||||||
async_subscribe,
|
async_subscribe,
|
||||||
|
async_subscribe_internal,
|
||||||
publish,
|
publish,
|
||||||
subscribe,
|
subscribe,
|
||||||
)
|
)
|
||||||
|
@ -311,7 +312,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
def collect_msg(msg: ReceiveMessage) -> None:
|
def collect_msg(msg: ReceiveMessage) -> None:
|
||||||
messages.append((msg.topic, str(msg.payload).replace("\n", "")))
|
messages.append((msg.topic, str(msg.payload).replace("\n", "")))
|
||||||
|
|
||||||
unsub = await async_subscribe(hass, call.data["topic"], collect_msg)
|
unsub = async_subscribe_internal(hass, call.data["topic"], collect_msg)
|
||||||
|
|
||||||
def write_dump() -> None:
|
def write_dump() -> None:
|
||||||
with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp:
|
with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp:
|
||||||
|
@ -459,7 +460,7 @@ async def websocket_subscribe(
|
||||||
|
|
||||||
# Perform UTF-8 decoding directly in callback routine
|
# Perform UTF-8 decoding directly in callback routine
|
||||||
qos: int = msg.get("qos", DEFAULT_QOS)
|
qos: int = msg.get("qos", DEFAULT_QOS)
|
||||||
connection.subscriptions[msg["id"]] = await async_subscribe(
|
connection.subscriptions[msg["id"]] = async_subscribe_internal(
|
||||||
hass, msg["topic"], forward_messages, encoding=None, qos=qos
|
hass, msg["topic"], forward_messages, encoding=None, qos=qos
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -226,7 +226,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_alarm_disarm(self, code: str | None = None) -> None:
|
async def async_alarm_disarm(self, code: str | None = None) -> None:
|
||||||
"""Send disarm command.
|
"""Send disarm command.
|
||||||
|
|
|
@ -254,7 +254,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _value_is_expired(self, *_: Any) -> None:
|
def _value_is_expired(self, *_: Any) -> None:
|
||||||
|
|
|
@ -130,7 +130,7 @@ class MqttCamera(MqttEntity, Camera):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_camera_image(
|
async def async_camera_image(
|
||||||
self, width: int | None = None, height: int | None = None
|
self, width: int | None = None, height: int | None = None
|
||||||
|
|
|
@ -191,13 +191,25 @@ async def async_subscribe(
|
||||||
|
|
||||||
Call the return value to unsubscribe.
|
Call the return value to unsubscribe.
|
||||||
"""
|
"""
|
||||||
if not mqtt_config_entry_enabled(hass):
|
return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)
|
||||||
raise HomeAssistantError(
|
|
||||||
f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
|
|
||||||
translation_key="mqtt_not_setup_cannot_subscribe",
|
@callback
|
||||||
translation_domain=DOMAIN,
|
def async_subscribe_internal(
|
||||||
translation_placeholders={"topic": topic},
|
hass: HomeAssistant,
|
||||||
)
|
topic: str,
|
||||||
|
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
||||||
|
qos: int = DEFAULT_QOS,
|
||||||
|
encoding: str | None = DEFAULT_ENCODING,
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
|
"""Subscribe to an MQTT topic.
|
||||||
|
|
||||||
|
This function is internal to the MQTT integration
|
||||||
|
and may change at any time. It should not be considered
|
||||||
|
a stable API.
|
||||||
|
|
||||||
|
Call the return value to unsubscribe.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
mqtt_data = hass.data[DATA_MQTT]
|
mqtt_data = hass.data[DATA_MQTT]
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
|
@ -208,12 +220,15 @@ async def async_subscribe(
|
||||||
translation_domain=DOMAIN,
|
translation_domain=DOMAIN,
|
||||||
translation_placeholders={"topic": topic},
|
translation_placeholders={"topic": topic},
|
||||||
) from exc
|
) from exc
|
||||||
return await mqtt_data.client.async_subscribe(
|
client = mqtt_data.client
|
||||||
topic,
|
if not client.connected and not mqtt_config_entry_enabled(hass):
|
||||||
msg_callback,
|
raise HomeAssistantError(
|
||||||
qos,
|
f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
|
||||||
encoding,
|
translation_key="mqtt_not_setup_cannot_subscribe",
|
||||||
)
|
translation_domain=DOMAIN,
|
||||||
|
translation_placeholders={"topic": topic},
|
||||||
|
)
|
||||||
|
return client.async_subscribe(topic, msg_callback, qos, encoding)
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
|
@ -845,17 +860,15 @@ class MQTT:
|
||||||
f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe]
|
f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_subscribe(
|
@callback
|
||||||
|
def async_subscribe(
|
||||||
self,
|
self,
|
||||||
topic: str,
|
topic: str,
|
||||||
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
||||||
qos: int,
|
qos: int,
|
||||||
encoding: str | None = None,
|
encoding: str | None = None,
|
||||||
) -> Callable[[], None]:
|
) -> Callable[[], None]:
|
||||||
"""Set up a subscription to a topic with the provided qos.
|
"""Set up a subscription to a topic with the provided qos."""
|
||||||
|
|
||||||
This method is a coroutine.
|
|
||||||
"""
|
|
||||||
if not isinstance(topic, str):
|
if not isinstance(topic, str):
|
||||||
raise HomeAssistantError("Topic needs to be a string!")
|
raise HomeAssistantError("Topic needs to be a string!")
|
||||||
|
|
||||||
|
@ -881,18 +894,18 @@ class MQTT:
|
||||||
if self.connected:
|
if self.connected:
|
||||||
self._async_queue_subscriptions(((topic, qos),))
|
self._async_queue_subscriptions(((topic, qos),))
|
||||||
|
|
||||||
@callback
|
return partial(self._async_remove, subscription)
|
||||||
def async_remove() -> None:
|
|
||||||
"""Remove subscription."""
|
|
||||||
self._async_untrack_subscription(subscription)
|
|
||||||
self._matching_subscriptions.cache_clear()
|
|
||||||
if subscription in self._retained_topics:
|
|
||||||
del self._retained_topics[subscription]
|
|
||||||
# Only unsubscribe if currently connected
|
|
||||||
if self.connected:
|
|
||||||
self._async_unsubscribe(topic)
|
|
||||||
|
|
||||||
return async_remove
|
@callback
|
||||||
|
def _async_remove(self, subscription: Subscription) -> None:
|
||||||
|
"""Remove subscription."""
|
||||||
|
self._async_untrack_subscription(subscription)
|
||||||
|
self._matching_subscriptions.cache_clear()
|
||||||
|
if subscription in self._retained_topics:
|
||||||
|
del self._retained_topics[subscription]
|
||||||
|
# Only unsubscribe if currently connected
|
||||||
|
if self.connected:
|
||||||
|
self._async_unsubscribe(subscription.topic)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_unsubscribe(self, topic: str) -> None:
|
def _async_unsubscribe(self, topic: str) -> None:
|
||||||
|
|
|
@ -511,7 +511,7 @@ class MqttTemperatureControlEntity(MqttEntity, ABC):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def _publish(self, topic: str, payload: PublishPayloadType) -> None:
|
async def _publish(self, topic: str, payload: PublishPayloadType) -> None:
|
||||||
if self._topic[topic] is not None:
|
if self._topic[topic] is not None:
|
||||||
|
|
|
@ -512,7 +512,7 @@ class MqttCover(MqttEntity, CoverEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_open_cover(self, **kwargs: Any) -> None:
|
async def async_open_cover(self, **kwargs: Any) -> None:
|
||||||
"""Move the cover up.
|
"""Move the cover up.
|
||||||
|
|
|
@ -166,7 +166,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def latitude(self) -> float | None:
|
def latitude(self) -> float | None:
|
||||||
|
|
|
@ -208,4 +208,4 @@ class MqttEvent(MqttEntity, EventEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
|
@ -477,7 +477,7 @@ class MqttFan(MqttEntity, FanEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_on(self) -> bool | None:
|
def is_on(self) -> bool | None:
|
||||||
|
|
|
@ -447,7 +447,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_turn_on(self, **kwargs: Any) -> None:
|
async def async_turn_on(self, **kwargs: Any) -> None:
|
||||||
"""Turn on the entity.
|
"""Turn on the entity.
|
||||||
|
|
|
@ -214,7 +214,7 @@ class MqttImage(MqttEntity, ImageEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_image(self) -> bytes | None:
|
async def async_image(self) -> bytes | None:
|
||||||
"""Return bytes of image."""
|
"""Return bytes of image."""
|
||||||
|
|
|
@ -198,7 +198,7 @@ class MqttLawnMower(MqttEntity, LawnMowerEntity, RestoreEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
if self._attr_assumed_state and (
|
if self._attr_assumed_state and (
|
||||||
last_state := await self.async_get_last_state()
|
last_state := await self.async_get_last_state()
|
||||||
|
|
|
@ -627,7 +627,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
last_state = await self.async_get_last_state()
|
last_state = await self.async_get_last_state()
|
||||||
|
|
||||||
def restore_state(
|
def restore_state(
|
||||||
|
|
|
@ -528,7 +528,7 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
last_state = await self.async_get_last_state()
|
last_state = await self.async_get_last_state()
|
||||||
if self._optimistic and last_state:
|
if self._optimistic and last_state:
|
||||||
|
|
|
@ -288,7 +288,7 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
last_state = await self.async_get_last_state()
|
last_state = await self.async_get_last_state()
|
||||||
if self._optimistic and last_state:
|
if self._optimistic and last_state:
|
||||||
|
|
|
@ -243,7 +243,7 @@ class MqttLock(MqttEntity, LockEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_lock(self, **kwargs: Any) -> None:
|
async def async_lock(self, **kwargs: Any) -> None:
|
||||||
"""Lock the device.
|
"""Lock the device.
|
||||||
|
|
|
@ -114,7 +114,7 @@ from .models import (
|
||||||
from .subscription import (
|
from .subscription import (
|
||||||
EntitySubscription,
|
EntitySubscription,
|
||||||
async_prepare_subscribe_topics,
|
async_prepare_subscribe_topics,
|
||||||
async_subscribe_topics,
|
async_subscribe_topics_internal,
|
||||||
async_unsubscribe_topics,
|
async_unsubscribe_topics,
|
||||||
)
|
)
|
||||||
from .util import mqtt_config_entry_enabled
|
from .util import mqtt_config_entry_enabled
|
||||||
|
@ -413,7 +413,7 @@ class MqttAttributesMixin(Entity):
|
||||||
"""Subscribe MQTT events."""
|
"""Subscribe MQTT events."""
|
||||||
await super().async_added_to_hass()
|
await super().async_added_to_hass()
|
||||||
self._attributes_prepare_subscribe_topics()
|
self._attributes_prepare_subscribe_topics()
|
||||||
await self._attributes_subscribe_topics()
|
self._attributes_subscribe_topics()
|
||||||
|
|
||||||
def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None:
|
def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None:
|
||||||
"""Handle updated discovery message."""
|
"""Handle updated discovery message."""
|
||||||
|
@ -422,7 +422,7 @@ class MqttAttributesMixin(Entity):
|
||||||
|
|
||||||
async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None:
|
async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None:
|
||||||
"""Handle updated discovery message."""
|
"""Handle updated discovery message."""
|
||||||
await self._attributes_subscribe_topics()
|
self._attributes_subscribe_topics()
|
||||||
|
|
||||||
def _attributes_prepare_subscribe_topics(self) -> None:
|
def _attributes_prepare_subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
|
@ -447,9 +447,10 @@ class MqttAttributesMixin(Entity):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _attributes_subscribe_topics(self) -> None:
|
@callback
|
||||||
|
def _attributes_subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await async_subscribe_topics(self.hass, self._attributes_sub_state)
|
async_subscribe_topics_internal(self.hass, self._attributes_sub_state)
|
||||||
|
|
||||||
async def async_will_remove_from_hass(self) -> None:
|
async def async_will_remove_from_hass(self) -> None:
|
||||||
"""Unsubscribe when removed."""
|
"""Unsubscribe when removed."""
|
||||||
|
@ -494,7 +495,7 @@ class MqttAvailabilityMixin(Entity):
|
||||||
"""Subscribe MQTT events."""
|
"""Subscribe MQTT events."""
|
||||||
await super().async_added_to_hass()
|
await super().async_added_to_hass()
|
||||||
self._availability_prepare_subscribe_topics()
|
self._availability_prepare_subscribe_topics()
|
||||||
await self._availability_subscribe_topics()
|
self._availability_subscribe_topics()
|
||||||
self.async_on_remove(
|
self.async_on_remove(
|
||||||
async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect)
|
async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect)
|
||||||
)
|
)
|
||||||
|
@ -511,7 +512,7 @@ class MqttAvailabilityMixin(Entity):
|
||||||
|
|
||||||
async def availability_discovery_update(self, config: DiscoveryInfoType) -> None:
|
async def availability_discovery_update(self, config: DiscoveryInfoType) -> None:
|
||||||
"""Handle updated discovery message."""
|
"""Handle updated discovery message."""
|
||||||
await self._availability_subscribe_topics()
|
self._availability_subscribe_topics()
|
||||||
|
|
||||||
def _availability_setup_from_config(self, config: ConfigType) -> None:
|
def _availability_setup_from_config(self, config: ConfigType) -> None:
|
||||||
"""(Re)Setup."""
|
"""(Re)Setup."""
|
||||||
|
@ -579,9 +580,10 @@ class MqttAvailabilityMixin(Entity):
|
||||||
self._available[topic] = False
|
self._available[topic] = False
|
||||||
self._available_latest = False
|
self._available_latest = False
|
||||||
|
|
||||||
async def _availability_subscribe_topics(self) -> None:
|
@callback
|
||||||
|
def _availability_subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await async_subscribe_topics(self.hass, self._availability_sub_state)
|
async_subscribe_topics_internal(self.hass, self._availability_sub_state)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_mqtt_connect(self) -> None:
|
def async_mqtt_connect(self) -> None:
|
||||||
|
|
|
@ -220,7 +220,7 @@ class MqttNumber(MqttEntity, RestoreNumber):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
if self._attr_assumed_state and (
|
if self._attr_assumed_state and (
|
||||||
last_number_data := await self.async_get_last_number_data()
|
last_number_data := await self.async_get_last_number_data()
|
||||||
|
|
|
@ -160,7 +160,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
if self._attr_assumed_state and (
|
if self._attr_assumed_state and (
|
||||||
last_state := await self.async_get_last_state()
|
last_state := await self.async_get_last_state()
|
||||||
|
|
|
@ -305,7 +305,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _value_is_expired(self, *_: datetime) -> None:
|
def _value_is_expired(self, *_: datetime) -> None:
|
||||||
|
|
|
@ -288,7 +288,7 @@ class MqttSiren(MqttEntity, SirenEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def extra_state_attributes(self) -> dict[str, Any] | None:
|
def extra_state_attributes(self) -> dict[str, Any] | None:
|
||||||
|
|
|
@ -2,14 +2,15 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
|
||||||
from .. import mqtt
|
|
||||||
from . import debug_info
|
from . import debug_info
|
||||||
|
from .client import async_subscribe_internal
|
||||||
from .const import DEFAULT_QOS
|
from .const import DEFAULT_QOS
|
||||||
from .models import MessageCallbackType
|
from .models import MessageCallbackType
|
||||||
|
|
||||||
|
@ -21,7 +22,7 @@ class EntitySubscription:
|
||||||
hass: HomeAssistant
|
hass: HomeAssistant
|
||||||
topic: str | None
|
topic: str | None
|
||||||
message_callback: MessageCallbackType
|
message_callback: MessageCallbackType
|
||||||
subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None
|
should_subscribe: bool | None
|
||||||
unsubscribe_callback: Callable[[], None] | None
|
unsubscribe_callback: Callable[[], None] | None
|
||||||
qos: int = 0
|
qos: int = 0
|
||||||
encoding: str = "utf-8"
|
encoding: str = "utf-8"
|
||||||
|
@ -53,15 +54,16 @@ class EntitySubscription:
|
||||||
self.hass, self.message_callback, self.topic, self.entity_id
|
self.hass, self.message_callback, self.topic, self.entity_id
|
||||||
)
|
)
|
||||||
|
|
||||||
self.subscribe_task = mqtt.async_subscribe(
|
self.should_subscribe = True
|
||||||
hass, self.topic, self.message_callback, self.qos, self.encoding
|
|
||||||
)
|
|
||||||
|
|
||||||
async def subscribe(self) -> None:
|
@callback
|
||||||
|
def subscribe(self) -> None:
|
||||||
"""Subscribe to a topic."""
|
"""Subscribe to a topic."""
|
||||||
if not self.subscribe_task:
|
if not self.should_subscribe or not self.topic:
|
||||||
return
|
return
|
||||||
self.unsubscribe_callback = await self.subscribe_task
|
self.unsubscribe_callback = async_subscribe_internal(
|
||||||
|
self.hass, self.topic, self.message_callback, self.qos, self.encoding
|
||||||
|
)
|
||||||
|
|
||||||
def _should_resubscribe(self, other: EntitySubscription | None) -> bool:
|
def _should_resubscribe(self, other: EntitySubscription | None) -> bool:
|
||||||
"""Check if we should re-subscribe to the topic using the old state."""
|
"""Check if we should re-subscribe to the topic using the old state."""
|
||||||
|
@ -79,6 +81,7 @@ class EntitySubscription:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_prepare_subscribe_topics(
|
def async_prepare_subscribe_topics(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
new_state: dict[str, EntitySubscription] | None,
|
new_state: dict[str, EntitySubscription] | None,
|
||||||
|
@ -107,7 +110,7 @@ def async_prepare_subscribe_topics(
|
||||||
qos=value.get("qos", DEFAULT_QOS),
|
qos=value.get("qos", DEFAULT_QOS),
|
||||||
encoding=value.get("encoding", "utf-8"),
|
encoding=value.get("encoding", "utf-8"),
|
||||||
hass=hass,
|
hass=hass,
|
||||||
subscribe_task=None,
|
should_subscribe=None,
|
||||||
entity_id=value.get("entity_id", None),
|
entity_id=value.get("entity_id", None),
|
||||||
)
|
)
|
||||||
# Get the current subscription state
|
# Get the current subscription state
|
||||||
|
@ -135,12 +138,29 @@ async def async_subscribe_topics(
|
||||||
sub_state: dict[str, EntitySubscription],
|
sub_state: dict[str, EntitySubscription],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""(Re)Subscribe to a set of MQTT topics."""
|
"""(Re)Subscribe to a set of MQTT topics."""
|
||||||
|
async_subscribe_topics_internal(hass, sub_state)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_subscribe_topics_internal(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
sub_state: dict[str, EntitySubscription],
|
||||||
|
) -> None:
|
||||||
|
"""(Re)Subscribe to a set of MQTT topics.
|
||||||
|
|
||||||
|
This function is internal to the MQTT integration and should not be called
|
||||||
|
from outside the integration.
|
||||||
|
"""
|
||||||
for sub in sub_state.values():
|
for sub in sub_state.values():
|
||||||
await sub.subscribe()
|
sub.subscribe()
|
||||||
|
|
||||||
|
|
||||||
def async_unsubscribe_topics(
|
if TYPE_CHECKING:
|
||||||
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
|
|
||||||
) -> dict[str, EntitySubscription]:
|
def async_unsubscribe_topics(
|
||||||
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
|
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
|
||||||
return async_prepare_subscribe_topics(hass, sub_state, {})
|
) -> dict[str, EntitySubscription]:
|
||||||
|
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
|
||||||
|
|
||||||
|
|
||||||
|
async_unsubscribe_topics = partial(async_prepare_subscribe_topics, topics={})
|
||||||
|
|
|
@ -151,7 +151,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
if self._optimistic and (last_state := await self.async_get_last_state()):
|
if self._optimistic and (last_state := await self.async_get_last_state()):
|
||||||
self._attr_is_on = last_state.state == STATE_ON
|
self._attr_is_on = last_state.state == STATE_ON
|
||||||
|
|
|
@ -167,7 +167,7 @@ class MQTTTagScanner(MqttDiscoveryDeviceUpdateMixin):
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_tear_down(self) -> None:
|
async def async_tear_down(self) -> None:
|
||||||
"""Cleanup tag scanner."""
|
"""Cleanup tag scanner."""
|
||||||
|
|
|
@ -198,7 +198,7 @@ class MqttTextEntity(MqttEntity, TextEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_set_value(self, value: str) -> None:
|
async def async_set_value(self, value: str) -> None:
|
||||||
"""Change the text."""
|
"""Change the text."""
|
||||||
|
|
|
@ -257,7 +257,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_install(
|
async def async_install(
|
||||||
self, version: str | None, backup: bool, **kwargs: Any
|
self, version: str | None, backup: bool, **kwargs: Any
|
||||||
|
|
|
@ -353,7 +353,7 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def _async_publish_command(self, feature: VacuumEntityFeature) -> None:
|
async def _async_publish_command(self, feature: VacuumEntityFeature) -> None:
|
||||||
"""Publish a command."""
|
"""Publish a command."""
|
||||||
|
|
|
@ -371,7 +371,7 @@ class MqttValve(MqttEntity, ValveEntity):
|
||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_open_valve(self) -> None:
|
async def async_open_valve(self) -> None:
|
||||||
"""Move the valve up.
|
"""Move the valve up.
|
||||||
|
|
|
@ -1051,6 +1051,27 @@ async def test_subscribe_topic_not_initialize(
|
||||||
await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_subscribe_mqtt_config_entry_disabled(
|
||||||
|
hass: HomeAssistant, mqtt_mock: MqttMockHAClient
|
||||||
|
) -> None:
|
||||||
|
"""Test the subscription of a topic when MQTT config entry is disabled."""
|
||||||
|
mqtt_mock.connected = True
|
||||||
|
|
||||||
|
mqtt_config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
|
||||||
|
assert mqtt_config_entry.state is ConfigEntryState.LOADED
|
||||||
|
|
||||||
|
assert await hass.config_entries.async_unload(mqtt_config_entry.entry_id)
|
||||||
|
assert mqtt_config_entry.state is ConfigEntryState.NOT_LOADED
|
||||||
|
|
||||||
|
await hass.config_entries.async_set_disabled_by(
|
||||||
|
mqtt_config_entry.entry_id, ConfigEntryDisabler.USER
|
||||||
|
)
|
||||||
|
mqtt_mock.connected = False
|
||||||
|
|
||||||
|
with pytest.raises(HomeAssistantError, match=r".*MQTT is not enabled"):
|
||||||
|
await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||||
|
|
||||||
|
|
||||||
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
||||||
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2)
|
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2)
|
||||||
async def test_subscribe_and_resubscribe(
|
async def test_subscribe_and_resubscribe(
|
||||||
|
@ -3824,7 +3845,7 @@ async def test_unload_config_entry(
|
||||||
async def test_publish_or_subscribe_without_valid_config_entry(
|
async def test_publish_or_subscribe_without_valid_config_entry(
|
||||||
hass: HomeAssistant, record_calls: MessageCallbackType
|
hass: HomeAssistant, record_calls: MessageCallbackType
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test internal publish function with bas use cases."""
|
"""Test internal publish function with bad use cases."""
|
||||||
with pytest.raises(HomeAssistantError):
|
with pytest.raises(HomeAssistantError):
|
||||||
await mqtt.async_publish(
|
await mqtt.async_publish(
|
||||||
hass, "some-topic", "test-payload", qos=0, retain=False, encoding=None
|
hass, "some-topic", "test-payload", qos=0, retain=False, encoding=None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue