* Refactor zeroconf setup to be async Most of the setup was calling back to async because we were setting up listeners. Since we only need to jump into the executor to create the zeroconf instance, its much faster to setup in async. In testing this cut the setup time in half or better. * partial revert to after_deps
"""Set up some common test helper things."""
import asyncio
import datetime
import functools
import logging
import ssl
import threading
from aiohttp.test_utils import make_mocked_request
import pytest
import requests_mock as _requests_mock
from homeassistant import core as ha, loader, runner, util
from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
from homeassistant.auth.providers import homeassistant, legacy_api_password
from homeassistant.components import mqtt
from homeassistant.components.websocket_api.auth import (
from homeassistant.components.websocket_api.http import URL
from homeassistant.const import ATTR_NOW, EVENT_TIME_CHANGED
from homeassistant.exceptions import ServiceNotFound
from homeassistant.helpers import event
from homeassistant.setup import async_setup_component
from homeassistant.util import location
from tests.async_mock import MagicMock, Mock, patch
from tests.ignore_uncaught_exceptions import IGNORE_UNCAUGHT_EXCEPTIONS
from tests.common import ( # noqa: E402, isort:skip
mock_storage as mock_storage,
from tests.test_util.aiohttp import mock_aiohttp_client # noqa: E402, isort:skip
# Disable fixtures overriding our beautiful policy
asyncio.set_event_loop_policy = lambda policy: None
def pytest_configure(config):
"""Register marker for tests that log exceptions."""
"markers", "no_fail_on_log_exception: mark test to not fail on logged exception"
def check_real(func):
"""Force a function to require a keyword _test_real to be passed in."""
async def guard_func(*args, **kwargs):
real = kwargs.pop("_test_real", None)
if not real:
raise Exception(
'Forgot to mock or pass "_test_real=True" to %s', func.__name__
return await func(*args, **kwargs)
return guard_func
# Guard a few functions that would make network connections
location.async_detect_location_info = check_real(location.async_detect_location_info)
util.get_local_ip = lambda: ""
def verify_cleanup():
"""Verify that the test has cleaned up resources correctly."""
threads_before = frozenset(threading.enumerate())
if len(INSTANCES) >= 2:
count = len(INSTANCES)
for inst in INSTANCES:
pytest.exit(f"Detected non stopped instances ({count}), aborting test run")
threads = frozenset(threading.enumerate()) - threads_before
assert not threads
def hass_storage():
"""Fixture to mock storage."""
with mock_storage() as stored_data:
yield stored_data
def hass(loop, hass_storage, request):
"""Fixture to provide a test instance of Home Assistant."""
def exc_handle(loop, context):
"""Handle exceptions by rethrowing them, which will fail the test."""
orig_exception_handler(loop, context)
exceptions = []
hass = loop.run_until_complete(async_test_home_assistant(loop))
orig_exception_handler = loop.get_exception_handler()
yield hass
for ex in exceptions:
if (
if isinstance(ex, ServiceNotFound):
raise ex
async def stop_hass():
"""Make sure all hass are stopped."""
orig_hass = ha.HomeAssistant
created = []
def mock_hass():
hass_inst = orig_hass()
return hass_inst
with patch("homeassistant.core.HomeAssistant", mock_hass):
for hass_inst in created:
if hass_inst.state == ha.CoreState.stopped:
with patch.object(hass_inst.loop, "stop"):
await hass_inst.async_block_till_done()
await hass_inst.async_stop(force=True)
def requests_mock():
"""Fixture to provide a requests mocker."""
with _requests_mock.mock() as m:
yield m
def aioclient_mock():
"""Fixture to mock aioclient calls."""
with mock_aiohttp_client() as mock_session:
yield mock_session
def mock_device_tracker_conf():
"""Prevent device tracker from reading/writing data."""
devices = []
async def mock_update_config(path, id, entity):
with patch(
), patch(
side_effect=lambda *args: devices,
yield devices
def hass_access_token(hass, hass_admin_user):
"""Return an access token to access Home Assistant."""
refresh_token = hass.loop.run_until_complete(
hass.auth.async_create_refresh_token(hass_admin_user, CLIENT_ID)
return hass.auth.async_create_access_token(refresh_token)
def hass_owner_user(hass, local_auth):
"""Return a Home Assistant admin user."""
return MockUser(is_owner=True).add_to_hass(hass)
def hass_admin_user(hass, local_auth):
"""Return a Home Assistant admin user."""
admin_group = hass.loop.run_until_complete(
return MockUser(groups=[admin_group]).add_to_hass(hass)
def hass_read_only_user(hass, local_auth):
"""Return a Home Assistant read only user."""
read_only_group = hass.loop.run_until_complete(
return MockUser(groups=[read_only_group]).add_to_hass(hass)
def hass_read_only_access_token(hass, hass_read_only_user):
"""Return a Home Assistant read only user."""
refresh_token = hass.loop.run_until_complete(
hass.auth.async_create_refresh_token(hass_read_only_user, CLIENT_ID)
return hass.auth.async_create_access_token(refresh_token)
def legacy_auth(hass):
"""Load legacy API password provider."""
prv = legacy_api_password.LegacyApiPasswordAuthProvider(
{"type": "legacy_api_password", "api_password": "test-password"},
hass.auth._providers[(prv.type, prv.id)] = prv
return prv
def local_auth(hass):
"""Load local auth provider."""
prv = homeassistant.HassAuthProvider(
hass, hass.auth._store, {"type": "homeassistant"}
hass.auth._providers[(prv.type, prv.id)] = prv
return prv
def hass_client(hass, aiohttp_client, hass_access_token):
"""Return an authenticated HTTP client."""
async def auth_client():
"""Return an authenticated client."""
return await aiohttp_client(
hass.http.app, headers={"Authorization": f"Bearer {hass_access_token}"}
return auth_client
def current_request(hass):
"""Mock current request."""
with patch("homeassistant.helpers.network.current_request") as mock_request_context:
mocked_request = make_mocked_request(
headers={"Host": "example.com"},
mock_request_context.get = Mock(return_value=mocked_request)
yield mock_request_context
def hass_ws_client(aiohttp_client, hass_access_token, hass):
"""Websocket client fixture connected to websocket server."""
async def create_client(hass=hass, access_token=hass_access_token):
"""Create a websocket client."""
assert await async_setup_component(hass, "websocket_api", {})
client = await aiohttp_client(hass.http.app)
with patch("homeassistant.components.http.auth.setup_auth"):
websocket = await client.ws_connect(URL)
auth_resp = await websocket.receive_json()
assert auth_resp["type"] == TYPE_AUTH_REQUIRED
if access_token is None:
await websocket.send_json(
{"type": TYPE_AUTH, "access_token": "incorrect"}
await websocket.send_json(
{"type": TYPE_AUTH, "access_token": access_token}
auth_ok = await websocket.receive_json()
assert auth_ok["type"] == TYPE_AUTH_OK
# wrap in client
websocket.client = client
return websocket
return create_client
def fail_on_log_exception(request, monkeypatch):
"""Fixture to fail if a callback wrapped by catch_log_exception or coroutine wrapped by async_create_catching_coro throws."""
if "no_fail_on_log_exception" in request.keywords:
def log_exception(format_err, *args):
monkeypatch.setattr("homeassistant.util.logging.log_exception", log_exception)
def mqtt_config():
"""Fixture to allow overriding MQTT config."""
return None
def mqtt_client_mock(hass):
"""Fixture to mock MQTT client."""
mid = 0
def get_mid():
nonlocal mid
mid += 1
return mid
class FakeInfo:
def __init__(self, mid):
self.mid = mid
self.rc = 0
with patch("paho.mqtt.client.Client") as mock_client:
def _async_fire_mqtt_message(topic, payload, qos, retain):
async_fire_mqtt_message(hass, topic, payload, qos, retain)
mid = get_mid()
mock_client.on_publish(0, 0, mid)
return FakeInfo(mid)
def _subscribe(topic, qos=0):
mock_client.on_subscribe(0, 0, mid)
return (0, mid)
def _unsubscribe(topic):
mock_client.on_unsubscribe(0, 0, mid)
return (0, mid)
mock_client = mock_client.return_value
mock_client.connect.return_value = 0
mock_client.subscribe.side_effect = _subscribe
mock_client.unsubscribe.side_effect = _unsubscribe
mock_client.publish.side_effect = _async_fire_mqtt_message
yield mock_client
async def mqtt_mock(hass, mqtt_client_mock, mqtt_config):
"""Fixture to mock MQTT component."""
if mqtt_config is None:
mqtt_config = {mqtt.CONF_BROKER: "mock-broker"}
result = await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: mqtt_config})
assert result
await hass.async_block_till_done()
mqtt_component_mock = MagicMock(
mqtt_component_mock._mqttc = mqtt_client_mock
hass.data["mqtt"] = mqtt_component_mock
component = hass.data["mqtt"]
return component
def mock_zeroconf():
"""Mock zeroconf."""
with patch("homeassistant.components.zeroconf.HaZeroconf") as mock_zc:
yield mock_zc.return_value
def legacy_patchable_time():
"""Allow time to be patchable by using event listeners instead of asyncio loop."""
def async_track_point_in_utc_time(hass, action, point_in_time):
"""Add a listener that fires once after a specific point in UTC time."""
# Ensure point_in_time is UTC
point_in_time = event.dt_util.as_utc(point_in_time)
def point_in_time_listener(event):
"""Listen for matching time_changed events."""
now = event.data[ATTR_NOW]
if now < point_in_time or hasattr(point_in_time_listener, "run"):
# Set variable so that we will never run twice.
# Because the event bus might have to wait till a thread comes
# available to execute this listener it might occur that the
# listener gets lined up twice to be executed. This will make
# sure the second time it does nothing.
setattr(point_in_time_listener, "run", True)
hass.async_run_job(action, now)
async_unsub = hass.bus.async_listen(EVENT_TIME_CHANGED, point_in_time_listener)
return async_unsub
def async_track_utc_time_change(
hass, action, hour=None, minute=None, second=None, local=False
"""Add a listener that will fire if time matches a pattern."""
# We do not have to wrap the function with time pattern matching logic
# if no pattern given
if all(val is None for val in (hour, minute, second)):
def time_change_listener(ev) -> None:
"""Fire every time event that comes in."""
hass.async_run_job(action, ev.data[ATTR_NOW])
return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener)
matching_seconds = event.dt_util.parse_time_expression(second, 0, 59)
matching_minutes = event.dt_util.parse_time_expression(minute, 0, 59)
matching_hours = event.dt_util.parse_time_expression(hour, 0, 23)
next_time = None
def calculate_next(now) -> None:
"""Calculate and set the next time the trigger should fire."""
nonlocal next_time
localized_now = event.dt_util.as_local(now) if local else now
next_time = event.dt_util.find_next_time_expression_time(
localized_now, matching_seconds, matching_minutes, matching_hours
# Make sure rolling back the clock doesn't prevent the timer from
# triggering.
last_now = None
def pattern_time_change_listener(ev) -> None:
"""Listen for matching time_changed events."""
nonlocal next_time, last_now
now = ev.data[ATTR_NOW]
if last_now is None or now < last_now:
# Time rolled back or next time not yet calculated
last_now = now
if next_time <= now:
action, event.dt_util.as_local(now) if local else now
calculate_next(now + datetime.timedelta(seconds=1))
# We can't use async_track_point_in_utc_time here because it would
# break in the case that the system time abruptly jumps backwards.
# Our custom last_now logic takes care of resolving that scenario.
return hass.bus.async_listen(EVENT_TIME_CHANGED, pattern_time_change_listener)
with patch(
), patch(