Extract Collection helper from Person integration (#30313)

* Add CRUD foundation

* Use collection helper in person integration

* Lint/pytest

* Add tests

* Lint

* Create notification
This commit is contained in:
Paulus Schoutsen 2020-01-03 21:37:11 +01:00 committed by GitHub
parent 3033dbd86c
commit b9aba30a6e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 1074 additions and 396 deletions

View file

@ -173,13 +173,13 @@ async def _async_get_device_automation_capabilities(hass, automation_type, autom
return capabilities return capabilities
@websocket_api.async_response
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "device_automation/action/list", vol.Required("type"): "device_automation/action/list",
vol.Required("device_id"): str, vol.Required("device_id"): str,
} }
) )
@websocket_api.async_response
async def websocket_device_automation_list_actions(hass, connection, msg): async def websocket_device_automation_list_actions(hass, connection, msg):
"""Handle request for device actions.""" """Handle request for device actions."""
device_id = msg["device_id"] device_id = msg["device_id"]
@ -187,13 +187,13 @@ async def websocket_device_automation_list_actions(hass, connection, msg):
connection.send_result(msg["id"], actions) connection.send_result(msg["id"], actions)
@websocket_api.async_response
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "device_automation/condition/list", vol.Required("type"): "device_automation/condition/list",
vol.Required("device_id"): str, vol.Required("device_id"): str,
} }
) )
@websocket_api.async_response
async def websocket_device_automation_list_conditions(hass, connection, msg): async def websocket_device_automation_list_conditions(hass, connection, msg):
"""Handle request for device conditions.""" """Handle request for device conditions."""
device_id = msg["device_id"] device_id = msg["device_id"]
@ -201,13 +201,13 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
connection.send_result(msg["id"], conditions) connection.send_result(msg["id"], conditions)
@websocket_api.async_response
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "device_automation/trigger/list", vol.Required("type"): "device_automation/trigger/list",
vol.Required("device_id"): str, vol.Required("device_id"): str,
} }
) )
@websocket_api.async_response
async def websocket_device_automation_list_triggers(hass, connection, msg): async def websocket_device_automation_list_triggers(hass, connection, msg):
"""Handle request for device triggers.""" """Handle request for device triggers."""
device_id = msg["device_id"] device_id = msg["device_id"]
@ -215,13 +215,13 @@ async def websocket_device_automation_list_triggers(hass, connection, msg):
connection.send_result(msg["id"], triggers) connection.send_result(msg["id"], triggers)
@websocket_api.async_response
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "device_automation/action/capabilities", vol.Required("type"): "device_automation/action/capabilities",
vol.Required("action"): dict, vol.Required("action"): dict,
} }
) )
@websocket_api.async_response
async def websocket_device_automation_get_action_capabilities(hass, connection, msg): async def websocket_device_automation_get_action_capabilities(hass, connection, msg):
"""Handle request for device action capabilities.""" """Handle request for device action capabilities."""
action = msg["action"] action = msg["action"]
@ -231,13 +231,13 @@ async def websocket_device_automation_get_action_capabilities(hass, connection,
connection.send_result(msg["id"], capabilities) connection.send_result(msg["id"], capabilities)
@websocket_api.async_response
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "device_automation/condition/capabilities", vol.Required("type"): "device_automation/condition/capabilities",
vol.Required("condition"): dict, vol.Required("condition"): dict,
} }
) )
@websocket_api.async_response
async def websocket_device_automation_get_condition_capabilities(hass, connection, msg): async def websocket_device_automation_get_condition_capabilities(hass, connection, msg):
"""Handle request for device condition capabilities.""" """Handle request for device condition capabilities."""
condition = msg["condition"] condition = msg["condition"]
@ -247,13 +247,13 @@ async def websocket_device_automation_get_condition_capabilities(hass, connectio
connection.send_result(msg["id"], capabilities) connection.send_result(msg["id"], capabilities)
@websocket_api.async_response
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "device_automation/trigger/capabilities", vol.Required("type"): "device_automation/trigger/capabilities",
vol.Required("trigger"): dict, vol.Required("trigger"): dict,
} }
) )
@websocket_api.async_response
async def websocket_device_automation_get_trigger_capabilities(hass, connection, msg): async def websocket_device_automation_get_trigger_capabilities(hass, connection, msg):
"""Handle request for device trigger capabilities.""" """Handle request for device trigger capabilities."""
trigger = msg["trigger"] trigger = msg["trigger"]

View file

@ -1,9 +1,6 @@
"""Support for tracking people.""" """Support for tracking people."""
from collections import OrderedDict
from itertools import chain
import logging import logging
from typing import Optional from typing import List, Optional
import uuid
import voluptuous as vol import voluptuous as vol
@ -28,6 +25,7 @@ from homeassistant.const import (
STATE_UNKNOWN, STATE_UNKNOWN,
) )
from homeassistant.core import Event, State, callback from homeassistant.core import Event, State, callback
from homeassistant.helpers import collection, entity_registry
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_track_state_change from homeassistant.helpers.event import async_track_state_change
@ -48,8 +46,7 @@ CONF_USER_ID = "user_id"
DOMAIN = "person" DOMAIN = "person"
STORAGE_KEY = DOMAIN STORAGE_KEY = DOMAIN
STORAGE_VERSION = 1 STORAGE_VERSION = 2
SAVE_DELAY = 10
# Device tracker states to ignore # Device tracker states to ignore
IGNORE_STATES = (STATE_UNKNOWN, STATE_UNAVAILABLE) IGNORE_STATES = (STATE_UNKNOWN, STATE_UNAVAILABLE)
@ -75,217 +72,184 @@ _UNDEF = object()
@bind_hass @bind_hass
async def async_create_person(hass, name, *, user_id=None, device_trackers=None): async def async_create_person(hass, name, *, user_id=None, device_trackers=None):
"""Create a new person.""" """Create a new person."""
await hass.data[DOMAIN].async_create_person( await hass.data[DOMAIN][1].async_create_item(
name=name, user_id=user_id, device_trackers=device_trackers {"name": name, "user_id": user_id, "device_trackers": device_trackers}
) )
class PersonManager: CREATE_FIELDS = {
"""Manage person data.""" vol.Required("name"): vol.All(str, vol.Length(min=1)),
vol.Optional("user_id"): vol.Any(str, None),
vol.Optional("device_trackers", default=list): vol.All(
cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN)
),
}
UPDATE_FIELDS = {
vol.Optional("name"): vol.All(str, vol.Length(min=1)),
vol.Optional("user_id"): vol.Any(str, None),
vol.Optional("device_trackers", default=list): vol.All(
cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN)
),
}
class PersonStore(Store):
"""Person storage."""
async def _async_migrate_func(self, old_version, old_data):
"""Migrate to the new version.
Migrate storage to use format of collection helper.
"""
return {"items": old_data["persons"]}
class PersonStorageCollection(collection.StorageCollection):
"""Person collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
UPDATE_SCHEMA = vol.Schema(UPDATE_FIELDS)
def __init__( def __init__(
self, hass: HomeAssistantType, component: EntityComponent, config_persons self,
store: Store,
logger: logging.Logger,
id_manager: collection.IDManager,
yaml_collection: collection.YamlCollection,
): ):
"""Initialize person storage.""" """Initialize a person storage collection."""
self.hass = hass super().__init__(store, logger, id_manager)
self.component = component self.async_add_listener(self._collection_changed)
self.store = Store(hass, STORAGE_VERSION, STORAGE_KEY) self.yaml_collection = yaml_collection
self.storage_data = None
config_data = self.config_data = OrderedDict() async def _process_create_data(self, data: dict) -> dict:
for conf in config_persons: """Validate the config is valid."""
person_id = conf[CONF_ID] data = self.CREATE_SCHEMA(data)
if person_id in config_data: user_id = data.get("user_id")
_LOGGER.error("Found config user with duplicate ID: %s", person_id)
continue
config_data[person_id] = conf
@property
def storage_persons(self):
"""Iterate over persons stored in storage."""
return list(self.storage_data.values())
@property
def config_persons(self):
"""Iterate over persons stored in config."""
return list(self.config_data.values())
async def async_initialize(self):
"""Get the person data."""
raw_storage = await self.store.async_load()
if raw_storage is None:
raw_storage = {"persons": []}
storage_data = self.storage_data = OrderedDict()
for person in raw_storage["persons"]:
storage_data[person[CONF_ID]] = person
entities = []
seen_users = set()
for person_conf in self.config_data.values():
person_id = person_conf[CONF_ID]
user_id = person_conf.get(CONF_USER_ID)
if user_id is not None:
if await self.hass.auth.async_get_user(user_id) is None:
_LOGGER.error("Invalid user_id detected for person %s", person_id)
continue
if user_id in seen_users:
_LOGGER.error(
"Duplicate user_id %s detected for person %s",
user_id,
person_id,
)
continue
seen_users.add(user_id)
entities.append(Person(person_conf, False))
# To make sure IDs don't overlap between config/storage
seen_persons = set(self.config_data)
for person_conf in storage_data.values():
person_id = person_conf[CONF_ID]
user_id = person_conf[CONF_USER_ID]
if person_id in seen_persons:
_LOGGER.error(
"Skipping adding person from storage with same ID as"
" configuration.yaml entry: %s",
person_id,
)
continue
if user_id is not None and user_id in seen_users:
_LOGGER.error(
"Duplicate user_id %s detected for person %s", user_id, person_id
)
continue
# To make sure all users have just 1 person linked.
seen_users.add(user_id)
entities.append(Person(person_conf, True))
if entities:
await self.component.async_add_entities(entities)
self.hass.bus.async_listen(EVENT_USER_REMOVED, self._user_removed)
async def async_create_person(self, *, name, device_trackers=None, user_id=None):
"""Create a new person."""
if not name:
raise ValueError("Name is required")
if user_id is not None: if user_id is not None:
await self._validate_user_id(user_id) await self._validate_user_id(user_id)
person = { return self.CREATE_SCHEMA(data)
CONF_ID: uuid.uuid4().hex,
CONF_NAME: name,
CONF_USER_ID: user_id,
CONF_DEVICE_TRACKERS: device_trackers or [],
}
self.storage_data[person[CONF_ID]] = person
self._async_schedule_save()
await self.component.async_add_entities([Person(person, True)])
return person
async def async_update_person( @callback
self, person_id, *, name=_UNDEF, device_trackers=_UNDEF, user_id=_UNDEF def _get_suggested_id(self, info: dict) -> str:
): """Suggest an ID based on the config."""
"""Update person.""" return info["name"]
current = self.storage_data.get(person_id)
if current is None: async def _update_data(self, data: dict, update_data: dict) -> dict:
raise ValueError("Invalid person specified.") """Return a new updated data object."""
update_data = self.UPDATE_SCHEMA(update_data)
changes = { user_id = update_data.get("user_id")
key: value
for key, value in (
(CONF_NAME, name),
(CONF_DEVICE_TRACKERS, device_trackers),
(CONF_USER_ID, user_id),
)
if value is not _UNDEF and current[key] != value
}
if CONF_USER_ID in changes and user_id is not None: if user_id is not None:
await self._validate_user_id(user_id) await self._validate_user_id(user_id)
self.storage_data[person_id].update(changes) return {**data, **update_data}
self._async_schedule_save()
for entity in self.component.entities:
if entity.unique_id == person_id:
entity.person_updated()
break
return self.storage_data[person_id]
async def async_delete_person(self, person_id):
"""Delete person."""
if person_id not in self.storage_data:
raise ValueError("Invalid person specified.")
self.storage_data.pop(person_id)
self._async_schedule_save()
ent_reg = await self.hass.helpers.entity_registry.async_get_registry()
for entity in self.component.entities:
if entity.unique_id == person_id:
await entity.async_remove()
ent_reg.async_remove(entity.entity_id)
break
@callback
def _async_schedule_save(self) -> None:
"""Schedule saving the area registry."""
self.store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self) -> dict:
"""Return data of area registry to store in a file."""
return {"persons": list(self.storage_data.values())}
async def _validate_user_id(self, user_id): async def _validate_user_id(self, user_id):
"""Validate the used user_id.""" """Validate the used user_id."""
if await self.hass.auth.async_get_user(user_id) is None: if await self.hass.auth.async_get_user(user_id) is None:
raise ValueError("User does not exist") raise ValueError("User does not exist")
if any( for persons in (self.data.values(), self.yaml_collection.async_items()):
person if any(person for person in persons if person.get(CONF_USER_ID) == user_id):
for person in chain(self.storage_data.values(), self.config_data.values()) raise ValueError("User already taken")
if person.get(CONF_USER_ID) == user_id
):
raise ValueError("User already taken")
async def _user_removed(self, event: Event): async def _collection_changed(
"""Handle event that a person is removed.""" self, change_type: str, item_id: str, config: Optional[dict]
user_id = event.data["user_id"] ) -> None:
for person in self.storage_data.values(): """Handle a collection change."""
if person[CONF_USER_ID] == user_id: if change_type != collection.CHANGE_REMOVED:
await self.async_update_person(person_id=person[CONF_ID], user_id=None) return
ent_reg = await entity_registry.async_get_registry(self.hass)
ent_reg.async_remove(ent_reg.async_get_entity_id(DOMAIN, DOMAIN, item_id))
async def filter_yaml_data(hass: HomeAssistantType, persons: List[dict]) -> List[dict]:
"""Validate YAML data that we can't validate via schema."""
filtered = []
person_invalid_user = []
for person_conf in persons:
user_id = person_conf.get(CONF_USER_ID)
if user_id is not None:
if await hass.auth.async_get_user(user_id) is None:
_LOGGER.error(
"Invalid user_id detected for person %s",
person_conf[collection.CONF_ID],
)
person_invalid_user.append(
f"- Person {person_conf[CONF_NAME]} (id: {person_conf[collection.CONF_ID]}) points at invalid user {user_id}"
)
continue
filtered.append(person_conf)
if person_invalid_user:
hass.components.persistent_notification.async_create(
f"""
The following persons point at invalid users:
{"- ".join(person_invalid_user)}
""",
"Invalid Person Configuration",
DOMAIN,
)
return filtered
async def async_setup(hass: HomeAssistantType, config: ConfigType): async def async_setup(hass: HomeAssistantType, config: ConfigType):
"""Set up the person component.""" """Set up the person component."""
component = EntityComponent(_LOGGER, DOMAIN, hass) entity_component = EntityComponent(_LOGGER, DOMAIN, hass)
conf_persons = config.get(DOMAIN, []) id_manager = collection.IDManager()
manager = hass.data[DOMAIN] = PersonManager(hass, component, conf_persons) yaml_collection = collection.YamlCollection(
await manager.async_initialize() logging.getLogger(f"{__name__}.yaml_collection"), id_manager
)
storage_collection = PersonStorageCollection(
PersonStore(hass, STORAGE_VERSION, STORAGE_KEY),
logging.getLogger(f"{__name__}.storage_collection"),
id_manager,
yaml_collection,
)
collection.attach_entity_component_collection(
entity_component, yaml_collection, lambda conf: Person(conf, False)
)
collection.attach_entity_component_collection(
entity_component, storage_collection, lambda conf: Person(conf, True)
)
await yaml_collection.async_load(
await filter_yaml_data(hass, config.get(DOMAIN, []))
)
await storage_collection.async_load()
hass.data[DOMAIN] = (yaml_collection, storage_collection)
collection.StorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS
).async_setup(hass, create_list=False)
websocket_api.async_register_command(hass, ws_list_person) websocket_api.async_register_command(hass, ws_list_person)
websocket_api.async_register_command(hass, ws_create_person)
websocket_api.async_register_command(hass, ws_update_person) async def _handle_user_removed(event: Event) -> None:
websocket_api.async_register_command(hass, ws_delete_person) """Handle a user being removed."""
user_id = event.data["user_id"]
for person in storage_collection.async_items():
if person[CONF_USER_ID] == user_id:
await storage_collection.async_update_item(
person[CONF_ID], {"user_id": None}
)
hass.bus.async_listen(EVENT_USER_REMOVED, _handle_user_removed)
return True return True
@ -353,21 +317,21 @@ class Person(RestoreEntity):
if self.hass.is_running: if self.hass.is_running:
# Update person now if hass is already running. # Update person now if hass is already running.
self.person_updated() await self.async_update_config(self._config)
else: else:
# Wait for hass start to not have race between person # Wait for hass start to not have race between person
# and device trackers finishing setup. # and device trackers finishing setup.
@callback async def person_start_hass(now):
def person_start_hass(now): await self.async_update_config(self._config)
self.person_updated()
self.hass.bus.async_listen_once( self.hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_START, person_start_hass EVENT_HOMEASSISTANT_START, person_start_hass
) )
@callback async def async_update_config(self, config):
def person_updated(self):
"""Handle when the config is updated.""" """Handle when the config is updated."""
self._config = config
if self._unsub_track_device is not None: if self._unsub_track_device is not None:
self._unsub_track_device() self._unsub_track_device()
self._unsub_track_device = None self._unsub_track_device = None
@ -441,89 +405,12 @@ def ws_list_person(
hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg
): ):
"""List persons.""" """List persons."""
manager: PersonManager = hass.data[DOMAIN] yaml, storage = hass.data[DOMAIN]
connection.send_result( connection.send_result(
msg["id"], msg["id"], {"storage": storage.async_items(), "config": yaml.async_items()},
{"storage": manager.storage_persons, "config": manager.config_persons},
) )
@websocket_api.websocket_command(
{
vol.Required("type"): "person/create",
vol.Required("name"): vol.All(str, vol.Length(min=1)),
vol.Optional("user_id"): vol.Any(str, None),
vol.Optional("device_trackers", default=[]): vol.All(
cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN)
),
}
)
@websocket_api.require_admin
@websocket_api.async_response
async def ws_create_person(
hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg
):
"""Create a person."""
manager: PersonManager = hass.data[DOMAIN]
try:
person = await manager.async_create_person(
name=msg["name"],
user_id=msg.get("user_id"),
device_trackers=msg["device_trackers"],
)
connection.send_result(msg["id"], person)
except ValueError as err:
connection.send_error(
msg["id"], websocket_api.const.ERR_INVALID_FORMAT, str(err)
)
@websocket_api.websocket_command(
{
vol.Required("type"): "person/update",
vol.Required("person_id"): str,
vol.Required("name"): vol.All(str, vol.Length(min=1)),
vol.Optional("user_id"): vol.Any(str, None),
vol.Optional(CONF_DEVICE_TRACKERS, default=[]): vol.All(
cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN)
),
}
)
@websocket_api.require_admin
@websocket_api.async_response
async def ws_update_person(
hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg
):
"""Update a person."""
manager: PersonManager = hass.data[DOMAIN]
changes = {}
for key in ("name", "user_id", "device_trackers"):
if key in msg:
changes[key] = msg[key]
try:
person = await manager.async_update_person(msg["person_id"], **changes)
connection.send_result(msg["id"], person)
except ValueError as err:
connection.send_error(
msg["id"], websocket_api.const.ERR_INVALID_FORMAT, str(err)
)
@websocket_api.websocket_command(
{vol.Required("type"): "person/delete", vol.Required("person_id"): str}
)
@websocket_api.require_admin
@websocket_api.async_response
async def ws_delete_person(
hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg
):
"""Delete a person."""
manager: PersonManager = hass.data[DOMAIN]
await manager.async_delete_person(msg["person_id"])
connection.send_result(msg["id"])
def _get_latest(prev: Optional[State], curr: State): def _get_latest(prev: Optional[State], curr: State):
"""Get latest state.""" """Get latest state."""
if prev is None or curr.last_updated > prev.last_updated: if prev is None or curr.last_updated > prev.last_updated:

View file

@ -1,5 +1,9 @@
"""WebSocket based API for Home Assistant.""" """WebSocket based API for Home Assistant."""
from homeassistant.core import callback from typing import Optional, Union, cast
import voluptuous as vol
from homeassistant.core import HomeAssistant, callback
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from . import commands, connection, const, decorators, http, messages from . import commands, connection, const, decorators, http, messages
@ -26,13 +30,18 @@ websocket_command = decorators.websocket_command
@bind_hass @bind_hass
@callback @callback
def async_register_command(hass, command_or_handler, handler=None, schema=None): def async_register_command(
hass: HomeAssistant,
command_or_handler: Union[str, const.WebSocketCommandHandler],
handler: Optional[const.WebSocketCommandHandler] = None,
schema: Optional[vol.Schema] = None,
) -> None:
"""Register a websocket command.""" """Register a websocket command."""
# pylint: disable=protected-access # pylint: disable=protected-access
if handler is None: if handler is None:
handler = command_or_handler handler = cast(const.WebSocketCommandHandler, command_or_handler)
command = handler._ws_command command = handler._ws_command # type: ignore
schema = handler._ws_schema schema = handler._ws_schema # type: ignore
else: else:
command = command_or_handler command = command_or_handler
handlers = hass.data.get(DOMAIN) handlers = hass.data.get(DOMAIN)

View file

@ -107,7 +107,6 @@ def handle_unsubscribe_events(hass, connection, msg):
) )
@decorators.async_response
@decorators.websocket_command( @decorators.websocket_command(
{ {
vol.Required("type"): "call_service", vol.Required("type"): "call_service",
@ -116,6 +115,7 @@ def handle_unsubscribe_events(hass, connection, msg):
vol.Optional("service_data"): dict, vol.Optional("service_data"): dict,
} }
) )
@decorators.async_response
async def handle_call_service(hass, connection, msg): async def handle_call_service(hass, connection, msg):
"""Handle call service command. """Handle call service command.
@ -181,8 +181,8 @@ def handle_get_states(hass, connection, msg):
connection.send_message(messages.result_message(msg["id"], states)) connection.send_message(messages.result_message(msg["id"], states))
@decorators.async_response
@decorators.websocket_command({vol.Required("type"): "get_services"}) @decorators.websocket_command({vol.Required("type"): "get_services"})
@decorators.async_response
async def handle_get_services(hass, connection, msg): async def handle_get_services(hass, connection, msg):
"""Handle get services command. """Handle get services command.

View file

@ -1,6 +1,6 @@
"""Connection session.""" """Connection session."""
import asyncio import asyncio
from typing import Any, Callable, Dict, Hashable from typing import Any, Callable, Dict, Hashable, Optional
import voluptuous as vol import voluptuous as vol
@ -37,7 +37,7 @@ class ActiveConnection:
return Context(user_id=user.id) return Context(user_id=user.id)
@callback @callback
def send_result(self, msg_id, result=None): def send_result(self, msg_id: int, result: Optional[Any] = None) -> None:
"""Send a result message.""" """Send a result message."""
self.send_message(messages.result_message(msg_id, result)) self.send_message(messages.result_message(msg_id, result))
@ -49,7 +49,7 @@ class ActiveConnection:
self.send_message(content) self.send_message(content)
@callback @callback
def send_error(self, msg_id, code, message): def send_error(self, msg_id: int, code: str, message: str) -> None:
"""Send a error message.""" """Send a error message."""
self.send_message(messages.error_message(msg_id, code, message)) self.send_message(messages.error_message(msg_id, code, message))

View file

@ -3,9 +3,20 @@ import asyncio
from concurrent import futures from concurrent import futures
from functools import partial from functools import partial
import json import json
from typing import TYPE_CHECKING, Callable
from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
if TYPE_CHECKING:
from .connection import ActiveConnection # noqa
WebSocketCommandHandler = Callable[
[HomeAssistant, "ActiveConnection", dict], None
] # pylint: disable=invalid-name
DOMAIN = "websocket_api" DOMAIN = "websocket_api"
URL = "/api/websocket" URL = "/api/websocket"
MAX_PENDING_MSG = 512 MAX_PENDING_MSG = 512

View file

@ -1,11 +1,13 @@
"""Decorators for the Websocket API.""" """Decorators for the Websocket API."""
from functools import wraps from functools import wraps
import logging import logging
from typing import Awaitable, Callable
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import Unauthorized from homeassistant.exceptions import Unauthorized
from . import messages from . import const, messages
from .connection import ActiveConnection
# mypy: allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs
@ -20,7 +22,9 @@ async def _handle_async_response(func, hass, connection, msg):
connection.async_handle_exception(msg, err) connection.async_handle_exception(msg, err)
def async_response(func): def async_response(
func: Callable[[HomeAssistant, ActiveConnection, dict], Awaitable[None]]
) -> const.WebSocketCommandHandler:
"""Decorate an async function to handle WebSocket API messages.""" """Decorate an async function to handle WebSocket API messages."""
@callback @callback
@ -32,7 +36,7 @@ def async_response(func):
return schedule_handler return schedule_handler
def require_admin(func): def require_admin(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
"""Websocket decorator to require user to be an admin.""" """Websocket decorator to require user to be an admin."""
@wraps(func) @wraps(func)
@ -104,7 +108,9 @@ def ws_require_user(
return validator return validator
def websocket_command(schema): def websocket_command(
schema: dict,
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
"""Tag a function as a websocket command.""" """Tag a function as a websocket command."""
command = schema["type"] command = schema["type"]

View file

@ -0,0 +1,401 @@
"""Helper to deal with YAML + storage."""
from abc import ABC, abstractmethod
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional, cast
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.components import websocket_api
from homeassistant.const import CONF_ID
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.storage import Store
from homeassistant.util import slugify
STORAGE_VERSION = 1
SAVE_DELAY = 10
CHANGE_ADDED = "added"
CHANGE_UPDATED = "updated"
CHANGE_REMOVED = "removed"
ChangeListener = Callable[
[
# Change type
str,
# Item ID
str,
# New config (None if removed)
Optional[dict],
],
Awaitable[None],
] # pylint: disable=invalid-name
class CollectionError(HomeAssistantError):
"""Base class for collection related errors."""
class ItemNotFound(CollectionError):
"""Raised when an item is not found."""
def __init__(self, item_id: str):
"""Initialize item not found error."""
super().__init__(f"Item {item_id} not found.")
self.item_id = item_id
class IDManager:
"""Keep track of IDs across different collections."""
def __init__(self) -> None:
"""Initiate the ID manager."""
self.collections: List[Dict[str, Any]] = []
def add_collection(self, collection: Dict[str, Any]) -> None:
"""Add a collection to check for ID usage."""
self.collections.append(collection)
def has_id(self, item_id: str) -> bool:
"""Test if the ID exists."""
return any(item_id in collection for collection in self.collections)
def generate_id(self, suggestion: str) -> str:
"""Generate an ID."""
base = slugify(suggestion)
proposal = base
attempt = 1
while self.has_id(proposal):
attempt += 1
proposal = f"{base}_{attempt}"
return proposal
class ObservableCollection(ABC):
"""Base collection type that can be observed."""
def __init__(self, logger: logging.Logger, id_manager: Optional[IDManager] = None):
"""Initialize the base collection."""
self.logger = logger
self.id_manager = id_manager or IDManager()
self.data: Dict[str, dict] = {}
self.listeners: List[ChangeListener] = []
self.id_manager.add_collection(self.data)
@callback
def async_items(self) -> List[dict]:
"""Return list of items in collection."""
return list(self.data.values())
@callback
def async_add_listener(self, listener: ChangeListener) -> None:
"""Add a listener.
Will be called with (change_type, item_id, updated_config).
"""
self.listeners.append(listener)
async def notify_change(
self, change_type: str, item_id: str, item: Optional[dict]
) -> None:
"""Notify listeners of a change."""
self.logger.debug("%s %s: %s", change_type, item_id, item)
for listener in self.listeners:
await listener(change_type, item_id, item)
class YamlCollection(ObservableCollection):
"""Offer a fake CRUD interface on top of static YAML."""
async def async_load(self, data: List[dict]) -> None:
"""Load the storage Manager."""
for item in data:
item_id = item[CONF_ID]
if self.id_manager.has_id(item_id):
self.logger.warning("Duplicate ID '%s' detected, skipping", item_id)
continue
self.data[item_id] = item
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
class StorageCollection(ObservableCollection):
"""Offer a CRUD interface on top of JSON storage."""
def __init__(
self,
store: Store,
logger: logging.Logger,
id_manager: Optional[IDManager] = None,
):
"""Initialize the storage collection."""
super().__init__(logger, id_manager)
self.store = store
@property
def hass(self) -> HomeAssistant:
"""Home Assistant object."""
return self.store.hass
async def async_load(self) -> None:
"""Load the storage Manager."""
raw_storage = cast(Optional[dict], await self.store.async_load())
if raw_storage is None:
raw_storage = {"items": []}
for item in raw_storage["items"]:
self.data[item[CONF_ID]] = item
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
@abstractmethod
async def _process_create_data(self, data: dict) -> dict:
"""Validate the config is valid."""
@callback
@abstractmethod
def _get_suggested_id(self, info: dict) -> str:
"""Suggest an ID based on the config."""
@abstractmethod
async def _update_data(self, data: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
async def async_create_item(self, data: dict) -> dict:
"""Create a new item."""
item = await self._process_create_data(data)
item[CONF_ID] = self.id_manager.generate_id(self._get_suggested_id(item))
self.data[item[CONF_ID]] = item
self._async_schedule_save()
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
return item
async def async_update_item(self, item_id: str, updates: dict) -> dict:
"""Update item."""
if item_id not in self.data:
raise ItemNotFound(item_id)
if CONF_ID in updates:
raise ValueError("Cannot update ID")
current = self.data[item_id]
updated = await self._update_data(current, updates)
self.data[item_id] = updated
self._async_schedule_save()
await self.notify_change(CHANGE_UPDATED, item_id, updated)
return self.data[item_id]
async def async_delete_item(self, item_id: str) -> None:
"""Delete item."""
if item_id not in self.data:
raise ItemNotFound(item_id)
self.data.pop(item_id)
self._async_schedule_save()
await self.notify_change(CHANGE_REMOVED, item_id, None)
@callback
def _async_schedule_save(self) -> None:
"""Schedule saving the area registry."""
self.store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self) -> dict:
"""Return data of area registry to store in a file."""
return {"items": list(self.data.values())}
@callback
def attach_entity_component_collection(
entity_component: EntityComponent,
collection: ObservableCollection,
create_entity: Callable[[dict], Entity],
) -> None:
"""Map a collection to an entity component."""
entities = {}
async def _collection_changed(
change_type: str, item_id: str, config: Optional[dict]
) -> None:
"""Handle a collection change."""
if change_type == CHANGE_ADDED:
entity = create_entity(cast(dict, config))
await entity_component.async_add_entities([entity])
entities[item_id] = entity
return
if change_type == CHANGE_REMOVED:
entity = entities.pop(item_id)
await entity.async_remove()
return
# CHANGE_UPDATED
await entities[item_id].async_update_config(config) # type: ignore
collection.async_add_listener(_collection_changed)
class StorageCollectionWebsocket:
"""Class to expose storage collection management over websocket."""
def __init__(
self,
storage_collection: StorageCollection,
api_prefix: str,
model_name: str,
create_schema: dict,
update_schema: dict,
):
"""Initialize a websocket CRUD."""
self.storage_collection = storage_collection
self.api_prefix = api_prefix
self.model_name = model_name
self.create_schema = create_schema
self.update_schema = update_schema
assert self.api_prefix[-1] != "/", "API prefix should not end in /"
@property
def item_id_key(self) -> str:
"""Return item ID key."""
return f"{self.model_name}_id"
@callback
def async_setup(self, hass: HomeAssistant, *, create_list: bool = True) -> None:
"""Set up the websocket commands."""
if create_list:
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/list",
self.ws_list_item,
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{vol.Required("type"): f"{self.api_prefix}/list"}
),
)
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/create",
websocket_api.require_admin(
websocket_api.async_response(self.ws_create_item)
),
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
**self.create_schema,
vol.Required("type"): f"{self.api_prefix}/create",
}
),
)
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/update",
websocket_api.require_admin(
websocket_api.async_response(self.ws_update_item)
),
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
**self.update_schema,
vol.Required("type"): f"{self.api_prefix}/update",
vol.Required(self.item_id_key): str,
}
),
)
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/delete",
websocket_api.require_admin(
websocket_api.async_response(self.ws_delete_item)
),
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): f"{self.api_prefix}/delete",
vol.Required(self.item_id_key): str,
}
),
)
def ws_list_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List items."""
connection.send_result(msg["id"], self.storage_collection.async_items())
async def ws_create_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Create a item."""
try:
data = dict(msg)
data.pop("id")
data.pop("type")
item = await self.storage_collection.async_create_item(data)
connection.send_result(msg["id"], item)
except vol.Invalid as err:
connection.send_error(
msg["id"],
websocket_api.const.ERR_INVALID_FORMAT,
humanize_error(data, err),
)
except ValueError as err:
connection.send_error(
msg["id"], websocket_api.const.ERR_INVALID_FORMAT, str(err)
)
async def ws_update_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Update a item."""
data = dict(msg)
msg_id = data.pop("id")
item_id = data.pop(self.item_id_key)
data.pop("type")
try:
item = await self.storage_collection.async_update_item(item_id, data)
connection.send_result(msg_id, item)
except ItemNotFound:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"Unable to find {self.item_id_key} {item_id}",
)
except vol.Invalid as err:
connection.send_error(
msg["id"],
websocket_api.const.ERR_INVALID_FORMAT,
humanize_error(data, err),
)
except ValueError as err:
connection.send_error(
msg_id, websocket_api.const.ERR_INVALID_FORMAT, str(err)
)
async def ws_delete_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Delete a item."""
try:
await self.storage_collection.async_delete_item(msg[self.item_id_key])
except ItemNotFound:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"Unable to find {self.item_id_key} {msg[self.item_id_key]}",
)
connection.send_result(msg["id"])

View file

@ -473,8 +473,9 @@ class Entity(ABC):
self._on_remove = [] self._on_remove = []
self._on_remove.append(func) self._on_remove.append(func)
async def async_remove(self): async def async_remove(self) -> None:
"""Remove entity from Home Assistant.""" """Remove entity from Home Assistant."""
assert self.hass is not None
await self.async_internal_will_remove_from_hass() await self.async_internal_will_remove_from_hass()
await self.async_will_remove_from_hass() await self.async_will_remove_from_hass()

View file

@ -3,14 +3,6 @@ from unittest.mock import patch
import pytest import pytest
from homeassistant.components.websocket_api.auth import (
TYPE_AUTH,
TYPE_AUTH_OK,
TYPE_AUTH_REQUIRED,
)
from homeassistant.components.websocket_api.http import URL
from homeassistant.setup import async_setup_component
from tests.common import mock_coro from tests.common import mock_coro
@ -22,37 +14,3 @@ def prevent_io():
side_effect=lambda *args: mock_coro([]), side_effect=lambda *args: mock_coro([]),
): ):
yield yield
@pytest.fixture
def hass_ws_client(aiohttp_client, hass_access_token):
"""Websocket client fixture connected to websocket server."""
async def create_client(hass, access_token=hass_access_token):
"""Create a websocket client."""
assert await async_setup_component(hass, "websocket_api", {})
client = await aiohttp_client(hass.http.app)
with patch("homeassistant.components.http.auth.setup_auth"):
websocket = await client.ws_connect(URL)
auth_resp = await websocket.receive_json()
assert auth_resp["type"] == TYPE_AUTH_REQUIRED
if access_token is None:
await websocket.send_json(
{"type": TYPE_AUTH, "access_token": "incorrect"}
)
else:
await websocket.send_json(
{"type": TYPE_AUTH, "access_token": access_token}
)
auth_ok = await websocket.receive_json()
assert auth_ok["type"] == TYPE_AUTH_OK
# wrap in client
websocket.client = client
return websocket
return create_client

View file

@ -98,7 +98,7 @@ async def test_onboarding_user(hass, hass_storage, aiohttp_client):
assert user.name == "Test Name" assert user.name == "Test Name"
assert len(user.credentials) == 1 assert len(user.credentials) == 1
assert user.credentials[0].data["username"] == "test-user" assert user.credentials[0].data["username"] == "test-user"
assert len(hass.data["person"].storage_data) == 1 assert len(hass.data["person"][1].async_items()) == 1
# Validate refresh token 1 # Validate refresh token 1
resp = await client.post( resp = await client.post(

View file

@ -1,19 +1,15 @@
"""The tests for the person component.""" """The tests for the person component."""
from unittest.mock import Mock import logging
import pytest import pytest
from homeassistant.components import person
from homeassistant.components.device_tracker import ( from homeassistant.components.device_tracker import (
ATTR_SOURCE_TYPE, ATTR_SOURCE_TYPE,
SOURCE_TYPE_GPS, SOURCE_TYPE_GPS,
SOURCE_TYPE_ROUTER, SOURCE_TYPE_ROUTER,
) )
from homeassistant.components.person import ( from homeassistant.components.person import ATTR_SOURCE, ATTR_USER_ID, DOMAIN
ATTR_SOURCE,
ATTR_USER_ID,
DOMAIN,
PersonManager,
)
from homeassistant.const import ( from homeassistant.const import (
ATTR_GPS_ACCURACY, ATTR_GPS_ACCURACY,
ATTR_ID, ATTR_ID,
@ -23,20 +19,29 @@ from homeassistant.const import (
STATE_UNKNOWN, STATE_UNKNOWN,
) )
from homeassistant.core import CoreState, State from homeassistant.core import CoreState, State
from homeassistant.helpers import collection
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import ( from tests.common import assert_setup_component, mock_component, mock_restore_cache
assert_setup_component,
mock_component,
mock_coro_func,
mock_restore_cache,
)
DEVICE_TRACKER = "device_tracker.test_tracker" DEVICE_TRACKER = "device_tracker.test_tracker"
DEVICE_TRACKER_2 = "device_tracker.test_tracker_2" DEVICE_TRACKER_2 = "device_tracker.test_tracker_2"
# pylint: disable=redefined-outer-name @pytest.fixture
def storage_collection(hass):
"""Return an empty storage collection."""
id_manager = collection.IDManager()
return person.PersonStorageCollection(
person.PersonStore(hass, person.STORAGE_VERSION, person.STORAGE_KEY),
logging.getLogger(f"{person.__name__}.storage_collection"),
id_manager,
collection.YamlCollection(
logging.getLogger(f"{person.__name__}.yaml_collection"), id_manager
),
)
@pytest.fixture @pytest.fixture
def storage_setup(hass, hass_storage, hass_admin_user): def storage_setup(hass, hass_storage, hass_admin_user):
"""Storage setup.""" """Storage setup."""
@ -433,21 +438,21 @@ async def test_load_person_storage_two_nonlinked(hass, hass_storage):
async def test_ws_list(hass, hass_ws_client, storage_setup): async def test_ws_list(hass, hass_ws_client, storage_setup):
"""Test listing via WS.""" """Test listing via WS."""
manager = hass.data[DOMAIN] manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
resp = await client.send_json({"id": 6, "type": "person/list"}) resp = await client.send_json({"id": 6, "type": "person/list"})
resp = await client.receive_json() resp = await client.receive_json()
assert resp["success"] assert resp["success"]
assert resp["result"]["storage"] == manager.storage_persons assert resp["result"]["storage"] == manager.async_items()
assert len(resp["result"]["storage"]) == 1 assert len(resp["result"]["storage"]) == 1
assert len(resp["result"]["config"]) == 0 assert len(resp["result"]["config"]) == 0
async def test_ws_create(hass, hass_ws_client, storage_setup, hass_read_only_user): async def test_ws_create(hass, hass_ws_client, storage_setup, hass_read_only_user):
"""Test creating via WS.""" """Test creating via WS."""
manager = hass.data[DOMAIN] manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -462,7 +467,7 @@ async def test_ws_create(hass, hass_ws_client, storage_setup, hass_read_only_use
) )
resp = await client.receive_json() resp = await client.receive_json()
persons = manager.storage_persons persons = manager.async_items()
assert len(persons) == 2 assert len(persons) == 2
assert resp["success"] assert resp["success"]
@ -474,7 +479,7 @@ async def test_ws_create_requires_admin(
): ):
"""Test creating via WS requires admin.""" """Test creating via WS requires admin."""
hass_admin_user.groups = [] hass_admin_user.groups = []
manager = hass.data[DOMAIN] manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -489,7 +494,7 @@ async def test_ws_create_requires_admin(
) )
resp = await client.receive_json() resp = await client.receive_json()
persons = manager.storage_persons persons = manager.async_items()
assert len(persons) == 1 assert len(persons) == 1
assert not resp["success"] assert not resp["success"]
@ -497,10 +502,10 @@ async def test_ws_create_requires_admin(
async def test_ws_update(hass, hass_ws_client, storage_setup): async def test_ws_update(hass, hass_ws_client, storage_setup):
"""Test updating via WS.""" """Test updating via WS."""
manager = hass.data[DOMAIN] manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
persons = manager.storage_persons persons = manager.async_items()
resp = await client.send_json( resp = await client.send_json(
{ {
@ -514,7 +519,7 @@ async def test_ws_update(hass, hass_ws_client, storage_setup):
) )
resp = await client.receive_json() resp = await client.receive_json()
persons = manager.storage_persons persons = manager.async_items()
assert len(persons) == 1 assert len(persons) == 1
assert resp["success"] assert resp["success"]
@ -533,10 +538,10 @@ async def test_ws_update_require_admin(
): ):
"""Test updating via WS requires admin.""" """Test updating via WS requires admin."""
hass_admin_user.groups = [] hass_admin_user.groups = []
manager = hass.data[DOMAIN] manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
original = dict(manager.storage_persons[0]) original = dict(manager.async_items()[0])
resp = await client.send_json( resp = await client.send_json(
{ {
@ -551,23 +556,23 @@ async def test_ws_update_require_admin(
resp = await client.receive_json() resp = await client.receive_json()
assert not resp["success"] assert not resp["success"]
not_updated = dict(manager.storage_persons[0]) not_updated = dict(manager.async_items()[0])
assert original == not_updated assert original == not_updated
async def test_ws_delete(hass, hass_ws_client, storage_setup): async def test_ws_delete(hass, hass_ws_client, storage_setup):
"""Test deleting via WS.""" """Test deleting via WS."""
manager = hass.data[DOMAIN] manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
persons = manager.storage_persons persons = manager.async_items()
resp = await client.send_json( resp = await client.send_json(
{"id": 6, "type": "person/delete", "person_id": persons[0]["id"]} {"id": 6, "type": "person/delete", "person_id": persons[0]["id"]}
) )
resp = await client.receive_json() resp = await client.receive_json()
persons = manager.storage_persons persons = manager.async_items()
assert len(persons) == 0 assert len(persons) == 0
assert resp["success"] assert resp["success"]
@ -581,7 +586,7 @@ async def test_ws_delete_require_admin(
): ):
"""Test deleting via WS requires admin.""" """Test deleting via WS requires admin."""
hass_admin_user.groups = [] hass_admin_user.groups = []
manager = hass.data[DOMAIN] manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -589,7 +594,7 @@ async def test_ws_delete_require_admin(
{ {
"id": 6, "id": 6,
"type": "person/delete", "type": "person/delete",
"person_id": manager.storage_persons[0]["id"], "person_id": manager.async_items()[0]["id"],
"name": "Updated Name", "name": "Updated Name",
"device_trackers": [DEVICE_TRACKER_2], "device_trackers": [DEVICE_TRACKER_2],
"user_id": None, "user_id": None,
@ -598,61 +603,64 @@ async def test_ws_delete_require_admin(
resp = await client.receive_json() resp = await client.receive_json()
assert not resp["success"] assert not resp["success"]
persons = manager.storage_persons persons = manager.async_items()
assert len(persons) == 1 assert len(persons) == 1
async def test_create_invalid_user_id(hass): async def test_create_invalid_user_id(hass, storage_collection):
"""Test we do not allow invalid user ID during creation.""" """Test we do not allow invalid user ID during creation."""
manager = PersonManager(hass, Mock(), [])
await manager.async_initialize()
with pytest.raises(ValueError): with pytest.raises(ValueError):
await manager.async_create_person(name="Hello", user_id="non-existing") await storage_collection.async_create_item(
{"name": "Hello", "user_id": "non-existing"}
)
async def test_create_duplicate_user_id(hass, hass_admin_user): async def test_create_duplicate_user_id(hass, hass_admin_user, storage_collection):
"""Test we do not allow duplicate user ID during creation.""" """Test we do not allow duplicate user ID during creation."""
manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), []) await storage_collection.async_create_item(
await manager.async_initialize() {"name": "Hello", "user_id": hass_admin_user.id}
await manager.async_create_person(name="Hello", user_id=hass_admin_user.id) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
await manager.async_create_person(name="Hello", user_id=hass_admin_user.id) await storage_collection.async_create_item(
{"name": "Hello", "user_id": hass_admin_user.id}
)
async def test_update_double_user_id(hass, hass_admin_user): async def test_update_double_user_id(hass, hass_admin_user, storage_collection):
"""Test we do not allow double user ID during update.""" """Test we do not allow double user ID during update."""
manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), []) await storage_collection.async_create_item(
await manager.async_initialize() {"name": "Hello", "user_id": hass_admin_user.id}
await manager.async_create_person(name="Hello", user_id=hass_admin_user.id) )
person = await manager.async_create_person(name="Hello") person = await storage_collection.async_create_item({"name": "Hello"})
with pytest.raises(ValueError): with pytest.raises(ValueError):
await manager.async_update_person( await storage_collection.async_update_item(
person_id=person["id"], user_id=hass_admin_user.id person["id"], {"user_id": hass_admin_user.id}
) )
async def test_update_invalid_user_id(hass): async def test_update_invalid_user_id(hass, storage_collection):
"""Test updating to invalid user ID.""" """Test updating to invalid user ID."""
manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), []) person = await storage_collection.async_create_item({"name": "Hello"})
await manager.async_initialize()
person = await manager.async_create_person(name="Hello")
with pytest.raises(ValueError): with pytest.raises(ValueError):
await manager.async_update_person( await storage_collection.async_update_item(
person_id=person["id"], user_id="non-existing" person["id"], {"user_id": "non-existing"}
) )
async def test_update_person_when_user_removed(hass, hass_read_only_user): async def test_update_person_when_user_removed(
hass, storage_setup, hass_read_only_user
):
"""Update person when user is removed.""" """Update person when user is removed."""
manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), []) storage_collection = hass.data[DOMAIN][1]
await manager.async_initialize()
person = await manager.async_create_person( person = await storage_collection.async_create_item(
name="Hello", user_id=hass_read_only_user.id {"name": "Hello", "user_id": hass_read_only_user.id}
) )
await hass.auth.async_remove_user(hass_read_only_user) await hass.auth.async_remove_user(hass_read_only_user)
await hass.async_block_till_done() await hass.async_block_till_done()
assert person["user_id"] is None
assert storage_collection.data[person["id"]]["user_id"] is None

View file

@ -9,6 +9,13 @@ import requests_mock as _requests_mock
from homeassistant import util from homeassistant import util
from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
from homeassistant.auth.providers import homeassistant, legacy_api_password from homeassistant.auth.providers import homeassistant, legacy_api_password
from homeassistant.components.websocket_api.auth import (
TYPE_AUTH,
TYPE_AUTH_OK,
TYPE_AUTH_REQUIRED,
)
from homeassistant.components.websocket_api.http import URL
from homeassistant.setup import async_setup_component
from homeassistant.util import location from homeassistant.util import location
pytest.register_assert_rewrite("tests.common") pytest.register_assert_rewrite("tests.common")
@ -187,3 +194,37 @@ def hass_client(hass, aiohttp_client, hass_access_token):
) )
return auth_client return auth_client
@pytest.fixture
def hass_ws_client(aiohttp_client, hass_access_token):
"""Websocket client fixture connected to websocket server."""
async def create_client(hass, access_token=hass_access_token):
"""Create a websocket client."""
assert await async_setup_component(hass, "websocket_api", {})
client = await aiohttp_client(hass.http.app)
with patch("homeassistant.components.http.auth.setup_auth"):
websocket = await client.ws_connect(URL)
auth_resp = await websocket.receive_json()
assert auth_resp["type"] == TYPE_AUTH_REQUIRED
if access_token is None:
await websocket.send_json(
{"type": TYPE_AUTH, "access_token": "incorrect"}
)
else:
await websocket.send_json(
{"type": TYPE_AUTH, "access_token": access_token}
)
auth_ok = await websocket.receive_json()
assert auth_ok["type"] == TYPE_AUTH_OK
# wrap in client
websocket.client = client
return websocket
return create_client

View file

@ -0,0 +1,356 @@
"""Tests for the collection helper."""
import logging
import pytest
import voluptuous as vol
from homeassistant.helpers import collection, entity, entity_component, storage
from tests.common import flush_store
LOGGER = logging.getLogger(__name__)
def track_changes(coll: collection.ObservableCollection):
"""Create helper to track changes in a collection."""
changes = []
async def listener(*args):
changes.append(args)
coll.async_add_listener(listener)
return changes
class MockEntity(entity.Entity):
"""Entity that is config based."""
def __init__(self, config):
"""Initialize entity."""
self._config = config
@property
def unique_id(self):
"""Return unique ID of entity."""
return self._config["id"]
@property
def name(self):
"""Return name of entity."""
return self._config["name"]
@property
def state(self):
"""Return state of entity."""
return self._config["state"]
async def async_update_config(self, config):
"""Update entity config."""
self._config = config
self.async_write_ha_state()
class MockStorageCollection(collection.StorageCollection):
"""Mock storage collection."""
async def _process_create_data(self, data: dict) -> dict:
"""Validate the config is valid."""
if "name" not in data:
raise ValueError("invalid")
return data
def _get_suggested_id(self, info: dict) -> str:
"""Suggest an ID based on the config."""
return info["name"]
async def _update_data(self, data: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
return {**data, **update_data}
def test_id_manager():
"""Test the ID manager."""
id_manager = collection.IDManager()
assert not id_manager.has_id("some_id")
data = {}
id_manager.add_collection(data)
assert not id_manager.has_id("some_id")
data["some_id"] = 1
assert id_manager.has_id("some_id")
assert id_manager.generate_id("some_id") == "some_id_2"
assert id_manager.generate_id("bla") == "bla"
async def test_observable_collection():
"""Test observerable collection."""
coll = collection.ObservableCollection(LOGGER)
assert coll.async_items() == []
coll.data["bla"] = 1
assert coll.async_items() == [1]
changes = track_changes(coll)
await coll.notify_change("mock_type", "mock_id", {"mock": "item"})
assert len(changes) == 1
assert changes[0] == ("mock_type", "mock_id", {"mock": "item"})
async def test_yaml_collection():
"""Test a YAML collection."""
id_manager = collection.IDManager()
coll = collection.YamlCollection(LOGGER, id_manager)
changes = track_changes(coll)
await coll.async_load(
[{"id": "mock-1", "name": "Mock 1"}, {"id": "mock-2", "name": "Mock 2"}]
)
assert id_manager.has_id("mock-1")
assert id_manager.has_id("mock-2")
assert len(changes) == 2
assert changes[0] == (
collection.CHANGE_ADDED,
"mock-1",
{"id": "mock-1", "name": "Mock 1"},
)
assert changes[1] == (
collection.CHANGE_ADDED,
"mock-2",
{"id": "mock-2", "name": "Mock 2"},
)
async def test_yaml_collection_skipping_duplicate_ids():
"""Test YAML collection skipping duplicate IDs."""
id_manager = collection.IDManager()
id_manager.add_collection({"existing": True})
coll = collection.YamlCollection(LOGGER, id_manager)
changes = track_changes(coll)
await coll.async_load(
[{"id": "mock-1", "name": "Mock 1"}, {"id": "existing", "name": "Mock 2"}]
)
assert len(changes) == 1
assert changes[0] == (
collection.CHANGE_ADDED,
"mock-1",
{"id": "mock-1", "name": "Mock 1"},
)
async def test_storage_collection(hass):
"""Test storage collection."""
store = storage.Store(hass, 1, "test-data")
await store.async_save(
{
"items": [
{"id": "mock-1", "name": "Mock 1", "data": 1},
{"id": "mock-2", "name": "Mock 2", "data": 2},
]
}
)
id_manager = collection.IDManager()
coll = MockStorageCollection(store, LOGGER, id_manager)
changes = track_changes(coll)
await coll.async_load()
assert id_manager.has_id("mock-1")
assert id_manager.has_id("mock-2")
assert len(changes) == 2
assert changes[0] == (
collection.CHANGE_ADDED,
"mock-1",
{"id": "mock-1", "name": "Mock 1", "data": 1},
)
assert changes[1] == (
collection.CHANGE_ADDED,
"mock-2",
{"id": "mock-2", "name": "Mock 2", "data": 2},
)
item = await coll.async_create_item({"name": "Mock 3"})
assert item["id"] == "mock_3"
assert len(changes) == 3
assert changes[2] == (
collection.CHANGE_ADDED,
"mock_3",
{"id": "mock_3", "name": "Mock 3"},
)
updated_item = await coll.async_update_item("mock-2", {"name": "Mock 2 updated"})
assert id_manager.has_id("mock-2")
assert updated_item == {"id": "mock-2", "name": "Mock 2 updated", "data": 2}
assert len(changes) == 4
assert changes[3] == (collection.CHANGE_UPDATED, "mock-2", updated_item)
with pytest.raises(ValueError):
await coll.async_update_item("mock-2", {"id": "mock-2-updated"})
assert id_manager.has_id("mock-2")
assert not id_manager.has_id("mock-2-updated")
assert len(changes) == 4
await flush_store(store)
assert await storage.Store(hass, 1, "test-data").async_load() == {
"items": [
{"id": "mock-1", "name": "Mock 1", "data": 1},
{"id": "mock-2", "name": "Mock 2 updated", "data": 2},
{"id": "mock_3", "name": "Mock 3"},
]
}
async def test_attach_entity_component_collection(hass):
"""Test attaching collection to entity component."""
ent_comp = entity_component.EntityComponent(LOGGER, "test", hass)
coll = collection.ObservableCollection(LOGGER)
collection.attach_entity_component_collection(ent_comp, coll, MockEntity)
await coll.notify_change(
collection.CHANGE_ADDED,
"mock_id",
{"id": "mock_id", "state": "initial", "name": "Mock 1"},
)
assert hass.states.get("test.mock_1").name == "Mock 1"
assert hass.states.get("test.mock_1").state == "initial"
await coll.notify_change(
collection.CHANGE_UPDATED,
"mock_id",
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
)
assert hass.states.get("test.mock_1").name == "Mock 1 updated"
assert hass.states.get("test.mock_1").state == "second"
await coll.notify_change(collection.CHANGE_REMOVED, "mock_id", None)
assert hass.states.get("test.mock_1") is None
async def test_storage_collection_websocket(hass, hass_ws_client):
"""Test exposing a storage collection via websockets."""
store = storage.Store(hass, 1, "test-data")
coll = MockStorageCollection(store, LOGGER)
changes = track_changes(coll)
collection.StorageCollectionWebsocket(
coll,
"test_item/collection",
"test_item",
{vol.Required("name"): str, vol.Required("immutable_string"): str},
{vol.Optional("name"): str},
).async_setup(hass)
client = await hass_ws_client(hass)
# Create invalid
await client.send_json(
{
"id": 1,
"type": "test_item/collection/create",
"name": 1,
# Forgot to add immutable_string
}
)
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "invalid_format"
assert len(changes) == 0
# Create
await client.send_json(
{
"id": 2,
"type": "test_item/collection/create",
"name": "Initial Name",
"immutable_string": "no-changes",
}
)
response = await client.receive_json()
assert response["success"]
assert response["result"] == {
"id": "initial_name",
"name": "Initial Name",
"immutable_string": "no-changes",
}
assert len(changes) == 1
assert changes[0] == (collection.CHANGE_ADDED, "initial_name", response["result"])
# List
await client.send_json({"id": 3, "type": "test_item/collection/list"})
response = await client.receive_json()
assert response["success"]
assert response["result"] == [
{
"id": "initial_name",
"name": "Initial Name",
"immutable_string": "no-changes",
}
]
assert len(changes) == 1
# Update invalid data
await client.send_json(
{
"id": 4,
"type": "test_item/collection/update",
"test_item_id": "initial_name",
"immutable_string": "no-changes",
}
)
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "invalid_format"
assert len(changes) == 1
# Update invalid item
await client.send_json(
{
"id": 5,
"type": "test_item/collection/update",
"test_item_id": "non-existing",
"name": "Updated name",
}
)
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "not_found"
assert len(changes) == 1
# Update
await client.send_json(
{
"id": 6,
"type": "test_item/collection/update",
"test_item_id": "initial_name",
"name": "Updated name",
}
)
response = await client.receive_json()
assert response["success"]
assert response["result"] == {
"id": "initial_name",
"name": "Updated name",
"immutable_string": "no-changes",
}
assert len(changes) == 2
assert changes[1] == (collection.CHANGE_UPDATED, "initial_name", response["result"])
# Delete invalid ID
await client.send_json(
{"id": 7, "type": "test_item/collection/update", "test_item_id": "non-existing"}
)
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "not_found"
assert len(changes) == 2
# Delete
await client.send_json(
{"id": 8, "type": "test_item/collection/delete", "test_item_id": "initial_name"}
)
response = await client.receive_json()
assert response["success"]
assert len(changes) == 3
assert changes[2] == (collection.CHANGE_REMOVED, "initial_name", None)