Make mqtt internal subscription a normal function (#118092)

Co-authored-by: Jan Bouwhuis <jbouwh@users.noreply.github.com>
This commit is contained in:
J. Nick Koston 2024-05-25 11:34:24 -10:00 committed by GitHub
parent ecd48cc447
commit 9be829ba1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 140 additions and 83 deletions

View file

@ -2,14 +2,15 @@
from __future__ import annotations
from collections.abc import Callable, Coroutine
from collections.abc import Callable
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from .. import mqtt
from . import debug_info
from .client import async_subscribe_internal
from .const import DEFAULT_QOS
from .models import MessageCallbackType
@ -21,7 +22,7 @@ class EntitySubscription:
hass: HomeAssistant
topic: str | None
message_callback: MessageCallbackType
subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None
should_subscribe: bool | None
unsubscribe_callback: Callable[[], None] | None
qos: int = 0
encoding: str = "utf-8"
@ -53,15 +54,16 @@ class EntitySubscription:
self.hass, self.message_callback, self.topic, self.entity_id
)
self.subscribe_task = mqtt.async_subscribe(
hass, self.topic, self.message_callback, self.qos, self.encoding
)
self.should_subscribe = True
async def subscribe(self) -> None:
@callback
def subscribe(self) -> None:
"""Subscribe to a topic."""
if not self.subscribe_task:
if not self.should_subscribe or not self.topic:
return
self.unsubscribe_callback = await self.subscribe_task
self.unsubscribe_callback = async_subscribe_internal(
self.hass, self.topic, self.message_callback, self.qos, self.encoding
)
def _should_resubscribe(self, other: EntitySubscription | None) -> bool:
"""Check if we should re-subscribe to the topic using the old state."""
@ -79,6 +81,7 @@ class EntitySubscription:
)
@callback
def async_prepare_subscribe_topics(
hass: HomeAssistant,
new_state: dict[str, EntitySubscription] | None,
@ -107,7 +110,7 @@ def async_prepare_subscribe_topics(
qos=value.get("qos", DEFAULT_QOS),
encoding=value.get("encoding", "utf-8"),
hass=hass,
subscribe_task=None,
should_subscribe=None,
entity_id=value.get("entity_id", None),
)
# Get the current subscription state
@ -135,12 +138,29 @@ async def async_subscribe_topics(
sub_state: dict[str, EntitySubscription],
) -> None:
"""(Re)Subscribe to a set of MQTT topics."""
async_subscribe_topics_internal(hass, sub_state)
@callback
def async_subscribe_topics_internal(
hass: HomeAssistant,
sub_state: dict[str, EntitySubscription],
) -> None:
"""(Re)Subscribe to a set of MQTT topics.
This function is internal to the MQTT integration and should not be called
from outside the integration.
"""
for sub in sub_state.values():
await sub.subscribe()
sub.subscribe()
def async_unsubscribe_topics(
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
) -> dict[str, EntitySubscription]:
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
return async_prepare_subscribe_topics(hass, sub_state, {})
if TYPE_CHECKING:
def async_unsubscribe_topics(
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
) -> dict[str, EntitySubscription]:
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
async_unsubscribe_topics = partial(async_prepare_subscribe_topics, topics={})