Google gen updates (#117893)
* Add a recommended model for Google Gen AI * Add recommended settings to Google Gen AI * Revert no API msg * Use correct default settings * Make sure options are cleared when using recommended * Update snapshots * address comments
This commit is contained in:
parent
c0bcf00bf8
commit
d1af40f1eb
7 changed files with 212 additions and 110 deletions
|
@ -8,11 +8,11 @@
|
|||
dict({
|
||||
'generation_config': dict({
|
||||
'max_output_tokens': 150,
|
||||
'temperature': 0.9,
|
||||
'top_k': 1,
|
||||
'top_p': 1.0,
|
||||
'temperature': 1.0,
|
||||
'top_k': 64,
|
||||
'top_p': 0.95,
|
||||
}),
|
||||
'model_name': 'models/gemini-pro',
|
||||
'model_name': 'models/gemini-1.5-flash-latest',
|
||||
'tools': None,
|
||||
}),
|
||||
),
|
||||
|
@ -67,11 +67,11 @@
|
|||
dict({
|
||||
'generation_config': dict({
|
||||
'max_output_tokens': 150,
|
||||
'temperature': 0.9,
|
||||
'top_k': 1,
|
||||
'top_p': 1.0,
|
||||
'temperature': 1.0,
|
||||
'top_k': 64,
|
||||
'top_p': 0.95,
|
||||
}),
|
||||
'model_name': 'models/gemini-pro',
|
||||
'model_name': 'models/gemini-1.5-flash-latest',
|
||||
'tools': None,
|
||||
}),
|
||||
),
|
||||
|
@ -126,11 +126,11 @@
|
|||
dict({
|
||||
'generation_config': dict({
|
||||
'max_output_tokens': 150,
|
||||
'temperature': 0.9,
|
||||
'top_k': 1,
|
||||
'top_p': 1.0,
|
||||
'temperature': 1.0,
|
||||
'top_k': 64,
|
||||
'top_p': 0.95,
|
||||
}),
|
||||
'model_name': 'models/gemini-pro',
|
||||
'model_name': 'models/gemini-1.5-flash-latest',
|
||||
'tools': None,
|
||||
}),
|
||||
),
|
||||
|
@ -185,11 +185,11 @@
|
|||
dict({
|
||||
'generation_config': dict({
|
||||
'max_output_tokens': 150,
|
||||
'temperature': 0.9,
|
||||
'top_k': 1,
|
||||
'top_p': 1.0,
|
||||
'temperature': 1.0,
|
||||
'top_k': 64,
|
||||
'top_p': 0.95,
|
||||
}),
|
||||
'model_name': 'models/gemini-pro',
|
||||
'model_name': 'models/gemini-1.5-flash-latest',
|
||||
'tools': None,
|
||||
}),
|
||||
),
|
||||
|
|
|
@ -10,13 +10,17 @@ from homeassistant import config_entries
|
|||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_RECOMMENDED,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TONE_PROMPT,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_TOP_P,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_TOP_K,
|
||||
RECOMMENDED_TOP_P,
|
||||
)
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
@ -42,7 +46,7 @@ def mock_models():
|
|||
model_10_pro.name = "models/gemini-pro"
|
||||
with patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models",
|
||||
return_value=[model_10_pro],
|
||||
return_value=[model_15_flash, model_10_pro],
|
||||
):
|
||||
yield
|
||||
|
||||
|
@ -84,36 +88,89 @@ async def test_form(hass: HomeAssistant) -> None:
|
|||
"api_key": "bla",
|
||||
}
|
||||
assert result2["options"] == {
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||
CONF_TONE_PROMPT: "",
|
||||
}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_options(
|
||||
hass: HomeAssistant, mock_config_entry, mock_init_component, mock_models
|
||||
@pytest.mark.parametrize(
|
||||
("current_options", "new_options", "expected_options"),
|
||||
[
|
||||
(
|
||||
{
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: "none",
|
||||
CONF_TONE_PROMPT: "bla",
|
||||
},
|
||||
{
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_PROMPT: "Speak like a pirate",
|
||||
CONF_TEMPERATURE: 0.3,
|
||||
},
|
||||
{
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_PROMPT: "Speak like a pirate",
|
||||
CONF_TEMPERATURE: 0.3,
|
||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||
CONF_TOP_K: RECOMMENDED_TOP_K,
|
||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||
},
|
||||
),
|
||||
(
|
||||
{
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_PROMPT: "Speak like a pirate",
|
||||
CONF_TEMPERATURE: 0.3,
|
||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||
CONF_TOP_K: RECOMMENDED_TOP_K,
|
||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||
},
|
||||
{
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: "assist",
|
||||
CONF_TONE_PROMPT: "",
|
||||
},
|
||||
{
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: "assist",
|
||||
CONF_TONE_PROMPT: "",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_options_switching(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry,
|
||||
mock_init_component,
|
||||
mock_models,
|
||||
current_options,
|
||||
new_options,
|
||||
expected_options,
|
||||
) -> None:
|
||||
"""Test the options form."""
|
||||
hass.config_entries.async_update_entry(mock_config_entry, options=current_options)
|
||||
options_flow = await hass.config_entries.options.async_init(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
if current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED):
|
||||
options_flow = await hass.config_entries.options.async_configure(
|
||||
options_flow["flow_id"],
|
||||
{
|
||||
**current_options,
|
||||
CONF_RECOMMENDED: new_options[CONF_RECOMMENDED],
|
||||
},
|
||||
)
|
||||
options = await hass.config_entries.options.async_configure(
|
||||
options_flow["flow_id"],
|
||||
{
|
||||
"prompt": "Speak like a pirate",
|
||||
"temperature": 0.3,
|
||||
},
|
||||
new_options,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert options["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert options["data"]["prompt"] == "Speak like a pirate"
|
||||
assert options["data"]["temperature"] == 0.3
|
||||
assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL
|
||||
assert options["data"][CONF_TOP_P] == DEFAULT_TOP_P
|
||||
assert options["data"][CONF_TOP_K] == DEFAULT_TOP_K
|
||||
assert options["data"][CONF_MAX_TOKENS] == DEFAULT_MAX_TOKENS
|
||||
assert (
|
||||
CONF_LLM_HASS_API not in options["data"]
|
||||
), "Options flow should not set this key"
|
||||
assert options["data"] == expected_options
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -354,7 +354,7 @@ async def test_blocked_response(
|
|||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||
"Sorry, I had a problem talking to Google Generative AI. Likely blocked"
|
||||
"Sorry, I had a problem getting a response from Google Generative AI."
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue