Improve tests.common typing (#122257)
This commit is contained in:
parent
90e7d82049
commit
769d7214a3
1 changed files with 18 additions and 10 deletions
|
@ -8,6 +8,8 @@ from collections.abc import (
|
|||
Callable,
|
||||
Coroutine,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Sequence,
|
||||
)
|
||||
|
@ -30,6 +32,7 @@ from unittest.mock import AsyncMock, Mock, patch
|
|||
from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401
|
||||
import pytest
|
||||
from syrupy import SnapshotAssertion
|
||||
from typing_extensions import TypeVar
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import auth, bootstrap, config_entries, loader
|
||||
|
@ -90,6 +93,7 @@ from homeassistant.helpers.json import JSONEncoder, _orjson_default_encoder, jso
|
|||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.util.event_type import EventType
|
||||
from homeassistant.util.json import (
|
||||
JsonArrayType,
|
||||
JsonObjectType,
|
||||
|
@ -107,6 +111,8 @@ from .testing_config.custom_components.test_constant_deprecation import (
|
|||
import_deprecated_constant,
|
||||
)
|
||||
|
||||
_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=dict[str, Any])
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
INSTANCES = []
|
||||
CLIENT_ID = "https://example.com/app"
|
||||
|
@ -1434,7 +1440,7 @@ async def get_system_health_info(hass: HomeAssistant, domain: str) -> dict[str,
|
|||
|
||||
|
||||
@contextmanager
|
||||
def mock_config_flow(domain: str, config_flow: type[ConfigFlow]) -> None:
|
||||
def mock_config_flow(domain: str, config_flow: type[ConfigFlow]) -> Iterator[None]:
|
||||
"""Mock a config flow handler."""
|
||||
original_handler = config_entries.HANDLERS.get(domain)
|
||||
config_entries.HANDLERS[domain] = config_flow
|
||||
|
@ -1502,12 +1508,14 @@ def mock_platform(
|
|||
module_cache[platform_path] = module or Mock()
|
||||
|
||||
|
||||
def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]:
|
||||
def async_capture_events(
|
||||
hass: HomeAssistant, event_name: EventType[_DataT] | str
|
||||
) -> list[Event[_DataT]]:
|
||||
"""Create a helper that captures events."""
|
||||
events = []
|
||||
events: list[Event[_DataT]] = []
|
||||
|
||||
@callback
|
||||
def capture_events(event: Event) -> None:
|
||||
def capture_events(event: Event[_DataT]) -> None:
|
||||
events.append(event)
|
||||
|
||||
hass.bus.async_listen(event_name, capture_events)
|
||||
|
@ -1516,14 +1524,14 @@ def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]:
|
|||
|
||||
|
||||
@callback
|
||||
def async_mock_signal(
|
||||
hass: HomeAssistant, signal: SignalType[Any] | str
|
||||
) -> list[tuple[Any]]:
|
||||
def async_mock_signal[*_Ts](
|
||||
hass: HomeAssistant, signal: SignalType[*_Ts] | str
|
||||
) -> list[tuple[*_Ts]]:
|
||||
"""Catch all dispatches to a signal."""
|
||||
calls = []
|
||||
calls: list[tuple[*_Ts]] = []
|
||||
|
||||
@callback
|
||||
def mock_signal_handler(*args: Any) -> None:
|
||||
def mock_signal_handler(*args: *_Ts) -> None:
|
||||
"""Mock service call."""
|
||||
calls.append(args)
|
||||
|
||||
|
@ -1723,7 +1731,7 @@ def extract_stack_to_frame(extract_stack: list[Mock]) -> FrameType:
|
|||
def setup_test_component_platform(
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
entities: Sequence[Entity],
|
||||
entities: Iterable[Entity],
|
||||
from_config_entry: bool = False,
|
||||
built_in: bool = True,
|
||||
) -> MockPlatform:
|
||||
|
|
Loading…
Add table
Reference in a new issue