Move MqttServiceInfo to init.py (#60905)

Co-authored-by: epenet <epenet@users.noreply.github.com>
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
epenet 2021-12-03 19:34:48 +01:00 committed by GitHub
parent 74d1c340d7
commit b65b25c1bb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 66 additions and 72 deletions

View file

@ -2,6 +2,8 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
import datetime as dt
from functools import lru_cache, partial, wraps
import inspect
from itertools import groupby
@ -38,9 +40,11 @@ from homeassistant.core import (
ServiceCall,
callback,
)
from homeassistant.data_entry_flow import BaseServiceInfo
from homeassistant.exceptions import HomeAssistantError, TemplateError, Unauthorized
from homeassistant.helpers import config_validation as cv, event, template
from homeassistant.helpers.dispatcher import async_dispatcher_connect, dispatcher_send
from homeassistant.helpers.frame import report
from homeassistant.helpers.typing import ConfigType, ServiceDataType
from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util
@ -246,6 +250,36 @@ MQTT_PUBLISH_SCHEMA = vol.All(
SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None
@dataclass
class MqttServiceInfo(BaseServiceInfo):
"""Prepared info from mqtt entries."""
topic: str
payload: ReceivePayloadType
qos: int
retain: bool
subscribed_topic: str
timestamp: dt.datetime
# Used to prevent log flooding. To be removed in 2022.6
_warning_logged: bool = False
def __getitem__(self, name: str) -> Any:
"""
Allow property access by name for compatibility reason.
Deprecated, and will be removed in version 2022.6.
"""
if not self._warning_logged:
report(
f"accessed discovery_info['{name}'] instead of discovery_info.{name}; this will fail in version 2022.6",
exclude_integrations={"mqtt"},
error_if_core=False,
)
self._warning_logged = True
return getattr(self, name)
def _build_publish_data(topic: Any, qos: int, retain: bool) -> ServiceDataType:
"""Build the arguments for the publish service without the payload."""
data = {ATTR_TOPIC: topic}

View file

@ -1,24 +1,20 @@
"""Support for MQTT discovery."""
import asyncio
from collections import deque
from dataclasses import dataclass
import datetime as dt
import functools
import json
import logging
import re
import time
from typing import Any
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import RESULT_TYPE_ABORT, BaseServiceInfo
from homeassistant.data_entry_flow import RESULT_TYPE_ABORT
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.frame import report
from homeassistant.loader import async_get_mqtt
from .. import mqtt
@ -31,7 +27,6 @@ from .const import (
CONF_TOPIC,
DOMAIN,
)
from .models import ReceivePayloadType
_LOGGER = logging.getLogger(__name__)
@ -91,36 +86,6 @@ class MQTTConfig(dict):
"""Dummy class to allow adding attributes."""
@dataclass
class MqttServiceInfo(BaseServiceInfo):
"""Prepared info from mqtt entries."""
topic: str
payload: ReceivePayloadType
qos: int
retain: bool
subscribed_topic: str
timestamp: dt.datetime
# Used to prevent log flooding. To be removed in 2022.6
_warning_logged: bool = False
def __getitem__(self, name: str) -> Any:
"""
Allow property access by name for compatibility reason.
Deprecated, and will be removed in version 2022.6.
"""
if not self._warning_logged:
report(
f"accessed discovery_info['{name}'] instead of discovery_info.{name}; this will fail in version 2022.6",
exclude_integrations={"mqtt"},
error_if_core=False,
)
self._warning_logged = True
return getattr(self, name)
async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic, config_entry=None
) -> None:
@ -323,7 +288,7 @@ async def async_start( # noqa: C901
if key not in hass.data[INTEGRATION_UNSUBSCRIBE]:
return
data = MqttServiceInfo(
data = mqtt.MqttServiceInfo(
topic=msg.topic,
payload=msg.payload,
qos=msg.qos,

View file

@ -6,7 +6,7 @@ from typing import Any
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.components.mqtt import discovery as mqtt, valid_subscribe_topic
from homeassistant.components.mqtt import MqttServiceInfo, valid_subscribe_topic
from homeassistant.data_entry_flow import FlowResult
from .const import CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX, DOMAIN
@ -21,7 +21,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Initialize flow."""
self._prefix = DEFAULT_PREFIX
async def async_step_mqtt(self, discovery_info: mqtt.MqttServiceInfo) -> FlowResult:
async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult:
"""Handle a flow initialized by MQTT discovery."""
if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")

View file

@ -35,7 +35,7 @@ import homeassistant.util.uuid as uuid_util
if TYPE_CHECKING:
from homeassistant.components.dhcp import DhcpServiceInfo
from homeassistant.components.hassio import HassioServiceInfo
from homeassistant.components.mqtt.discovery import MqttServiceInfo
from homeassistant.components.mqtt import MqttServiceInfo
from homeassistant.components.ssdp import SsdpServiceInfo
from homeassistant.components.usb import UsbServiceInfo
from homeassistant.components.zeroconf import ZeroconfServiceInfo

View file

@ -5,8 +5,7 @@ import logging
from typing import Any, Awaitable, Callable, Union
from homeassistant import config_entries
from homeassistant.components import dhcp, ssdp, zeroconf
from homeassistant.components.mqtt import discovery as mqtt
from homeassistant.components import dhcp, mqtt, ssdp, zeroconf
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.typing import UNDEFINED, DiscoveryInfoType, UndefinedType

View file

@ -11,11 +11,7 @@ from homeassistant.components.mqtt.abbreviations import (
ABBREVIATIONS,
DEVICE_ABBREVIATIONS,
)
from homeassistant.components.mqtt.discovery import (
ALREADY_DISCOVERED,
MqttServiceInfo,
async_start,
)
from homeassistant.components.mqtt.discovery import ALREADY_DISCOVERED, async_start
from homeassistant.const import (
EVENT_STATE_CHANGED,
STATE_OFF,
@ -909,27 +905,3 @@ async def test_mqtt_discovery_unsubscribe_once(hass, mqtt_client_mock, mqtt_mock
await hass.async_block_till_done()
await hass.async_block_till_done()
mqtt_client_mock.unsubscribe.assert_called_once_with("comp/discovery/#")
async def test_service_info_compatibility(hass, caplog):
"""Test compatibility with old-style dict.
To be removed in 2022.6
"""
discovery_info = MqttServiceInfo(
topic="tasmota/discovery/DC4F220848A2/config",
payload="",
qos=0,
retain=False,
subscribed_topic="tasmota/discovery/#",
timestamp=None,
)
# Ensure first call get logged
assert discovery_info["topic"] == "tasmota/discovery/DC4F220848A2/config"
assert "Detected code that accessed discovery_info['topic']" in caplog.text
# Ensure second call doesn't get logged
caplog.clear()
assert discovery_info["topic"] == "tasmota/discovery/DC4F220848A2/config"
assert "Detected code that accessed discovery_info['topic']" not in caplog.text

View file

@ -1812,3 +1812,27 @@ async def test_publish_json_from_template(hass, mqtt_mock):
assert mqtt_mock.async_publish.called
assert mqtt_mock.async_publish.call_args[0][1] == test_str
async def test_service_info_compatibility(hass, caplog):
"""Test compatibility with old-style dict.
To be removed in 2022.6
"""
discovery_info = mqtt.MqttServiceInfo(
topic="tasmota/discovery/DC4F220848A2/config",
payload="",
qos=0,
retain=False,
subscribed_topic="tasmota/discovery/#",
timestamp=None,
)
# Ensure first call get logged
assert discovery_info["topic"] == "tasmota/discovery/DC4F220848A2/config"
assert "Detected code that accessed discovery_info['topic']" in caplog.text
# Ensure second call doesn't get logged
caplog.clear()
assert discovery_info["topic"] == "tasmota/discovery/DC4F220848A2/config"
assert "Detected code that accessed discovery_info['topic']" not in caplog.text

View file

@ -1,6 +1,6 @@
"""Test config flow."""
from homeassistant import config_entries
from homeassistant.components.mqtt import discovery as mqtt
from homeassistant.components import mqtt
from tests.common import MockConfigEntry