Improve tests.common typing (#122257)
This commit is contained in:
parent
90e7d82049
commit
769d7214a3
1 changed files with 18 additions and 10 deletions
|
@ -8,6 +8,8 @@ from collections.abc import (
|
||||||
Callable,
|
Callable,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
Generator,
|
Generator,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
Mapping,
|
Mapping,
|
||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
|
@ -30,6 +32,7 @@ from unittest.mock import AsyncMock, Mock, patch
|
||||||
from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401
|
from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy import SnapshotAssertion
|
from syrupy import SnapshotAssertion
|
||||||
|
from typing_extensions import TypeVar
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import auth, bootstrap, config_entries, loader
|
from homeassistant import auth, bootstrap, config_entries, loader
|
||||||
|
@ -90,6 +93,7 @@ from homeassistant.helpers.json import JSONEncoder, _orjson_default_encoder, jso
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
from homeassistant.util.async_ import run_callback_threadsafe
|
from homeassistant.util.async_ import run_callback_threadsafe
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
from homeassistant.util.event_type import EventType
|
||||||
from homeassistant.util.json import (
|
from homeassistant.util.json import (
|
||||||
JsonArrayType,
|
JsonArrayType,
|
||||||
JsonObjectType,
|
JsonObjectType,
|
||||||
|
@ -107,6 +111,8 @@ from .testing_config.custom_components.test_constant_deprecation import (
|
||||||
import_deprecated_constant,
|
import_deprecated_constant,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=dict[str, Any])
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
INSTANCES = []
|
INSTANCES = []
|
||||||
CLIENT_ID = "https://example.com/app"
|
CLIENT_ID = "https://example.com/app"
|
||||||
|
@ -1434,7 +1440,7 @@ async def get_system_health_info(hass: HomeAssistant, domain: str) -> dict[str,
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def mock_config_flow(domain: str, config_flow: type[ConfigFlow]) -> None:
|
def mock_config_flow(domain: str, config_flow: type[ConfigFlow]) -> Iterator[None]:
|
||||||
"""Mock a config flow handler."""
|
"""Mock a config flow handler."""
|
||||||
original_handler = config_entries.HANDLERS.get(domain)
|
original_handler = config_entries.HANDLERS.get(domain)
|
||||||
config_entries.HANDLERS[domain] = config_flow
|
config_entries.HANDLERS[domain] = config_flow
|
||||||
|
@ -1502,12 +1508,14 @@ def mock_platform(
|
||||||
module_cache[platform_path] = module or Mock()
|
module_cache[platform_path] = module or Mock()
|
||||||
|
|
||||||
|
|
||||||
def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]:
|
def async_capture_events(
|
||||||
|
hass: HomeAssistant, event_name: EventType[_DataT] | str
|
||||||
|
) -> list[Event[_DataT]]:
|
||||||
"""Create a helper that captures events."""
|
"""Create a helper that captures events."""
|
||||||
events = []
|
events: list[Event[_DataT]] = []
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def capture_events(event: Event) -> None:
|
def capture_events(event: Event[_DataT]) -> None:
|
||||||
events.append(event)
|
events.append(event)
|
||||||
|
|
||||||
hass.bus.async_listen(event_name, capture_events)
|
hass.bus.async_listen(event_name, capture_events)
|
||||||
|
@ -1516,14 +1524,14 @@ def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]:
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_mock_signal(
|
def async_mock_signal[*_Ts](
|
||||||
hass: HomeAssistant, signal: SignalType[Any] | str
|
hass: HomeAssistant, signal: SignalType[*_Ts] | str
|
||||||
) -> list[tuple[Any]]:
|
) -> list[tuple[*_Ts]]:
|
||||||
"""Catch all dispatches to a signal."""
|
"""Catch all dispatches to a signal."""
|
||||||
calls = []
|
calls: list[tuple[*_Ts]] = []
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def mock_signal_handler(*args: Any) -> None:
|
def mock_signal_handler(*args: *_Ts) -> None:
|
||||||
"""Mock service call."""
|
"""Mock service call."""
|
||||||
calls.append(args)
|
calls.append(args)
|
||||||
|
|
||||||
|
@ -1723,7 +1731,7 @@ def extract_stack_to_frame(extract_stack: list[Mock]) -> FrameType:
|
||||||
def setup_test_component_platform(
|
def setup_test_component_platform(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
domain: str,
|
domain: str,
|
||||||
entities: Sequence[Entity],
|
entities: Iterable[Entity],
|
||||||
from_config_entry: bool = False,
|
from_config_entry: bool = False,
|
||||||
built_in: bool = True,
|
built_in: bool = True,
|
||||||
) -> MockPlatform:
|
) -> MockPlatform:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue