Google gen updates (#117893)
* Add a recommended model for Google Gen AI * Add recommended settings to Google Gen AI * Revert no API msg * Use correct default settings * Make sure options are cleared when using recommended * Update snapshots * address comments
This commit is contained in:
parent
c0bcf00bf8
commit
d1af40f1eb
7 changed files with 212 additions and 110 deletions
|
@ -34,16 +34,18 @@ from .const import (
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
|
CONF_RECOMMENDED,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
|
CONF_TONE_PROMPT,
|
||||||
CONF_TOP_K,
|
CONF_TOP_K,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
DEFAULT_CHAT_MODEL,
|
|
||||||
DEFAULT_MAX_TOKENS,
|
|
||||||
DEFAULT_PROMPT,
|
DEFAULT_PROMPT,
|
||||||
DEFAULT_TEMPERATURE,
|
|
||||||
DEFAULT_TOP_K,
|
|
||||||
DEFAULT_TOP_P,
|
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_TEMPERATURE,
|
||||||
|
RECOMMENDED_TOP_K,
|
||||||
|
RECOMMENDED_TOP_P,
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
@ -54,6 +56,12 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
RECOMMENDED_OPTIONS = {
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||||
|
CONF_TONE_PROMPT: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
||||||
"""Validate the user input allows us to connect.
|
"""Validate the user input allows us to connect.
|
||||||
|
@ -94,7 +102,7 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title="Google Generative AI",
|
title="Google Generative AI",
|
||||||
data=user_input,
|
data=user_input,
|
||||||
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
|
options=RECOMMENDED_OPTIONS,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
|
@ -115,18 +123,37 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
|
||||||
def __init__(self, config_entry: ConfigEntry) -> None:
|
def __init__(self, config_entry: ConfigEntry) -> None:
|
||||||
"""Initialize options flow."""
|
"""Initialize options flow."""
|
||||||
self.config_entry = config_entry
|
self.config_entry = config_entry
|
||||||
|
self.last_rendered_recommended = config_entry.options.get(
|
||||||
|
CONF_RECOMMENDED, False
|
||||||
|
)
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Manage the options."""
|
"""Manage the options."""
|
||||||
|
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
if user_input[CONF_LLM_HASS_API] == "none":
|
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
||||||
user_input.pop(CONF_LLM_HASS_API)
|
if user_input[CONF_LLM_HASS_API] == "none":
|
||||||
return self.async_create_entry(title="", data=user_input)
|
user_input.pop(CONF_LLM_HASS_API)
|
||||||
schema = await google_generative_ai_config_option_schema(
|
return self.async_create_entry(title="", data=user_input)
|
||||||
self.hass, self.config_entry.options
|
|
||||||
)
|
# Re-render the options again, now with the recommended options shown/hidden
|
||||||
|
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
|
||||||
|
|
||||||
|
# If we switch to not recommended, generate used prompt.
|
||||||
|
if user_input[CONF_RECOMMENDED]:
|
||||||
|
options = RECOMMENDED_OPTIONS
|
||||||
|
else:
|
||||||
|
options = {
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_PROMPT: DEFAULT_PROMPT
|
||||||
|
+ "\n"
|
||||||
|
+ user_input.get(CONF_TONE_PROMPT, ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
schema = await google_generative_ai_config_option_schema(self.hass, options)
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="init",
|
step_id="init",
|
||||||
data_schema=vol.Schema(schema),
|
data_schema=vol.Schema(schema),
|
||||||
|
@ -135,41 +162,16 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
|
||||||
|
|
||||||
async def google_generative_ai_config_option_schema(
|
async def google_generative_ai_config_option_schema(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
options: MappingProxyType[str, Any],
|
options: dict[str, Any] | MappingProxyType[str, Any],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Return a schema for Google Generative AI completion options."""
|
"""Return a schema for Google Generative AI completion options."""
|
||||||
api_models = await hass.async_add_executor_job(partial(genai.list_models))
|
hass_apis: list[SelectOptionDict] = [
|
||||||
|
|
||||||
models: list[SelectOptionDict] = [
|
|
||||||
SelectOptionDict(
|
|
||||||
label="Gemini 1.5 Flash (recommended)",
|
|
||||||
value="models/gemini-1.5-flash-latest",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
models.extend(
|
|
||||||
SelectOptionDict(
|
|
||||||
label=api_model.display_name,
|
|
||||||
value=api_model.name,
|
|
||||||
)
|
|
||||||
for api_model in sorted(api_models, key=lambda x: x.display_name)
|
|
||||||
if (
|
|
||||||
api_model.name
|
|
||||||
not in (
|
|
||||||
"models/gemini-1.0-pro", # duplicate of gemini-pro
|
|
||||||
"models/gemini-1.5-flash-latest",
|
|
||||||
)
|
|
||||||
and "vision" not in api_model.name
|
|
||||||
and "generateContent" in api_model.supported_generation_methods
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
apis: list[SelectOptionDict] = [
|
|
||||||
SelectOptionDict(
|
SelectOptionDict(
|
||||||
label="No control",
|
label="No control",
|
||||||
value="none",
|
value="none",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
apis.extend(
|
hass_apis.extend(
|
||||||
SelectOptionDict(
|
SelectOptionDict(
|
||||||
label=api.name,
|
label=api.name,
|
||||||
value=api.id,
|
value=api.id,
|
||||||
|
@ -177,45 +179,77 @@ async def google_generative_ai_config_option_schema(
|
||||||
for api in llm.async_get_apis(hass)
|
for api in llm.async_get_apis(hass)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if options.get(CONF_RECOMMENDED):
|
||||||
|
return {
|
||||||
|
vol.Required(
|
||||||
|
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||||
|
): bool,
|
||||||
|
vol.Optional(
|
||||||
|
CONF_TONE_PROMPT,
|
||||||
|
description={"suggested_value": options.get(CONF_TONE_PROMPT)},
|
||||||
|
default="",
|
||||||
|
): TemplateSelector(),
|
||||||
|
vol.Optional(
|
||||||
|
CONF_LLM_HASS_API,
|
||||||
|
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||||
|
default="none",
|
||||||
|
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
|
||||||
|
}
|
||||||
|
|
||||||
|
api_models = await hass.async_add_executor_job(partial(genai.list_models))
|
||||||
|
|
||||||
|
models = [
|
||||||
|
SelectOptionDict(
|
||||||
|
label=api_model.display_name,
|
||||||
|
value=api_model.name,
|
||||||
|
)
|
||||||
|
for api_model in sorted(api_models, key=lambda x: x.display_name)
|
||||||
|
if (
|
||||||
|
api_model.name != "models/gemini-1.0-pro" # duplicate of gemini-pro
|
||||||
|
and "vision" not in api_model.name
|
||||||
|
and "generateContent" in api_model.supported_generation_methods
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
vol.Required(
|
||||||
|
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||||
|
): bool,
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
|
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
|
||||||
default=DEFAULT_CHAT_MODEL,
|
default=RECOMMENDED_CHAT_MODEL,
|
||||||
): SelectSelector(
|
): SelectSelector(
|
||||||
SelectSelectorConfig(
|
SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models)
|
||||||
mode=SelectSelectorMode.DROPDOWN,
|
|
||||||
options=models,
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
vol.Optional(
|
|
||||||
CONF_LLM_HASS_API,
|
|
||||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
|
||||||
default="none",
|
|
||||||
): SelectSelector(SelectSelectorConfig(options=apis)),
|
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
description={"suggested_value": options.get(CONF_PROMPT)},
|
description={"suggested_value": options.get(CONF_PROMPT)},
|
||||||
default=DEFAULT_PROMPT,
|
default=DEFAULT_PROMPT,
|
||||||
): TemplateSelector(),
|
): TemplateSelector(),
|
||||||
|
vol.Optional(
|
||||||
|
CONF_LLM_HASS_API,
|
||||||
|
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||||
|
default="none",
|
||||||
|
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
||||||
default=DEFAULT_TEMPERATURE,
|
default=RECOMMENDED_TEMPERATURE,
|
||||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
description={"suggested_value": options.get(CONF_TOP_P)},
|
description={"suggested_value": options.get(CONF_TOP_P)},
|
||||||
default=DEFAULT_TOP_P,
|
default=RECOMMENDED_TOP_P,
|
||||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_TOP_K,
|
CONF_TOP_K,
|
||||||
description={"suggested_value": options.get(CONF_TOP_K)},
|
description={"suggested_value": options.get(CONF_TOP_K)},
|
||||||
default=DEFAULT_TOP_K,
|
default=RECOMMENDED_TOP_K,
|
||||||
): int,
|
): int,
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
||||||
default=DEFAULT_MAX_TOKENS,
|
default=RECOMMENDED_MAX_TOKENS,
|
||||||
): int,
|
): int,
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import logging
|
||||||
DOMAIN = "google_generative_ai_conversation"
|
DOMAIN = "google_generative_ai_conversation"
|
||||||
LOGGER = logging.getLogger(__package__)
|
LOGGER = logging.getLogger(__package__)
|
||||||
CONF_PROMPT = "prompt"
|
CONF_PROMPT = "prompt"
|
||||||
|
CONF_TONE_PROMPT = "tone_prompt"
|
||||||
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
|
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
|
||||||
|
|
||||||
An overview of the areas and the devices in this smart home:
|
An overview of the areas and the devices in this smart home:
|
||||||
|
@ -23,14 +24,14 @@ An overview of the areas and the devices in this smart home:
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
CONF_RECOMMENDED = "recommended"
|
||||||
CONF_CHAT_MODEL = "chat_model"
|
CONF_CHAT_MODEL = "chat_model"
|
||||||
DEFAULT_CHAT_MODEL = "models/gemini-pro"
|
RECOMMENDED_CHAT_MODEL = "models/gemini-1.5-flash-latest"
|
||||||
CONF_TEMPERATURE = "temperature"
|
CONF_TEMPERATURE = "temperature"
|
||||||
DEFAULT_TEMPERATURE = 0.9
|
RECOMMENDED_TEMPERATURE = 1.0
|
||||||
CONF_TOP_P = "top_p"
|
CONF_TOP_P = "top_p"
|
||||||
DEFAULT_TOP_P = 1.0
|
RECOMMENDED_TOP_P = 0.95
|
||||||
CONF_TOP_K = "top_k"
|
CONF_TOP_K = "top_k"
|
||||||
DEFAULT_TOP_K = 1
|
RECOMMENDED_TOP_K = 64
|
||||||
CONF_MAX_TOKENS = "max_tokens"
|
CONF_MAX_TOKENS = "max_tokens"
|
||||||
DEFAULT_MAX_TOKENS = 150
|
RECOMMENDED_MAX_TOKENS = 150
|
||||||
DEFAULT_ALLOW_HASS_ACCESS = False
|
|
||||||
|
|
|
@ -25,16 +25,17 @@ from .const import (
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
|
CONF_TONE_PROMPT,
|
||||||
CONF_TOP_K,
|
CONF_TOP_K,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
DEFAULT_CHAT_MODEL,
|
|
||||||
DEFAULT_MAX_TOKENS,
|
|
||||||
DEFAULT_PROMPT,
|
DEFAULT_PROMPT,
|
||||||
DEFAULT_TEMPERATURE,
|
|
||||||
DEFAULT_TOP_K,
|
|
||||||
DEFAULT_TOP_P,
|
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
LOGGER,
|
LOGGER,
|
||||||
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_TEMPERATURE,
|
||||||
|
RECOMMENDED_TOP_K,
|
||||||
|
RECOMMENDED_TOP_P,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Max number of back and forth with the LLM to generate a response
|
# Max number of back and forth with the LLM to generate a response
|
||||||
|
@ -156,17 +157,16 @@ class GoogleGenerativeAIConversationEntity(
|
||||||
)
|
)
|
||||||
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
|
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
|
||||||
|
|
||||||
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
|
||||||
model = genai.GenerativeModel(
|
model = genai.GenerativeModel(
|
||||||
model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL),
|
model_name=self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||||
generation_config={
|
generation_config={
|
||||||
"temperature": self.entry.options.get(
|
"temperature": self.entry.options.get(
|
||||||
CONF_TEMPERATURE, DEFAULT_TEMPERATURE
|
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
|
||||||
),
|
),
|
||||||
"top_p": self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P),
|
"top_p": self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||||
"top_k": self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K),
|
"top_k": self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
|
||||||
"max_output_tokens": self.entry.options.get(
|
"max_output_tokens": self.entry.options.get(
|
||||||
CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS
|
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
tools=tools or None,
|
tools=tools or None,
|
||||||
|
@ -179,6 +179,10 @@ class GoogleGenerativeAIConversationEntity(
|
||||||
conversation_id = ulid.ulid_now()
|
conversation_id = ulid.ulid_now()
|
||||||
messages = [{}, {}]
|
messages = [{}, {}]
|
||||||
|
|
||||||
|
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||||
|
if tone_prompt := self.entry.options.get(CONF_TONE_PROMPT):
|
||||||
|
raw_prompt += "\n" + tone_prompt
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prompt = self._async_generate_prompt(raw_prompt, llm_api)
|
prompt = self._async_generate_prompt(raw_prompt, llm_api)
|
||||||
except TemplateError as err:
|
except TemplateError as err:
|
||||||
|
@ -221,7 +225,7 @@ class GoogleGenerativeAIConversationEntity(
|
||||||
if not chat_response.parts:
|
if not chat_response.parts:
|
||||||
intent_response.async_set_error(
|
intent_response.async_set_error(
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
"Sorry, I had a problem talking to Google Generative AI. Likely blocked",
|
"Sorry, I had a problem getting a response from Google Generative AI.",
|
||||||
)
|
)
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=conversation_id
|
||||||
|
|
|
@ -18,13 +18,19 @@
|
||||||
"step": {
|
"step": {
|
||||||
"init": {
|
"init": {
|
||||||
"data": {
|
"data": {
|
||||||
"prompt": "Prompt Template",
|
"recommended": "Recommended settings",
|
||||||
|
"prompt": "Prompt",
|
||||||
|
"tone_prompt": "Tone",
|
||||||
"chat_model": "[%key:common::generic::model%]",
|
"chat_model": "[%key:common::generic::model%]",
|
||||||
"temperature": "Temperature",
|
"temperature": "Temperature",
|
||||||
"top_p": "Top P",
|
"top_p": "Top P",
|
||||||
"top_k": "Top K",
|
"top_k": "Top K",
|
||||||
"max_tokens": "Maximum tokens to return in response",
|
"max_tokens": "Maximum tokens to return in response",
|
||||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
|
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
|
||||||
|
},
|
||||||
|
"data_description": {
|
||||||
|
"prompt": "Extra data to provide to the LLM. This can be a template.",
|
||||||
|
"tone_prompt": "Instructions for the LLM on the style of the generated text. This can be a template."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,11 +8,11 @@
|
||||||
dict({
|
dict({
|
||||||
'generation_config': dict({
|
'generation_config': dict({
|
||||||
'max_output_tokens': 150,
|
'max_output_tokens': 150,
|
||||||
'temperature': 0.9,
|
'temperature': 1.0,
|
||||||
'top_k': 1,
|
'top_k': 64,
|
||||||
'top_p': 1.0,
|
'top_p': 0.95,
|
||||||
}),
|
}),
|
||||||
'model_name': 'models/gemini-pro',
|
'model_name': 'models/gemini-1.5-flash-latest',
|
||||||
'tools': None,
|
'tools': None,
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
@ -67,11 +67,11 @@
|
||||||
dict({
|
dict({
|
||||||
'generation_config': dict({
|
'generation_config': dict({
|
||||||
'max_output_tokens': 150,
|
'max_output_tokens': 150,
|
||||||
'temperature': 0.9,
|
'temperature': 1.0,
|
||||||
'top_k': 1,
|
'top_k': 64,
|
||||||
'top_p': 1.0,
|
'top_p': 0.95,
|
||||||
}),
|
}),
|
||||||
'model_name': 'models/gemini-pro',
|
'model_name': 'models/gemini-1.5-flash-latest',
|
||||||
'tools': None,
|
'tools': None,
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
@ -126,11 +126,11 @@
|
||||||
dict({
|
dict({
|
||||||
'generation_config': dict({
|
'generation_config': dict({
|
||||||
'max_output_tokens': 150,
|
'max_output_tokens': 150,
|
||||||
'temperature': 0.9,
|
'temperature': 1.0,
|
||||||
'top_k': 1,
|
'top_k': 64,
|
||||||
'top_p': 1.0,
|
'top_p': 0.95,
|
||||||
}),
|
}),
|
||||||
'model_name': 'models/gemini-pro',
|
'model_name': 'models/gemini-1.5-flash-latest',
|
||||||
'tools': None,
|
'tools': None,
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
@ -185,11 +185,11 @@
|
||||||
dict({
|
dict({
|
||||||
'generation_config': dict({
|
'generation_config': dict({
|
||||||
'max_output_tokens': 150,
|
'max_output_tokens': 150,
|
||||||
'temperature': 0.9,
|
'temperature': 1.0,
|
||||||
'top_k': 1,
|
'top_k': 64,
|
||||||
'top_p': 1.0,
|
'top_p': 0.95,
|
||||||
}),
|
}),
|
||||||
'model_name': 'models/gemini-pro',
|
'model_name': 'models/gemini-1.5-flash-latest',
|
||||||
'tools': None,
|
'tools': None,
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
|
|
@ -10,13 +10,17 @@ from homeassistant import config_entries
|
||||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
|
CONF_PROMPT,
|
||||||
|
CONF_RECOMMENDED,
|
||||||
|
CONF_TEMPERATURE,
|
||||||
|
CONF_TONE_PROMPT,
|
||||||
CONF_TOP_K,
|
CONF_TOP_K,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
DEFAULT_CHAT_MODEL,
|
|
||||||
DEFAULT_MAX_TOKENS,
|
|
||||||
DEFAULT_TOP_K,
|
|
||||||
DEFAULT_TOP_P,
|
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_TOP_K,
|
||||||
|
RECOMMENDED_TOP_P,
|
||||||
)
|
)
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
@ -42,7 +46,7 @@ def mock_models():
|
||||||
model_10_pro.name = "models/gemini-pro"
|
model_10_pro.name = "models/gemini-pro"
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models",
|
"homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models",
|
||||||
return_value=[model_10_pro],
|
return_value=[model_15_flash, model_10_pro],
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
@ -84,36 +88,89 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||||
"api_key": "bla",
|
"api_key": "bla",
|
||||||
}
|
}
|
||||||
assert result2["options"] == {
|
assert result2["options"] == {
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||||
|
CONF_TONE_PROMPT: "",
|
||||||
}
|
}
|
||||||
assert len(mock_setup_entry.mock_calls) == 1
|
assert len(mock_setup_entry.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_options(
|
@pytest.mark.parametrize(
|
||||||
hass: HomeAssistant, mock_config_entry, mock_init_component, mock_models
|
("current_options", "new_options", "expected_options"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: "none",
|
||||||
|
CONF_TONE_PROMPT: "bla",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_PROMPT: "Speak like a pirate",
|
||||||
|
CONF_TEMPERATURE: 0.3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_PROMPT: "Speak like a pirate",
|
||||||
|
CONF_TEMPERATURE: 0.3,
|
||||||
|
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||||
|
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||||
|
CONF_TOP_K: RECOMMENDED_TOP_K,
|
||||||
|
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_PROMPT: "Speak like a pirate",
|
||||||
|
CONF_TEMPERATURE: 0.3,
|
||||||
|
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||||
|
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||||
|
CONF_TOP_K: RECOMMENDED_TOP_K,
|
||||||
|
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: "assist",
|
||||||
|
CONF_TONE_PROMPT: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: "assist",
|
||||||
|
CONF_TONE_PROMPT: "",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_options_switching(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry,
|
||||||
|
mock_init_component,
|
||||||
|
mock_models,
|
||||||
|
current_options,
|
||||||
|
new_options,
|
||||||
|
expected_options,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the options form."""
|
"""Test the options form."""
|
||||||
|
hass.config_entries.async_update_entry(mock_config_entry, options=current_options)
|
||||||
options_flow = await hass.config_entries.options.async_init(
|
options_flow = await hass.config_entries.options.async_init(
|
||||||
mock_config_entry.entry_id
|
mock_config_entry.entry_id
|
||||||
)
|
)
|
||||||
|
if current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED):
|
||||||
|
options_flow = await hass.config_entries.options.async_configure(
|
||||||
|
options_flow["flow_id"],
|
||||||
|
{
|
||||||
|
**current_options,
|
||||||
|
CONF_RECOMMENDED: new_options[CONF_RECOMMENDED],
|
||||||
|
},
|
||||||
|
)
|
||||||
options = await hass.config_entries.options.async_configure(
|
options = await hass.config_entries.options.async_configure(
|
||||||
options_flow["flow_id"],
|
options_flow["flow_id"],
|
||||||
{
|
new_options,
|
||||||
"prompt": "Speak like a pirate",
|
|
||||||
"temperature": 0.3,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert options["type"] is FlowResultType.CREATE_ENTRY
|
assert options["type"] is FlowResultType.CREATE_ENTRY
|
||||||
assert options["data"]["prompt"] == "Speak like a pirate"
|
assert options["data"] == expected_options
|
||||||
assert options["data"]["temperature"] == 0.3
|
|
||||||
assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL
|
|
||||||
assert options["data"][CONF_TOP_P] == DEFAULT_TOP_P
|
|
||||||
assert options["data"][CONF_TOP_K] == DEFAULT_TOP_K
|
|
||||||
assert options["data"][CONF_MAX_TOKENS] == DEFAULT_MAX_TOKENS
|
|
||||||
assert (
|
|
||||||
CONF_LLM_HASS_API not in options["data"]
|
|
||||||
), "Options flow should not set this key"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
@ -354,7 +354,7 @@ async def test_blocked_response(
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
assert result.response.error_code == "unknown", result
|
assert result.response.error_code == "unknown", result
|
||||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||||
"Sorry, I had a problem talking to Google Generative AI. Likely blocked"
|
"Sorry, I had a problem getting a response from Google Generative AI."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue