Improve tests.common typing (#122257)

This commit is contained in:
Marc Mueller 2024-07-20 17:34:43 +02:00 committed by GitHub
parent 90e7d82049
commit 769d7214a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -8,6 +8,8 @@ from collections.abc import (
Callable, Callable,
Coroutine, Coroutine,
Generator, Generator,
Iterable,
Iterator,
Mapping, Mapping,
Sequence, Sequence,
) )
@ -30,6 +32,7 @@ from unittest.mock import AsyncMock, Mock, patch
from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401 from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401
import pytest import pytest
from syrupy import SnapshotAssertion from syrupy import SnapshotAssertion
from typing_extensions import TypeVar
import voluptuous as vol import voluptuous as vol
from homeassistant import auth, bootstrap, config_entries, loader from homeassistant import auth, bootstrap, config_entries, loader
@ -90,6 +93,7 @@ from homeassistant.helpers.json import JSONEncoder, _orjson_default_encoder, jso
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.util.event_type import EventType
from homeassistant.util.json import ( from homeassistant.util.json import (
JsonArrayType, JsonArrayType,
JsonObjectType, JsonObjectType,
@ -107,6 +111,8 @@ from .testing_config.custom_components.test_constant_deprecation import (
import_deprecated_constant, import_deprecated_constant,
) )
_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=dict[str, Any])
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
INSTANCES = [] INSTANCES = []
CLIENT_ID = "https://example.com/app" CLIENT_ID = "https://example.com/app"
@ -1434,7 +1440,7 @@ async def get_system_health_info(hass: HomeAssistant, domain: str) -> dict[str,
@contextmanager @contextmanager
def mock_config_flow(domain: str, config_flow: type[ConfigFlow]) -> None: def mock_config_flow(domain: str, config_flow: type[ConfigFlow]) -> Iterator[None]:
"""Mock a config flow handler.""" """Mock a config flow handler."""
original_handler = config_entries.HANDLERS.get(domain) original_handler = config_entries.HANDLERS.get(domain)
config_entries.HANDLERS[domain] = config_flow config_entries.HANDLERS[domain] = config_flow
@ -1502,12 +1508,14 @@ def mock_platform(
module_cache[platform_path] = module or Mock() module_cache[platform_path] = module or Mock()
def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]: def async_capture_events(
hass: HomeAssistant, event_name: EventType[_DataT] | str
) -> list[Event[_DataT]]:
"""Create a helper that captures events.""" """Create a helper that captures events."""
events = [] events: list[Event[_DataT]] = []
@callback @callback
def capture_events(event: Event) -> None: def capture_events(event: Event[_DataT]) -> None:
events.append(event) events.append(event)
hass.bus.async_listen(event_name, capture_events) hass.bus.async_listen(event_name, capture_events)
@ -1516,14 +1524,14 @@ def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]:
@callback @callback
def async_mock_signal( def async_mock_signal[*_Ts](
hass: HomeAssistant, signal: SignalType[Any] | str hass: HomeAssistant, signal: SignalType[*_Ts] | str
) -> list[tuple[Any]]: ) -> list[tuple[*_Ts]]:
"""Catch all dispatches to a signal.""" """Catch all dispatches to a signal."""
calls = [] calls: list[tuple[*_Ts]] = []
@callback @callback
def mock_signal_handler(*args: Any) -> None: def mock_signal_handler(*args: *_Ts) -> None:
"""Mock service call.""" """Mock service call."""
calls.append(args) calls.append(args)
@ -1723,7 +1731,7 @@ def extract_stack_to_frame(extract_stack: list[Mock]) -> FrameType:
def setup_test_component_platform( def setup_test_component_platform(
hass: HomeAssistant, hass: HomeAssistant,
domain: str, domain: str,
entities: Sequence[Entity], entities: Iterable[Entity],
from_config_entry: bool = False, from_config_entry: bool = False,
built_in: bool = True, built_in: bool = True,
) -> MockPlatform: ) -> MockPlatform: