Adjust registry access in openai_conversation (#88882)

This commit is contained in:
epenet 2023-03-01 03:59:44 +01:00 committed by GitHub
parent 246f9784c8
commit c724e7c29f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 19 deletions

View file

@ -12,7 +12,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
from homeassistant.helpers import area_registry, intent, template
from homeassistant.helpers import area_registry as ar, intent, template
from homeassistant.util import ulid
from .const import (
@ -150,7 +150,7 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
"areas": list(area_registry.async_get(self.hass).areas.values()),
"areas": list(ar.async_get(self.hass).areas.values()),
},
parse_result=False,
)

View file

@ -5,20 +5,22 @@ from openai import error
from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import area_registry, device_registry, intent
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
from tests.common import MockConfigEntry
async def test_default_prompt(hass: HomeAssistant, mock_init_component) -> None:
async def test_default_prompt(
hass: HomeAssistant,
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
) -> None:
"""Test that the default prompt works."""
device_reg = device_registry.async_get(hass)
area_reg = area_registry.async_get(hass)
for i in range(3):
area_reg.async_create(f"{i}Empty Area")
area_registry.async_create(f"{i}Empty Area")
device_reg.async_get_or_create(
device_registry.async_get_or_create(
config_entry_id="1234",
connections={("test", "1234")},
name="Test Device",
@ -27,16 +29,16 @@ async def test_default_prompt(hass: HomeAssistant, mock_init_component) -> None:
suggested_area="Test Area",
)
for i in range(3):
device_reg.async_get_or_create(
device_registry.async_get_or_create(
config_entry_id="1234",
connections={("test", f"{i}abcd")},
name="Test Service",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
entry_type=device_registry.DeviceEntryType.SERVICE,
entry_type=dr.DeviceEntryType.SERVICE,
)
device_reg.async_get_or_create(
device_registry.async_get_or_create(
config_entry_id="1234",
connections={("test", "5678")},
name="Test Device 2",
@ -44,7 +46,7 @@ async def test_default_prompt(hass: HomeAssistant, mock_init_component) -> None:
model="Device 2",
suggested_area="Test Area 2",
)
device_reg.async_get_or_create(
device_registry.async_get_or_create(
config_entry_id="1234",
connections={("test", "9876")},
name="Test Device 3",
@ -52,13 +54,13 @@ async def test_default_prompt(hass: HomeAssistant, mock_init_component) -> None:
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_reg.async_get_or_create(
device_registry.async_get_or_create(
config_entry_id="1234",
connections={("test", "qwer")},
name="Test Device 4",
suggested_area="Test Area 2",
)
device = device_reg.async_get_or_create(
device = device_registry.async_get_or_create(
config_entry_id="1234",
connections={("test", "9876-disabled")},
name="Test Device 3",
@ -66,17 +68,17 @@ async def test_default_prompt(hass: HomeAssistant, mock_init_component) -> None:
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_reg.async_update_device(
device.id, disabled_by=device_registry.DeviceEntryDisabler.USER
device_registry.async_update_device(
device.id, disabled_by=dr.DeviceEntryDisabler.USER
)
device_reg.async_get_or_create(
device_registry.async_get_or_create(
config_entry_id="1234",
connections={("test", "9876-no-name")},
manufacturer="Test Manufacturer NoName",
model="Test Model NoName",
suggested_area="Test Area 2",
)
device_reg.async_get_or_create(
device_registry.async_get_or_create(
config_entry_id="1234",
connections={("test", "9876-integer-values")},
name=1,