diff --git a/homeassistant/components/demo/__init__.py b/homeassistant/components/demo/__init__.py index 13e8e135394..82cb8eff625 100644 --- a/homeassistant/components/demo/__init__.py +++ b/homeassistant/components/demo/__init__.py @@ -36,6 +36,7 @@ COMPONENTS_WITH_CONFIG_ENTRY_DEMO_PLATFORM = [ Platform.SELECT, Platform.SENSOR, Platform.SIREN, + Platform.STT, Platform.SWITCH, Platform.TEXT, Platform.UPDATE, diff --git a/homeassistant/components/demo/stt.py b/homeassistant/components/demo/stt.py index 923092fad20..e1f59fa76ee 100644 --- a/homeassistant/components/demo/stt.py +++ b/homeassistant/components/demo/stt.py @@ -13,8 +13,11 @@ from homeassistant.components.stt import ( SpeechMetadata, SpeechResult, SpeechResultState, + SpeechToTextEntity, ) +from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType SUPPORT_LANGUAGES = ["en", "de"] @@ -29,6 +32,60 @@ async def async_get_engine( return DemoProvider() +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up Demo speech platform via config entry.""" + async_add_entities([DemoProviderEntity()]) + + +class DemoProviderEntity(SpeechToTextEntity): + """Demo speech API provider entity.""" + + @property + def supported_languages(self) -> list[str]: + """Return a list of supported languages.""" + return SUPPORT_LANGUAGES + + @property + def supported_formats(self) -> list[AudioFormats]: + """Return a list of supported formats.""" + return [AudioFormats.WAV] + + @property + def supported_codecs(self) -> list[AudioCodecs]: + """Return a list of supported codecs.""" + return [AudioCodecs.PCM] + + @property + def supported_bit_rates(self) -> list[AudioBitRates]: + """Return a list of supported bit rates.""" + return [AudioBitRates.BITRATE_16] + + @property + def supported_sample_rates(self) -> list[AudioSampleRates]: + """Return a list of supported sample rates.""" + return [AudioSampleRates.SAMPLERATE_16000, AudioSampleRates.SAMPLERATE_44100] + + @property + def supported_channels(self) -> list[AudioChannels]: + """Return a list of supported channels.""" + return [AudioChannels.CHANNEL_STEREO] + + async def async_process_audio_stream( + self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] + ) -> SpeechResult: + """Process an audio stream to STT service.""" + + # Read available data + async for _ in stream: + pass + + return SpeechResult("Turn the Kitchen Lights on", SpeechResultState.SUCCESS) + + class DemoProvider(Provider): """Demo speech API provider.""" diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index 50eac1dfeb7..83492f5f6bc 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -1,9 +1,12 @@ """Provide functionality to STT.""" from __future__ import annotations +from abc import abstractmethod import asyncio +from collections.abc import AsyncIterable from dataclasses import asdict -from typing import Any +import logging +from typing import Any, final from aiohttp import web from aiohttp.hdrs import istr @@ -14,10 +17,16 @@ from aiohttp.web_exceptions import ( ) from homeassistant.components.http import HomeAssistantView -from homeassistant.core import HomeAssistant +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType +from homeassistant.util import dt as dt_util from .const import ( + DATA_PROVIDERS, DOMAIN, AudioBitRates, AudioChannels, @@ -36,6 +45,7 @@ from .legacy import ( __all__ = [ "async_get_provider", + "async_get_speech_to_text_entity", "AudioBitRates", "AudioChannels", "AudioCodecs", @@ -43,26 +53,158 @@ __all__ = [ "AudioSampleRates", "DOMAIN", "Provider", + "SpeechToTextEntity", "SpeechMetadata", "SpeechResult", "SpeechResultState", ] +_LOGGER = logging.getLogger(__name__) + + +@callback +def async_get_speech_to_text_entity( + hass: HomeAssistant, entity_id: str +) -> SpeechToTextEntity | None: + """Return stt entity.""" + component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN] + + return component.get_entity(entity_id) + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up STT.""" + component = hass.data[DOMAIN] = EntityComponent[SpeechToTextEntity]( + _LOGGER, DOMAIN, hass + ) + + component.register_shutdown() platform_setups = async_setup_legacy(hass, config) if platform_setups: await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups]) - hass.http.register_view(SpeechToTextView(hass.data[DOMAIN])) + hass.http.register_view(SpeechToTextView(hass.data[DATA_PROVIDERS])) return True +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Set up a config entry.""" + component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN] + return await component.async_setup_entry(entry) + + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Unload a config entry.""" + component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN] + return await component.async_unload_entry(entry) + + +class SpeechToTextEntity(RestoreEntity): + """Represent a single STT provider.""" + + _attr_should_poll = False + __last_processed: str | None = None + + @property + @final + def name(self) -> str: + """Return the name of the provider entity.""" + # Only one entity is allowed per platform for now. + if self.platform is None: + raise RuntimeError("Entity is not added to hass yet.") + + return self.platform.platform_name + + @property + @final + def state(self) -> str | None: + """Return the state of the provider entity.""" + if self.__last_processed is None: + return None + return self.__last_processed + + @property + @abstractmethod + def supported_languages(self) -> list[str]: + """Return a list of supported languages.""" + + @property + @abstractmethod + def supported_formats(self) -> list[AudioFormats]: + """Return a list of supported formats.""" + + @property + @abstractmethod + def supported_codecs(self) -> list[AudioCodecs]: + """Return a list of supported codecs.""" + + @property + @abstractmethod + def supported_bit_rates(self) -> list[AudioBitRates]: + """Return a list of supported bit rates.""" + + @property + @abstractmethod + def supported_sample_rates(self) -> list[AudioSampleRates]: + """Return a list of supported sample rates.""" + + @property + @abstractmethod + def supported_channels(self) -> list[AudioChannels]: + """Return a list of supported channels.""" + + async def async_internal_added_to_hass(self) -> None: + """Call when the provider entity is added to hass.""" + await super().async_internal_added_to_hass() + state = await self.async_get_last_state() + if ( + state is not None + and state.state is not None + and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) + ): + self.__last_processed = state.state + + @final + async def internal_async_process_audio_stream( + self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] + ) -> SpeechResult: + """Process an audio stream to STT service. + + Only streaming content is allowed! + """ + self.__last_processed = dt_util.utcnow().isoformat() + self.async_write_ha_state() + return await self.async_process_audio_stream(metadata=metadata, stream=stream) + + @abstractmethod + async def async_process_audio_stream( + self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] + ) -> SpeechResult: + """Process an audio stream to STT service. + + Only streaming content is allowed! + """ + + @callback + def check_metadata(self, metadata: SpeechMetadata) -> bool: + """Check if given metadata supported by this provider.""" + if ( + metadata.language not in self.supported_languages + or metadata.format not in self.supported_formats + or metadata.codec not in self.supported_codecs + or metadata.bit_rate not in self.supported_bit_rates + or metadata.sample_rate not in self.supported_sample_rates + or metadata.channel not in self.supported_channels + ): + return False + return True + + class SpeechToTextView(HomeAssistantView): """STT view to generate a text from audio stream.""" + _legacy_provider_reported = False requires_auth = True url = "/api/stt/{provider}" name = "api:stt:provider" @@ -73,9 +215,17 @@ class SpeechToTextView(HomeAssistantView): async def post(self, request: web.Request, provider: str) -> web.Response: """Convert Speech (audio) to text.""" - if provider not in self.providers: + hass: HomeAssistant = request.app["hass"] + provider_entity: SpeechToTextEntity | None = None + if ( + not ( + provider_entity := async_get_speech_to_text_entity( + hass, f"{DOMAIN}.{provider}" + ) + ) + and provider not in self.providers + ): raise HTTPNotFound() - stt_provider: Provider = self.providers[provider] # Get metadata try: @@ -83,35 +233,105 @@ class SpeechToTextView(HomeAssistantView): except ValueError as err: raise HTTPBadRequest(text=str(err)) from err - # Check format - if not stt_provider.check_metadata(metadata): - raise HTTPUnsupportedMediaType() + if not provider_entity: + stt_provider = self._get_provider(provider) - # Process audio stream - result = await stt_provider.async_process_audio_stream( - metadata, request.content - ) + # Check format + if not stt_provider.check_metadata(metadata): + raise HTTPUnsupportedMediaType() + + # Process audio stream + result = await stt_provider.async_process_audio_stream( + metadata, request.content + ) + else: + # Check format + if not provider_entity.check_metadata(metadata): + raise HTTPUnsupportedMediaType() + + # Process audio stream + result = await provider_entity.internal_async_process_audio_stream( + metadata, request.content + ) # Return result return self.json(asdict(result)) async def get(self, request: web.Request, provider: str) -> web.Response: """Return provider specific audio information.""" - if provider not in self.providers: + hass: HomeAssistant = request.app["hass"] + if ( + not ( + provider_entity := async_get_speech_to_text_entity( + hass, f"{DOMAIN}.{provider}" + ) + ) + and provider not in self.providers + ): raise HTTPNotFound() - stt_provider: Provider = self.providers[provider] + + if not provider_entity: + stt_provider = self._get_provider(provider) + + return self.json( + { + "languages": stt_provider.supported_languages, + "formats": stt_provider.supported_formats, + "codecs": stt_provider.supported_codecs, + "sample_rates": stt_provider.supported_sample_rates, + "bit_rates": stt_provider.supported_bit_rates, + "channels": stt_provider.supported_channels, + } + ) return self.json( { - "languages": stt_provider.supported_languages, - "formats": stt_provider.supported_formats, - "codecs": stt_provider.supported_codecs, - "sample_rates": stt_provider.supported_sample_rates, - "bit_rates": stt_provider.supported_bit_rates, - "channels": stt_provider.supported_channels, + "languages": provider_entity.supported_languages, + "formats": provider_entity.supported_formats, + "codecs": provider_entity.supported_codecs, + "sample_rates": provider_entity.supported_sample_rates, + "bit_rates": provider_entity.supported_bit_rates, + "channels": provider_entity.supported_channels, } ) + def _get_provider(self, provider: str) -> Provider: + """Get provider. + + Method for legacy providers. + This can be removed when we remove the legacy provider support. + """ + stt_provider = self.providers[provider] + + if not self._legacy_provider_reported: + self._legacy_provider_reported = True + report_issue = self._suggest_report_issue(provider, stt_provider) + # This should raise in Home Assistant Core 2023.9 + _LOGGER.warning( + "Provider %s (%s) is using a legacy implementation, " + "and should be updated to use the SpeechToTextEntity. Please " + "%s", + provider, + type(stt_provider), + report_issue, + ) + + return stt_provider + + def _suggest_report_issue(self, provider: str, provider_instance: object) -> str: + """Suggest to report an issue.""" + report_issue = "" + if "custom_components" in type(provider_instance).__module__: + report_issue = "report it to the custom integration author." + else: + report_issue = ( + "create a bug report at " + "https://github.com/home-assistant/core/issues?q=is%3Aopen+is%3Aissue" + ) + report_issue += f"+label%3A%22integration%3A+{provider}%22" + + return report_issue + def _metadata_from_header(request: web.Request) -> SpeechMetadata: """Extract STT metadata from header. @@ -138,7 +358,7 @@ def _metadata_from_header(request: web.Request) -> SpeechMetadata: for entry in data: key, _, value = entry.strip().partition("=") if key not in fields: - raise ValueError(f"Invalid field {key}") + raise ValueError(f"Invalid field: {key}") args[key] = value for field in fields: @@ -154,5 +374,5 @@ def _metadata_from_header(request: web.Request) -> SpeechMetadata: sample_rate=args["sample_rate"], channel=args["channel"], ) - except TypeError as err: + except ValueError as err: raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err diff --git a/homeassistant/components/stt/const.py b/homeassistant/components/stt/const.py index c111aed82a4..c9f5eb13d17 100644 --- a/homeassistant/components/stt/const.py +++ b/homeassistant/components/stt/const.py @@ -2,6 +2,7 @@ from enum import Enum DOMAIN = "stt" +DATA_PROVIDERS = f"{DOMAIN}_providers" class AudioCodecs(str, Enum): diff --git a/homeassistant/components/stt/legacy.py b/homeassistant/components/stt/legacy.py index 2f826e0be9e..ffa21a257f1 100644 --- a/homeassistant/components/stt/legacy.py +++ b/homeassistant/components/stt/legacy.py @@ -13,6 +13,7 @@ from homeassistant.helpers.typing import ConfigType from homeassistant.setup import async_prepare_setup_platform from .const import ( + DATA_PROVIDERS, DOMAIN, AudioBitRates, AudioChannels, @@ -31,15 +32,15 @@ def async_get_provider( ) -> Provider | None: """Return provider.""" if domain: - return hass.data[DOMAIN].get(domain) + return hass.data[DATA_PROVIDERS].get(domain) - if not hass.data[DOMAIN]: + if not hass.data[DATA_PROVIDERS]: return None - if "cloud" in hass.data[DOMAIN]: - return hass.data[DOMAIN]["cloud"] + if "cloud" in hass.data[DATA_PROVIDERS]: + return hass.data[DATA_PROVIDERS]["cloud"] - return next(iter(hass.data[DOMAIN].values())) + return next(iter(hass.data[DATA_PROVIDERS].values())) @callback @@ -47,7 +48,7 @@ def async_setup_legacy( hass: HomeAssistant, config: ConfigType ) -> list[Coroutine[Any, Any, None]]: """Set up legacy speech to text providers.""" - providers = hass.data[DOMAIN] = {} + providers = hass.data[DATA_PROVIDERS] = {} async def async_setup_platform(p_type, p_config=None, discovery_info=None): """Set up a TTS platform.""" diff --git a/tests/components/demo/test_stt.py b/tests/components/demo/test_stt.py index e51a07ae4cf..7a8582df29b 100644 --- a/tests/components/demo/test_stt.py +++ b/tests/components/demo/test_stt.py @@ -4,18 +4,31 @@ from http import HTTPStatus import pytest from homeassistant.components import stt +from homeassistant.components.demo import DOMAIN as DEMO_DOMAIN +from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component +from tests.common import MockConfigEntry from tests.typing import ClientSessionGenerator -@pytest.fixture(autouse=True) -async def setup_comp(hass): - """Set up demo component.""" +@pytest.fixture +async def setup_legacy_platform(hass: HomeAssistant) -> None: + """Set up legacy demo platform.""" assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "demo"}}) await hass.async_block_till_done() +@pytest.fixture +async def setup_config_entry(hass: HomeAssistant) -> None: + """Set up demo component from config entry.""" + config_entry = MockConfigEntry(domain=DEMO_DOMAIN) + config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + +@pytest.mark.usefixtures("setup_legacy_platform") async def test_demo_settings(hass_client: ClientSessionGenerator) -> None: """Test retrieve settings from demo provider.""" client = await hass_client() @@ -34,6 +47,7 @@ async def test_demo_settings(hass_client: ClientSessionGenerator) -> None: } +@pytest.mark.usefixtures("setup_legacy_platform") async def test_demo_speech_no_metadata(hass_client: ClientSessionGenerator) -> None: """Test retrieve settings from demo provider.""" client = await hass_client() @@ -42,6 +56,7 @@ async def test_demo_speech_no_metadata(hass_client: ClientSessionGenerator) -> N assert response.status == HTTPStatus.BAD_REQUEST +@pytest.mark.usefixtures("setup_legacy_platform") async def test_demo_speech_wrong_metadata(hass_client: ClientSessionGenerator) -> None: """Test retrieve settings from demo provider.""" client = await hass_client() @@ -59,6 +74,7 @@ async def test_demo_speech_wrong_metadata(hass_client: ClientSessionGenerator) - assert response.status == HTTPStatus.UNSUPPORTED_MEDIA_TYPE +@pytest.mark.usefixtures("setup_legacy_platform") async def test_demo_speech(hass_client: ClientSessionGenerator) -> None: """Test retrieve settings from demo provider.""" client = await hass_client() @@ -77,3 +93,26 @@ async def test_demo_speech(hass_client: ClientSessionGenerator) -> None: assert response.status == HTTPStatus.OK assert response_data == {"text": "Turn the Kitchen Lights on", "result": "success"} + + +@pytest.mark.usefixtures("setup_config_entry") +async def test_config_entry_demo_speech( + hass_client: ClientSessionGenerator, hass: HomeAssistant +) -> None: + """Test retrieve settings from demo provider from config entry.""" + client = await hass_client() + + response = await client.post( + "/api/stt/demo", + headers={ + "X-Speech-Content": ( + "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=2;" + " language=de" + ) + }, + data=b"Test", + ) + response_data = await response.json() + + assert response.status == HTTPStatus.OK + assert response_data == {"text": "Turn the Kitchen Lights on", "result": "success"} diff --git a/tests/components/stt/common.py b/tests/components/stt/common.py index 0fe4d5b80d1..79b58531b54 100644 --- a/tests/components/stt/common.py +++ b/tests/components/stt/common.py @@ -6,7 +6,9 @@ from pathlib import Path from typing import Any from homeassistant.components.stt import Provider +from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from tests.common import MockPlatform, mock_platform @@ -54,3 +56,19 @@ def mock_stt_platform( mock_platform(hass, f"{integration}.stt", loaded_platform) return loaded_platform + + +def mock_stt_entity_platform( + hass: HomeAssistant, + tmp_path: Path, + integration: str, + async_setup_entry: Callable[ + [HomeAssistant, ConfigEntry, AddEntitiesCallback], + Coroutine[Any, Any, None], + ] + | None = None, +) -> MockPlatform: + """Specialize the mock platform for stt.""" + loaded_platform = MockPlatform(async_setup_entry=async_setup_entry) + mock_platform(hass, f"{integration}.stt", loaded_platform) + return loaded_platform diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index e021d73c290..483037e7ee1 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -1,5 +1,5 @@ """Test STT component setup.""" -from collections.abc import AsyncIterable +from collections.abc import AsyncIterable, Generator from http import HTTPStatus from pathlib import Path from unittest.mock import AsyncMock @@ -7,6 +7,7 @@ from unittest.mock import AsyncMock import pytest from homeassistant.components.stt import ( + DOMAIN, AudioBitRates, AudioChannels, AudioCodecs, @@ -16,17 +17,30 @@ from homeassistant.components.stt import ( SpeechMetadata, SpeechResult, SpeechResultState, + SpeechToTextEntity, async_get_provider, ) -from homeassistant.core import HomeAssistant +from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow +from homeassistant.core import HomeAssistant, State +from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.setup import async_setup_component -from .common import mock_stt_platform +from .common import mock_stt_entity_platform, mock_stt_platform +from tests.common import ( + MockConfigEntry, + MockModule, + mock_config_flow, + mock_integration, + mock_platform, + mock_restore_cache, +) from tests.typing import ClientSessionGenerator +TEST_DOMAIN = "test" -class MockProvider(Provider): + +class BaseProvider: """Mock provider.""" fail_process_audio = False @@ -73,7 +87,15 @@ class MockProvider(Provider): if self.fail_process_audio: return SpeechResult(None, SpeechResultState.ERROR) - return SpeechResult("test", SpeechResultState.SUCCESS) + return SpeechResult("test_result", SpeechResultState.SUCCESS) + + +class MockProvider(BaseProvider, Provider): + """Mock provider.""" + + +class MockProviderEntity(BaseProvider, SpeechToTextEntity): + """Mock provider entity.""" @pytest.fixture @@ -82,26 +104,113 @@ def mock_provider() -> MockProvider: return MockProvider() +@pytest.fixture +def mock_provider_entity() -> MockProviderEntity: + """Test provider entity fixture.""" + return MockProviderEntity() + + +class STTFlow(ConfigFlow): + """Test flow.""" + + @pytest.fixture(autouse=True) +def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]: + """Mock config flow.""" + mock_platform(hass, f"{TEST_DOMAIN}.config_flow") + + with mock_config_flow(TEST_DOMAIN, STTFlow): + yield + + +@pytest.fixture(name="setup") +async def setup_fixture( + hass: HomeAssistant, + tmp_path: Path, + request: pytest.FixtureRequest, +) -> None: + """Set up the test environment.""" + if request.param == "mock_setup": + await mock_setup(hass, tmp_path, MockProvider()) + elif request.param == "mock_config_entry_setup": + await mock_config_entry_setup(hass, tmp_path, MockProviderEntity()) + else: + raise RuntimeError("Invalid setup fixture") + + async def mock_setup( - hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider + hass: HomeAssistant, + tmp_path: Path, + mock_provider: MockProvider, ) -> None: """Set up a test provider.""" mock_stt_platform( hass, tmp_path, - "test", + TEST_DOMAIN, async_get_engine=AsyncMock(return_value=mock_provider), ) - assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}}) + assert await async_setup_component(hass, "stt", {"stt": {"platform": TEST_DOMAIN}}) + await hass.async_block_till_done() +async def mock_config_entry_setup( + hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity +) -> MockConfigEntry: + """Set up a test provider via config entry.""" + + async def async_setup_entry_init( + hass: HomeAssistant, config_entry: ConfigEntry + ) -> bool: + """Set up test config entry.""" + await hass.config_entries.async_forward_entry_setup(config_entry, DOMAIN) + return True + + async def async_unload_entry_init( + hass: HomeAssistant, config_entry: ConfigEntry + ) -> bool: + """Unload up test config entry.""" + await hass.config_entries.async_forward_entry_unload(config_entry, DOMAIN) + return True + + mock_integration( + hass, + MockModule( + TEST_DOMAIN, + async_setup_entry=async_setup_entry_init, + async_unload_entry=async_unload_entry_init, + ), + ) + + async def async_setup_entry_platform( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, + ) -> None: + """Set up test stt platform via config entry.""" + async_add_entities([mock_provider_entity]) + + mock_stt_entity_platform(hass, tmp_path, TEST_DOMAIN, async_setup_entry_platform) + + config_entry = MockConfigEntry(domain=TEST_DOMAIN) + config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + return config_entry + + +@pytest.mark.parametrize( + "setup", ["mock_setup", "mock_config_entry_setup"], indirect=True +) async def test_get_provider_info( - hass: HomeAssistant, hass_client: ClientSessionGenerator + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + setup: str, ) -> None: """Test engine that doesn't exist.""" client = await hass_client() - response = await client.get("/api/stt/test") + response = await client.get(f"/api/stt/{TEST_DOMAIN}") assert response.status == HTTPStatus.OK assert await response.json() == { "languages": ["en"], @@ -113,22 +222,44 @@ async def test_get_provider_info( } -async def test_get_non_existing_provider_info( - hass: HomeAssistant, hass_client: ClientSessionGenerator +@pytest.mark.parametrize( + "setup", ["mock_setup", "mock_config_entry_setup"], indirect=True +) +async def test_non_existing_provider( + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + setup: str, ) -> None: """Test streaming to engine that doesn't exist.""" client = await hass_client() + response = await client.get("/api/stt/not_exist") assert response.status == HTTPStatus.NOT_FOUND + response = await client.post( + "/api/stt/not_exist", + headers={ + "X-Speech-Content": ( + "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;" + " language=en" + ) + }, + ) + assert response.status == HTTPStatus.NOT_FOUND + +@pytest.mark.parametrize( + "setup", ["mock_setup", "mock_config_entry_setup"], indirect=True +) async def test_stream_audio( - hass: HomeAssistant, hass_client: ClientSessionGenerator + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + setup: str, ) -> None: """Test streaming audio and getting response.""" client = await hass_client() response = await client.post( - "/api/stt/test", + f"/api/stt/{TEST_DOMAIN}", headers={ "X-Speech-Content": ( "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;" @@ -137,20 +268,39 @@ async def test_stream_audio( }, ) assert response.status == HTTPStatus.OK - assert await response.json() == {"text": "test", "result": "success"} + assert await response.json() == {"text": "test_result", "result": "success"} +@pytest.mark.parametrize( + "setup", ["mock_setup", "mock_config_entry_setup"], indirect=True +) @pytest.mark.parametrize( ("header", "status", "error"), ( (None, 400, "Missing X-Speech-Content header"), + ( + ( + "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100;" + " language=en; unknown=1" + ), + 400, + "Invalid field: unknown", + ), ( ( "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100;" " language=en" ), 400, - "100 is not a valid AudioChannels", + "Wrong format of X-Speech-Content: 100 is not a valid AudioChannels", + ), + ( + ( + "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=bad channel;" + " language=en" + ), + 400, + "Wrong format of X-Speech-Content: invalid literal for int() with base 10: 'bad channel'", ), ( "format=wav; codec=pcm; sample_rate=16000", @@ -165,6 +315,7 @@ async def test_metadata_errors( header: str | None, status: int, error: str, + setup: str, ) -> None: """Test metadata errors.""" client = await hass_client() @@ -172,11 +323,55 @@ async def test_metadata_errors( if header: headers["X-Speech-Content"] = header - response = await client.post("/api/stt/test", headers=headers) + response = await client.post(f"/api/stt/{TEST_DOMAIN}", headers=headers) assert response.status == status assert await response.text() == error -async def test_get_provider(hass: HomeAssistant, mock_provider: MockProvider) -> None: +async def test_get_provider( + hass: HomeAssistant, + tmp_path: Path, + mock_provider: MockProvider, +) -> None: """Test we can get STT providers.""" - assert mock_provider == async_get_provider(hass, "test") + await mock_setup(hass, tmp_path, mock_provider) + assert mock_provider == async_get_provider(hass, TEST_DOMAIN) + + +async def test_config_entry_unload( + hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity +) -> None: + """Test we can unload config entry.""" + config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity) + assert config_entry.state == ConfigEntryState.LOADED + await hass.config_entries.async_unload(config_entry.entry_id) + assert config_entry.state == ConfigEntryState.NOT_LOADED + + +def test_entity_name_raises_before_addition( + hass: HomeAssistant, + tmp_path: Path, + mock_provider_entity: MockProviderEntity, +) -> None: + """Test entity name raises before addition to Home Assistant.""" + with pytest.raises(RuntimeError): + mock_provider_entity.name # pylint: disable=pointless-statement + + +async def test_restore_state( + hass: HomeAssistant, + tmp_path: Path, + mock_provider_entity: MockProviderEntity, +) -> None: + """Test we restore state in the integration.""" + entity_id = f"{DOMAIN}.{TEST_DOMAIN}" + timestamp = "2023-01-01T23:59:59+00:00" + mock_restore_cache(hass, (State(entity_id, timestamp),)) + + config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity) + await hass.async_block_till_done() + + assert config_entry.state == ConfigEntryState.LOADED + state = hass.states.get(entity_id) + assert state + assert state.state == timestamp