Use the shared zeroconf instance when attempting to create another Zeroconf instance (#38744)

This commit is contained in:
J. Nick Koston 2020-08-12 09:08:33 -05:00 committed by GitHub
parent 34cb12d3c9
commit 444df4a7d2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 187 additions and 17 deletions

View file

@ -30,6 +30,8 @@ from homeassistant.helpers.network import NoURLAvailableError, get_url
from homeassistant.helpers.singleton import singleton
from homeassistant.loader import async_get_homekit, async_get_zeroconf
from .usage import install_multiple_zeroconf_catcher
_LOGGER = logging.getLogger(__name__)
DOMAIN = "zeroconf"
@ -135,6 +137,8 @@ def setup(hass, config):
ipv6=zc_config.get(CONF_IPV6, DEFAULT_IPV6),
)
install_multiple_zeroconf_catcher(zeroconf)
# Get instance UUID
uuid = asyncio.run_coroutine_threadsafe(
hass.helpers.instance_id.async_get(), hass.loop

View file

@ -0,0 +1,50 @@
"""Zeroconf usage utility to warn about multiple instances."""
import logging
import zeroconf
from homeassistant.helpers.frame import (
MissingIntegrationFrame,
get_integration_frame,
report_integration,
)
_LOGGER = logging.getLogger(__name__)
def install_multiple_zeroconf_catcher(hass_zc) -> None:
"""Wrap the Zeroconf class to return the shared instance if multiple instances are detected."""
def new_zeroconf_new(self, *k, **kw):
_report(
"attempted to create another Zeroconf instance. Please use the shared Zeroconf via await homeassistant.components.zeroconf.async_get_instance(hass)",
)
return hass_zc
def new_zeroconf_init(self, *k, **kw):
return
zeroconf.Zeroconf.__new__ = new_zeroconf_new
zeroconf.Zeroconf.__init__ = new_zeroconf_init
def _report(what: str) -> None:
"""Report incorrect usage.
Async friendly.
"""
integration_frame = None
try:
integration_frame = get_integration_frame(exclude_integrations={"zeroconf"})
except MissingIntegrationFrame:
pass
if not integration_frame:
_LOGGER.warning(
"Detected code that %s. Please report this issue.", what, stack_info=True
)
return
report_integration(what, integration_frame)

View file

@ -3,7 +3,7 @@ import asyncio
import functools
import logging
from traceback import FrameSummary, extract_stack
from typing import Any, Callable, Tuple, TypeVar, cast
from typing import Any, Callable, Optional, Tuple, TypeVar, cast
from homeassistant.exceptions import HomeAssistantError
@ -12,15 +12,24 @@ _LOGGER = logging.getLogger(__name__)
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name
def get_integration_frame() -> Tuple[FrameSummary, str, str]:
def get_integration_frame(
exclude_integrations: Optional[set] = None,
) -> Tuple[FrameSummary, str, str]:
"""Return the frame, integration and integration path of the current stack frame."""
found_frame = None
if not exclude_integrations:
exclude_integrations = set()
for frame in reversed(extract_stack()):
for path in ("custom_components/", "homeassistant/components/"):
try:
index = frame.filename.index(path)
found_frame = frame
start = index + len(path)
end = frame.filename.index("/", start)
integration = frame.filename[start:end]
if integration not in exclude_integrations:
found_frame = frame
break
except ValueError:
continue
@ -31,11 +40,6 @@ def get_integration_frame() -> Tuple[FrameSummary, str, str]:
if found_frame is None:
raise MissingIntegrationFrame
start = index + len(path)
end = found_frame.filename.index("/", start)
integration = found_frame.filename[start:end]
return found_frame, integration, path
@ -49,11 +53,24 @@ def report(what: str) -> None:
Async friendly.
"""
try:
found_frame, integration, path = get_integration_frame()
integration_frame = get_integration_frame()
except MissingIntegrationFrame:
# Did not source from an integration? Hard error.
raise RuntimeError(f"Detected code that {what}. Please report this issue.")
report_integration(what, integration_frame)
def report_integration(
what: str, integration_frame: Tuple[FrameSummary, str, str]
) -> None:
"""Report incorrect usage in an integration.
Async friendly.
"""
found_frame, integration, path = integration_frame
index = found_frame.filename.index(path)
if path == "custom_components/":
extra = " to the custom component author"

View file

@ -1,8 +1,15 @@
"""Fixtures for component testing."""
import pytest
from homeassistant.components import zeroconf
from tests.async_mock import patch
zeroconf.orig_install_multiple_zeroconf_catcher = (
zeroconf.install_multiple_zeroconf_catcher
)
zeroconf.install_multiple_zeroconf_catcher = lambda zc: None
@pytest.fixture(autouse=True)
def prevent_io():

View file

@ -0,0 +1,11 @@
"""conftest for zeroconf."""
import pytest
from tests.async_mock import patch
@pytest.fixture
def mock_zeroconf():
"""Mock zeroconf."""
with patch("homeassistant.components.zeroconf.HaZeroconf") as mock_zc:
yield mock_zc.return_value

View file

@ -1,5 +1,4 @@
"""Test Zeroconf component setup process."""
import pytest
from zeroconf import InterfaceChoice, IPVersion, ServiceInfo, ServiceStateChange
from homeassistant.components import zeroconf
@ -22,13 +21,6 @@ HOMEKIT_STATUS_UNPAIRED = b"1"
HOMEKIT_STATUS_PAIRED = b"0"
@pytest.fixture
def mock_zeroconf():
"""Mock zeroconf."""
with patch("homeassistant.components.zeroconf.HaZeroconf") as mock_zc:
yield mock_zc.return_value
def service_update_mock(zeroconf, services, handlers):
"""Call service update handler."""
for service in services:

View file

@ -0,0 +1,56 @@
"""Test Zeroconf multiple instance protection."""
import zeroconf
from homeassistant.components.zeroconf import async_get_instance
from homeassistant.components.zeroconf.usage import install_multiple_zeroconf_catcher
from tests.async_mock import Mock, patch
async def test_multiple_zeroconf_instances(hass, mock_zeroconf, caplog):
"""Test creating multiple zeroconf throws without an integration."""
zeroconf_instance = await async_get_instance(hass)
install_multiple_zeroconf_catcher(zeroconf_instance)
new_zeroconf_instance = zeroconf.Zeroconf()
assert new_zeroconf_instance == zeroconf_instance
assert "Zeroconf" in caplog.text
async def test_multiple_zeroconf_instances_gives_shared(hass, mock_zeroconf, caplog):
"""Test creating multiple zeroconf gives the shared instance to an integration."""
zeroconf_instance = await async_get_instance(hass)
install_multiple_zeroconf_catcher(zeroconf_instance)
correct_frame = Mock(
filename="/config/custom_components/burncpu/light.py",
lineno="23",
line="self.light.is_on",
)
with patch(
"homeassistant.helpers.frame.extract_stack",
return_value=[
Mock(
filename="/home/dev/homeassistant/core.py",
lineno="23",
line="do_something()",
),
correct_frame,
Mock(
filename="/home/dev/homeassistant/components/zeroconf/usage.py",
lineno="23",
line="self.light.is_on",
),
Mock(filename="/home/dev/mdns/lights.py", lineno="2", line="something()",),
],
):
assert zeroconf.Zeroconf() == zeroconf_instance
assert "custom_components/burncpu/light.py" in caplog.text
assert "23" in caplog.text
assert "self.light.is_on" in caplog.text

View file

@ -36,6 +36,39 @@ async def test_extract_frame_integration(caplog):
assert found_frame == correct_frame
async def test_extract_frame_integration_with_excluded_intergration(caplog):
"""Test extracting the current frame from integration context."""
correct_frame = Mock(
filename="/home/dev/homeassistant/components/mdns/light.py",
lineno="23",
line="self.light.is_on",
)
with patch(
"homeassistant.helpers.frame.extract_stack",
return_value=[
Mock(
filename="/home/dev/homeassistant/core.py",
lineno="23",
line="do_something()",
),
correct_frame,
Mock(
filename="/home/dev/homeassistant/components/zeroconf/usage.py",
lineno="23",
line="self.light.is_on",
),
Mock(filename="/home/dev/mdns/lights.py", lineno="2", line="something()",),
],
):
found_frame, integration, path = frame.get_integration_frame(
exclude_integrations={"zeroconf"}
)
assert integration == "mdns"
assert path == "homeassistant/components/"
assert found_frame == correct_frame
async def test_extract_frame_no_integration(caplog):
"""Test extracting the current frame without integration context."""
with patch(