Improve tests.common typing ()

This commit is contained in:
Marc Mueller 2024-07-20 17:34:43 +02:00 committed by GitHub
parent 90e7d82049
commit 769d7214a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -8,6 +8,8 @@ from collections.abc import (
Callable,
Coroutine,
Generator,
Iterable,
Iterator,
Mapping,
Sequence,
)
@ -30,6 +32,7 @@ from unittest.mock import AsyncMock, Mock, patch
from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401
import pytest
from syrupy import SnapshotAssertion
from typing_extensions import TypeVar
import voluptuous as vol
from homeassistant import auth, bootstrap, config_entries, loader
@ -90,6 +93,7 @@ from homeassistant.helpers.json import JSONEncoder, _orjson_default_encoder, jso
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.util.async_ import run_callback_threadsafe
import homeassistant.util.dt as dt_util
from homeassistant.util.event_type import EventType
from homeassistant.util.json import (
JsonArrayType,
JsonObjectType,
@ -107,6 +111,8 @@ from .testing_config.custom_components.test_constant_deprecation import (
import_deprecated_constant,
)
_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=dict[str, Any])
_LOGGER = logging.getLogger(__name__)
INSTANCES = []
CLIENT_ID = "https://example.com/app"
@ -1434,7 +1440,7 @@ async def get_system_health_info(hass: HomeAssistant, domain: str) -> dict[str,
@contextmanager
def mock_config_flow(domain: str, config_flow: type[ConfigFlow]) -> None:
def mock_config_flow(domain: str, config_flow: type[ConfigFlow]) -> Iterator[None]:
"""Mock a config flow handler."""
original_handler = config_entries.HANDLERS.get(domain)
config_entries.HANDLERS[domain] = config_flow
@ -1502,12 +1508,14 @@ def mock_platform(
module_cache[platform_path] = module or Mock()
def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]:
def async_capture_events(
hass: HomeAssistant, event_name: EventType[_DataT] | str
) -> list[Event[_DataT]]:
"""Create a helper that captures events."""
events = []
events: list[Event[_DataT]] = []
@callback
def capture_events(event: Event) -> None:
def capture_events(event: Event[_DataT]) -> None:
events.append(event)
hass.bus.async_listen(event_name, capture_events)
@ -1516,14 +1524,14 @@ def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]:
@callback
def async_mock_signal(
hass: HomeAssistant, signal: SignalType[Any] | str
) -> list[tuple[Any]]:
def async_mock_signal[*_Ts](
hass: HomeAssistant, signal: SignalType[*_Ts] | str
) -> list[tuple[*_Ts]]:
"""Catch all dispatches to a signal."""
calls = []
calls: list[tuple[*_Ts]] = []
@callback
def mock_signal_handler(*args: Any) -> None:
def mock_signal_handler(*args: *_Ts) -> None:
"""Mock service call."""
calls.append(args)
@ -1723,7 +1731,7 @@ def extract_stack_to_frame(extract_stack: list[Mock]) -> FrameType:
def setup_test_component_platform(
hass: HomeAssistant,
domain: str,
entities: Sequence[Entity],
entities: Iterable[Entity],
from_config_entry: bool = False,
built_in: bool = True,
) -> MockPlatform: