2023-01-25 11:30:13 -05:00
|
|
|
"""Tests for the OpenAI integration."""
|
|
|
|
from unittest.mock import patch
|
|
|
|
|
2023-01-25 22:17:19 -05:00
|
|
|
from openai import error
|
2023-03-28 23:37:43 -04:00
|
|
|
from syrupy.assertion import SnapshotAssertion
|
2023-01-25 22:17:19 -05:00
|
|
|
|
2023-01-25 11:30:13 -05:00
|
|
|
from homeassistant.components import conversation
|
2023-02-02 21:20:10 +01:00
|
|
|
from homeassistant.core import Context, HomeAssistant
|
2023-03-01 03:59:44 +01:00
|
|
|
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
|
2023-01-25 11:30:13 -05:00
|
|
|
|
2023-02-02 21:20:10 +01:00
|
|
|
from tests.common import MockConfigEntry
|
|
|
|
|
2023-01-25 11:30:13 -05:00
|
|
|
|
2023-03-01 03:59:44 +01:00
|
|
|
async def test_default_prompt(
|
|
|
|
hass: HomeAssistant,
|
|
|
|
mock_init_component,
|
|
|
|
area_registry: ar.AreaRegistry,
|
|
|
|
device_registry: dr.DeviceRegistry,
|
2023-03-28 23:37:43 -04:00
|
|
|
snapshot: SnapshotAssertion,
|
2023-03-01 03:59:44 +01:00
|
|
|
) -> None:
|
2023-01-25 11:30:13 -05:00
|
|
|
"""Test that the default prompt works."""
|
2023-01-25 22:17:19 -05:00
|
|
|
for i in range(3):
|
2023-03-01 03:59:44 +01:00
|
|
|
area_registry.async_create(f"{i}Empty Area")
|
2023-01-25 11:30:13 -05:00
|
|
|
|
2023-03-01 03:59:44 +01:00
|
|
|
device_registry.async_get_or_create(
|
2023-01-25 11:30:13 -05:00
|
|
|
config_entry_id="1234",
|
|
|
|
connections={("test", "1234")},
|
|
|
|
name="Test Device",
|
|
|
|
manufacturer="Test Manufacturer",
|
|
|
|
model="Test Model",
|
|
|
|
suggested_area="Test Area",
|
|
|
|
)
|
2023-01-25 22:17:19 -05:00
|
|
|
for i in range(3):
|
2023-03-01 03:59:44 +01:00
|
|
|
device_registry.async_get_or_create(
|
2023-01-25 22:17:19 -05:00
|
|
|
config_entry_id="1234",
|
|
|
|
connections={("test", f"{i}abcd")},
|
|
|
|
name="Test Service",
|
|
|
|
manufacturer="Test Manufacturer",
|
|
|
|
model="Test Model",
|
|
|
|
suggested_area="Test Area",
|
2023-03-01 03:59:44 +01:00
|
|
|
entry_type=dr.DeviceEntryType.SERVICE,
|
2023-01-25 22:17:19 -05:00
|
|
|
)
|
2023-03-01 03:59:44 +01:00
|
|
|
device_registry.async_get_or_create(
|
2023-01-25 11:30:13 -05:00
|
|
|
config_entry_id="1234",
|
|
|
|
connections={("test", "5678")},
|
|
|
|
name="Test Device 2",
|
|
|
|
manufacturer="Test Manufacturer 2",
|
2023-01-25 22:17:19 -05:00
|
|
|
model="Device 2",
|
2023-01-25 11:30:13 -05:00
|
|
|
suggested_area="Test Area 2",
|
|
|
|
)
|
2023-03-01 03:59:44 +01:00
|
|
|
device_registry.async_get_or_create(
|
2023-01-25 11:30:13 -05:00
|
|
|
config_entry_id="1234",
|
|
|
|
connections={("test", "9876")},
|
|
|
|
name="Test Device 3",
|
|
|
|
manufacturer="Test Manufacturer 3",
|
2023-01-25 22:17:19 -05:00
|
|
|
model="Test Model 3A",
|
2023-01-25 11:30:13 -05:00
|
|
|
suggested_area="Test Area 2",
|
|
|
|
)
|
2023-03-01 03:59:44 +01:00
|
|
|
device_registry.async_get_or_create(
|
2023-01-26 17:25:02 -05:00
|
|
|
config_entry_id="1234",
|
|
|
|
connections={("test", "qwer")},
|
|
|
|
name="Test Device 4",
|
|
|
|
suggested_area="Test Area 2",
|
|
|
|
)
|
2023-03-01 03:59:44 +01:00
|
|
|
device = device_registry.async_get_or_create(
|
2023-01-26 05:04:15 -05:00
|
|
|
config_entry_id="1234",
|
|
|
|
connections={("test", "9876-disabled")},
|
|
|
|
name="Test Device 3",
|
|
|
|
manufacturer="Test Manufacturer 3",
|
|
|
|
model="Test Model 3A",
|
|
|
|
suggested_area="Test Area 2",
|
|
|
|
)
|
2023-03-01 03:59:44 +01:00
|
|
|
device_registry.async_update_device(
|
|
|
|
device.id, disabled_by=dr.DeviceEntryDisabler.USER
|
2023-01-26 05:04:15 -05:00
|
|
|
)
|
2023-03-01 03:59:44 +01:00
|
|
|
device_registry.async_get_or_create(
|
2023-02-06 23:57:08 -05:00
|
|
|
config_entry_id="1234",
|
|
|
|
connections={("test", "9876-no-name")},
|
|
|
|
manufacturer="Test Manufacturer NoName",
|
|
|
|
model="Test Model NoName",
|
|
|
|
suggested_area="Test Area 2",
|
|
|
|
)
|
2023-03-01 03:59:44 +01:00
|
|
|
device_registry.async_get_or_create(
|
2023-02-14 05:45:27 -05:00
|
|
|
config_entry_id="1234",
|
|
|
|
connections={("test", "9876-integer-values")},
|
|
|
|
name=1,
|
|
|
|
manufacturer=2,
|
|
|
|
model=3,
|
|
|
|
suggested_area="Test Area 2",
|
|
|
|
)
|
2023-03-28 23:37:43 -04:00
|
|
|
with patch(
|
|
|
|
"openai.ChatCompletion.acreate",
|
|
|
|
return_value={
|
|
|
|
"choices": [
|
|
|
|
{
|
|
|
|
"message": {
|
|
|
|
"role": "assistant",
|
|
|
|
"content": "Hello, how can I help you?",
|
|
|
|
}
|
|
|
|
}
|
|
|
|
]
|
|
|
|
},
|
|
|
|
) as mock_create:
|
2023-01-26 17:25:02 -05:00
|
|
|
result = await conversation.async_converse(hass, "hello", None, Context())
|
|
|
|
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
2023-03-28 23:37:43 -04:00
|
|
|
assert mock_create.mock_calls[0][2]["messages"] == snapshot
|
2023-01-25 22:17:19 -05:00
|
|
|
|
|
|
|
|
2023-02-15 10:50:02 +01:00
|
|
|
async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None:
|
2023-01-25 22:17:19 -05:00
|
|
|
"""Test that the default prompt works."""
|
2023-03-28 23:37:43 -04:00
|
|
|
with patch(
|
|
|
|
"openai.ChatCompletion.acreate", side_effect=error.ServiceUnavailableError
|
|
|
|
):
|
2023-01-31 00:24:11 +11:00
|
|
|
result = await conversation.async_converse(hass, "hello", None, Context())
|
|
|
|
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
|
|
|
assert result.response.error_code == "unknown", result
|
|
|
|
|
|
|
|
|
2023-02-02 21:20:10 +01:00
|
|
|
async def test_template_error(
|
|
|
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
|
|
|
) -> None:
|
2023-01-31 00:24:11 +11:00
|
|
|
"""Test that template error handling works."""
|
2023-02-02 21:20:10 +01:00
|
|
|
hass.config_entries.async_update_entry(
|
|
|
|
mock_config_entry,
|
|
|
|
options={
|
2023-01-31 00:24:11 +11:00
|
|
|
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
|
|
|
},
|
|
|
|
)
|
2023-02-02 21:20:10 +01:00
|
|
|
with patch(
|
|
|
|
"openai.Engine.list",
|
2023-03-28 23:37:43 -04:00
|
|
|
), patch("openai.ChatCompletion.acreate"):
|
2023-02-02 21:20:10 +01:00
|
|
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
|
|
|
await hass.async_block_till_done()
|
2023-01-25 22:17:19 -05:00
|
|
|
result = await conversation.async_converse(hass, "hello", None, Context())
|
|
|
|
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
|
|
|
assert result.response.error_code == "unknown", result
|
2023-04-18 22:11:04 +02:00
|
|
|
|
|
|
|
|
|
|
|
async def test_conversation_agent(
|
|
|
|
hass: HomeAssistant,
|
|
|
|
mock_config_entry: MockConfigEntry,
|
|
|
|
mock_init_component,
|
|
|
|
) -> None:
|
|
|
|
"""Test OpenAIAgent."""
|
|
|
|
agent = await conversation._get_agent_manager(hass).async_get_agent(
|
|
|
|
mock_config_entry.entry_id
|
|
|
|
)
|
|
|
|
assert agent.supported_languages == ["*"]
|