"""Conversation support for OpenAI.""" import json from typing import Any, Literal import openai import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation from homeassistant.components.conversation import trace from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.helpers import device_registry as dr, intent, llm, template from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util import ulid from .const import ( CONF_CHAT_MODEL, CONF_MAX_TOKENS, CONF_PROMPT, CONF_TEMPERATURE, CONF_TOP_P, DOMAIN, LOGGER, RECOMMENDED_CHAT_MODEL, RECOMMENDED_MAX_TOKENS, RECOMMENDED_TEMPERATURE, RECOMMENDED_TOP_P, ) # Max number of back and forth with the LLM to generate a response MAX_TOOL_ITERATIONS = 10 async def async_setup_entry( hass: HomeAssistant, config_entry: ConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up conversation entities.""" agent = OpenAIConversationEntity(config_entry) async_add_entities([agent]) def _format_tool(tool: llm.Tool) -> dict[str, Any]: """Format tool specification.""" tool_spec = {"name": tool.name} if tool.description: tool_spec["description"] = tool.description tool_spec["parameters"] = convert(tool.parameters) return {"type": "function", "function": tool_spec} class OpenAIConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): """OpenAI conversation agent.""" _attr_has_entity_name = True _attr_name = None def __init__(self, entry: ConfigEntry) -> None: """Initialize the agent.""" self.entry = entry self.history: dict[str, list[dict]] = {} self._attr_unique_id = entry.entry_id self._attr_device_info = dr.DeviceInfo( identifiers={(DOMAIN, entry.entry_id)}, name=entry.title, manufacturer="OpenAI", model="ChatGPT", entry_type=dr.DeviceEntryType.SERVICE, ) @property def supported_languages(self) -> list[str] | Literal["*"]: """Return a list of supported languages.""" return MATCH_ALL async def async_added_to_hass(self) -> None: """When entity is added to Home Assistant.""" await super().async_added_to_hass() assist_pipeline.async_migrate_engine( self.hass, "conversation", self.entry.entry_id, self.entity_id ) conversation.async_set_agent(self.hass, self.entry, self) async def async_will_remove_from_hass(self) -> None: """When entity will be removed from Home Assistant.""" conversation.async_unset_agent(self.hass, self.entry) await super().async_will_remove_from_hass() async def async_process( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" options = self.entry.options intent_response = intent.IntentResponse(language=user_input.language) llm_api: llm.API | None = None tools: list[dict[str, Any]] | None = None if options.get(CONF_LLM_HASS_API): try: llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API]) except HomeAssistantError as err: LOGGER.error("Error getting LLM API: %s", err) intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, f"Error preparing LLM API: {err}", ) return conversation.ConversationResult( response=intent_response, conversation_id=user_input.conversation_id ) tools = [_format_tool(tool) for tool in llm_api.async_get_tools()] if user_input.conversation_id in self.history: conversation_id = user_input.conversation_id messages = self.history[conversation_id] else: conversation_id = ulid.ulid_now() try: if llm_api: empty_tool_input = llm.ToolInput( tool_name="", tool_args={}, platform=DOMAIN, context=user_input.context, user_prompt=user_input.text, language=user_input.language, assistant=conversation.DOMAIN, device_id=user_input.device_id, ) api_prompt = await llm_api.async_get_api_prompt(empty_tool_input) else: api_prompt = llm.async_render_no_api_prompt(self.hass) prompt = "\n".join( ( template.Template( options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT), self.hass, ).async_render( { "ha_name": self.hass.config.location_name, }, parse_result=False, ), api_prompt, ) ) except TemplateError as err: LOGGER.error("Error rendering prompt: %s", err) intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, f"Sorry, I had a problem with my template: {err}", ) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) messages = [{"role": "system", "content": prompt}] messages.append({"role": "user", "content": user_input.text}) LOGGER.debug("Prompt: %s", messages) trace.async_conversation_trace_append( trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages} ) client = self.hass.data[DOMAIN][self.entry.entry_id] # To prevent infinite loops, we limit the number of iterations for _iteration in range(MAX_TOOL_ITERATIONS): try: result = await client.chat.completions.create( model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), messages=messages, tools=tools, max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), user=conversation_id, ) except openai.OpenAIError as err: intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, f"Sorry, I had a problem talking to OpenAI: {err}", ) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) LOGGER.debug("Response %s", result) response = result.choices[0].message messages.append(response) tool_calls = response.tool_calls if not tool_calls or not llm_api: break for tool_call in tool_calls: tool_input = llm.ToolInput( tool_name=tool_call.function.name, tool_args=json.loads(tool_call.function.arguments), platform=DOMAIN, context=user_input.context, user_prompt=user_input.text, language=user_input.language, assistant=conversation.DOMAIN, device_id=user_input.device_id, ) LOGGER.debug( "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args ) try: tool_response = await llm_api.async_call_tool(tool_input) except (HomeAssistantError, vol.Invalid) as e: tool_response = {"error": type(e).__name__} if str(e): tool_response["error_text"] = str(e) LOGGER.debug("Tool response: %s", tool_response) messages.append( { "role": "tool", "tool_call_id": tool_call.id, "name": tool_call.function.name, "content": json.dumps(tool_response), } ) self.history[conversation_id] = messages intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_speech(response.content) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id )