Convert TTS tests to async (#33517)
* Convert TTS tests to async * Address comments
This commit is contained in:
parent
254394ecab
commit
cb5de0e090
3 changed files with 501 additions and 532 deletions
|
@ -133,7 +133,7 @@ async def async_setup(hass, config):
|
||||||
hass, p_config, discovery_info
|
hass, p_config, discovery_info
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
provider = await hass.async_add_job(
|
provider = await hass.async_add_executor_job(
|
||||||
platform.get_engine, hass, p_config, discovery_info
|
platform.get_engine, hass, p_config, discovery_info
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -226,41 +226,17 @@ class SpeechManager:
|
||||||
self.time_memory = time_memory
|
self.time_memory = time_memory
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
|
||||||
def init_tts_cache_dir(cache_dir):
|
|
||||||
"""Init cache folder."""
|
|
||||||
if not os.path.isabs(cache_dir):
|
|
||||||
cache_dir = self.hass.config.path(cache_dir)
|
|
||||||
if not os.path.isdir(cache_dir):
|
|
||||||
_LOGGER.info("Create cache dir %s.", cache_dir)
|
|
||||||
os.mkdir(cache_dir)
|
|
||||||
return cache_dir
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.cache_dir = await self.hass.async_add_job(
|
self.cache_dir = await self.hass.async_add_executor_job(
|
||||||
init_tts_cache_dir, cache_dir
|
_init_tts_cache_dir, self.hass, cache_dir
|
||||||
)
|
)
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
raise HomeAssistantError(f"Can't init cache dir {err}")
|
raise HomeAssistantError(f"Can't init cache dir {err}")
|
||||||
|
|
||||||
def get_cache_files():
|
|
||||||
"""Return a dict of given engine files."""
|
|
||||||
cache = {}
|
|
||||||
|
|
||||||
folder_data = os.listdir(self.cache_dir)
|
|
||||||
for file_data in folder_data:
|
|
||||||
record = _RE_VOICE_FILE.match(file_data)
|
|
||||||
if record:
|
|
||||||
key = KEY_PATTERN.format(
|
|
||||||
record.group(1),
|
|
||||||
record.group(2),
|
|
||||||
record.group(3),
|
|
||||||
record.group(4),
|
|
||||||
)
|
|
||||||
cache[key.lower()] = file_data.lower()
|
|
||||||
return cache
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cache_files = await self.hass.async_add_job(get_cache_files)
|
cache_files = await self.hass.async_add_executor_job(
|
||||||
|
_get_cache_files, self.cache_dir
|
||||||
|
)
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
raise HomeAssistantError(f"Can't read cache dir {err}")
|
raise HomeAssistantError(f"Can't read cache dir {err}")
|
||||||
|
|
||||||
|
@ -273,13 +249,13 @@ class SpeechManager:
|
||||||
|
|
||||||
def remove_files():
|
def remove_files():
|
||||||
"""Remove files from filesystem."""
|
"""Remove files from filesystem."""
|
||||||
for _, filename in self.file_cache.items():
|
for filename in self.file_cache.values():
|
||||||
try:
|
try:
|
||||||
os.remove(os.path.join(self.cache_dir, filename))
|
os.remove(os.path.join(self.cache_dir, filename))
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
_LOGGER.warning("Can't remove cache file '%s': %s", filename, err)
|
_LOGGER.warning("Can't remove cache file '%s': %s", filename, err)
|
||||||
|
|
||||||
await self.hass.async_add_job(remove_files)
|
await self.hass.async_add_executor_job(remove_files)
|
||||||
self.file_cache = {}
|
self.file_cache = {}
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -312,6 +288,7 @@ class SpeechManager:
|
||||||
merged_options.update(options)
|
merged_options.update(options)
|
||||||
options = merged_options
|
options = merged_options
|
||||||
options = options or provider.default_options
|
options = options or provider.default_options
|
||||||
|
|
||||||
if options is not None:
|
if options is not None:
|
||||||
invalid_opts = [
|
invalid_opts = [
|
||||||
opt_name
|
opt_name
|
||||||
|
@ -378,10 +355,10 @@ class SpeechManager:
|
||||||
speech.write(data)
|
speech.write(data)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.hass.async_add_job(save_speech)
|
await self.hass.async_add_executor_job(save_speech)
|
||||||
self.file_cache[key] = filename
|
self.file_cache[key] = filename
|
||||||
except OSError:
|
except OSError as err:
|
||||||
_LOGGER.error("Can't write %s", filename)
|
_LOGGER.error("Can't write %s: %s", filename, err)
|
||||||
|
|
||||||
async def async_file_to_mem(self, key):
|
async def async_file_to_mem(self, key):
|
||||||
"""Load voice from file cache into memory.
|
"""Load voice from file cache into memory.
|
||||||
|
@ -400,7 +377,7 @@ class SpeechManager:
|
||||||
return speech.read()
|
return speech.read()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = await self.hass.async_add_job(load_speech)
|
data = await self.hass.async_add_executor_job(load_speech)
|
||||||
except OSError:
|
except OSError:
|
||||||
del self.file_cache[key]
|
del self.file_cache[key]
|
||||||
raise HomeAssistantError(f"Can't read {voice_file}")
|
raise HomeAssistantError(f"Can't read {voice_file}")
|
||||||
|
@ -506,11 +483,36 @@ class Provider:
|
||||||
|
|
||||||
Return a tuple of file extension and data as bytes.
|
Return a tuple of file extension and data as bytes.
|
||||||
"""
|
"""
|
||||||
return await self.hass.async_add_job(
|
return await self.hass.async_add_executor_job(
|
||||||
ft.partial(self.get_tts_audio, message, language, options=options)
|
ft.partial(self.get_tts_audio, message, language, options=options)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_tts_cache_dir(hass, cache_dir):
|
||||||
|
"""Init cache folder."""
|
||||||
|
if not os.path.isabs(cache_dir):
|
||||||
|
cache_dir = hass.config.path(cache_dir)
|
||||||
|
if not os.path.isdir(cache_dir):
|
||||||
|
_LOGGER.info("Create cache dir %s", cache_dir)
|
||||||
|
os.mkdir(cache_dir)
|
||||||
|
return cache_dir
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cache_files(cache_dir):
|
||||||
|
"""Return a dict of given engine files."""
|
||||||
|
cache = {}
|
||||||
|
|
||||||
|
folder_data = os.listdir(cache_dir)
|
||||||
|
for file_data in folder_data:
|
||||||
|
record = _RE_VOICE_FILE.match(file_data)
|
||||||
|
if record:
|
||||||
|
key = KEY_PATTERN.format(
|
||||||
|
record.group(1), record.group(2), record.group(3), record.group(4),
|
||||||
|
)
|
||||||
|
cache[key.lower()] = file_data.lower()
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
class TextToSpeechUrlView(HomeAssistantView):
|
class TextToSpeechUrlView(HomeAssistantView):
|
||||||
"""TTS view to get a url to a generated speech file."""
|
"""TTS view to get a url to a generated speech file."""
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,8 @@ import threading
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from aiohttp.test_utils import unused_port as get_test_instance_port # noqa
|
||||||
|
|
||||||
from homeassistant import auth, config_entries, core as ha, loader
|
from homeassistant import auth, config_entries, core as ha, loader
|
||||||
from homeassistant.auth import (
|
from homeassistant.auth import (
|
||||||
auth_store,
|
auth_store,
|
||||||
|
@ -37,7 +39,6 @@ from homeassistant.const import (
|
||||||
EVENT_PLATFORM_DISCOVERED,
|
EVENT_PLATFORM_DISCOVERED,
|
||||||
EVENT_STATE_CHANGED,
|
EVENT_STATE_CHANGED,
|
||||||
EVENT_TIME_CHANGED,
|
EVENT_TIME_CHANGED,
|
||||||
SERVER_PORT,
|
|
||||||
STATE_OFF,
|
STATE_OFF,
|
||||||
STATE_ON,
|
STATE_ON,
|
||||||
)
|
)
|
||||||
|
@ -59,7 +60,6 @@ import homeassistant.util.dt as date_util
|
||||||
from homeassistant.util.unit_system import METRIC_SYSTEM
|
from homeassistant.util.unit_system import METRIC_SYSTEM
|
||||||
import homeassistant.util.yaml.loader as yaml_loader
|
import homeassistant.util.yaml.loader as yaml_loader
|
||||||
|
|
||||||
_TEST_INSTANCE_PORT = SERVER_PORT
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
INSTANCES = []
|
INSTANCES = []
|
||||||
CLIENT_ID = "https://example.com/app"
|
CLIENT_ID = "https://example.com/app"
|
||||||
|
@ -217,18 +217,6 @@ async def async_test_home_assistant(loop):
|
||||||
return hass
|
return hass
|
||||||
|
|
||||||
|
|
||||||
def get_test_instance_port():
|
|
||||||
"""Return unused port for running test instance.
|
|
||||||
|
|
||||||
The socket that holds the default port does not get released when we stop
|
|
||||||
HA in a different test case. Until I have figured out what is going on,
|
|
||||||
let's run each test on a different port.
|
|
||||||
"""
|
|
||||||
global _TEST_INSTANCE_PORT
|
|
||||||
_TEST_INSTANCE_PORT += 1
|
|
||||||
return _TEST_INSTANCE_PORT
|
|
||||||
|
|
||||||
|
|
||||||
def async_mock_service(hass, domain, service, schema=None):
|
def async_mock_service(hass, domain, service, schema=None):
|
||||||
"""Set up a fake service & return a calls log list to this service."""
|
"""Set up a fake service & return a calls log list to this service."""
|
||||||
calls = []
|
calls = []
|
||||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue