Google assistant storage of connected agents (#29158)

* Make async_report_state take agent_user_id

* Attempt to store synced agents

* Drop now not needed initialization

* Make sure cloud uses the all sync on changed preferences

* Some more places to use all version of sync

* Get the agent_user_id from the request context if available

* Minor cleanup

* Remove the old fixed agent_user_id for cloud

Instead pass along cloud_user where appropriate.

* async_delay_save takes a function

* Adjust test for delayed store

* Remove unused save function

* Add login check.
This commit is contained in:
Joakim Plate 2019-12-03 07:05:59 +01:00 committed by Paulus Schoutsen
parent 434b783b4c
commit 2569c4ae37
16 changed files with 260 additions and 83 deletions

View file

@ -101,6 +101,7 @@ class CloudClient(Interface):
self._google_config = google_config.CloudGoogleConfig( self._google_config = google_config.CloudGoogleConfig(
self._hass, self.google_user_config, cloud_user, self._prefs, self.cloud self._hass, self.google_user_config, cloud_user, self._prefs, self.cloud
) )
await self._google_config.async_initialize()
return self._google_config return self._google_config

View file

@ -42,12 +42,7 @@ class CloudGoogleConfig(AbstractConfig):
@property @property
def enabled(self): def enabled(self):
"""Return if Google is enabled.""" """Return if Google is enabled."""
return self._prefs.google_enabled return self._cloud.is_logged_in and self._prefs.google_enabled
@property
def agent_user_id(self):
"""Return Agent User Id to use for query responses."""
return self._cloud.username
@property @property
def entity_config(self): def entity_config(self):
@ -62,7 +57,7 @@ class CloudGoogleConfig(AbstractConfig):
@property @property
def should_report_state(self): def should_report_state(self):
"""Return if states should be proactively reported.""" """Return if states should be proactively reported."""
return self._prefs.google_report_state return self._cloud.is_logged_in and self._prefs.google_report_state
@property @property
def local_sdk_webhook_id(self): def local_sdk_webhook_id(self):
@ -104,7 +99,7 @@ class CloudGoogleConfig(AbstractConfig):
entity_config = entity_configs.get(state.entity_id, {}) entity_config = entity_configs.get(state.entity_id, {})
return not entity_config.get(PREF_DISABLE_2FA, DEFAULT_DISABLE_2FA) return not entity_config.get(PREF_DISABLE_2FA, DEFAULT_DISABLE_2FA)
async def async_report_state(self, message): async def async_report_state(self, message, agent_user_id: str):
"""Send a state report to Google.""" """Send a state report to Google."""
try: try:
await self._cloud.google_report_state.async_send_message(message) await self._cloud.google_report_state.async_send_message(message)
@ -132,13 +127,6 @@ class CloudGoogleConfig(AbstractConfig):
_LOGGER.debug("Finished requesting syncing: %s", req.status) _LOGGER.debug("Finished requesting syncing: %s", req.status)
return req.status return req.status
async def async_deactivate_report_state(self):
"""Turn off report state and disable further state reporting.
Called when the user disconnects their account from Google.
"""
await self._prefs.async_update(google_report_state=False)
async def _async_prefs_updated(self, prefs): async def _async_prefs_updated(self, prefs):
"""Handle updated preferences.""" """Handle updated preferences."""
if self.should_report_state != self.is_reporting_state: if self.should_report_state != self.is_reporting_state:
@ -149,7 +137,7 @@ class CloudGoogleConfig(AbstractConfig):
# State reporting is reported as a property on entities. # State reporting is reported as a property on entities.
# So when we change it, we need to sync all entities. # So when we change it, we need to sync all entities.
await self.async_sync_entities(self.agent_user_id) await self.async_sync_entities_all()
# If entity prefs are the same or we have filter in config.yaml, # If entity prefs are the same or we have filter in config.yaml,
# don't sync. # don't sync.
@ -157,7 +145,7 @@ class CloudGoogleConfig(AbstractConfig):
self._cur_entity_prefs is not prefs.google_entity_configs self._cur_entity_prefs is not prefs.google_entity_configs
and self._config["filter"].empty_filter and self._config["filter"].empty_filter
): ):
self.async_schedule_google_sync(self.agent_user_id) self.async_schedule_google_sync_all()
if self.enabled and not self.is_local_sdk_active: if self.enabled and not self.is_local_sdk_active:
self.async_enable_local_sdk() self.async_enable_local_sdk()
@ -173,4 +161,4 @@ class CloudGoogleConfig(AbstractConfig):
# Schedule a sync if a change was made to an entity that Google knows about # Schedule a sync if a change was made to an entity that Google knows about
if self._should_expose_entity_id(entity_id): if self._should_expose_entity_id(entity_id):
await self.async_sync_entities(self.agent_user_id) await self.async_sync_entities_all()

View file

@ -175,7 +175,7 @@ class GoogleActionsSyncView(HomeAssistantView):
hass = request.app["hass"] hass = request.app["hass"]
cloud: Cloud = hass.data[DOMAIN] cloud: Cloud = hass.data[DOMAIN]
gconf = await cloud.client.get_google_config() gconf = await cloud.client.get_google_config()
status = await gconf.async_sync_entities(gconf.agent_user_id) status = await gconf.async_sync_entities(gconf.cloud_user)
return self.json({}, status_code=status) return self.json({}, status_code=status)

View file

@ -93,7 +93,9 @@ CONFIG_SCHEMA = vol.Schema({DOMAIN: GOOGLE_ASSISTANT_SCHEMA}, extra=vol.ALLOW_EX
async def async_setup(hass: HomeAssistant, yaml_config: Dict[str, Any]): async def async_setup(hass: HomeAssistant, yaml_config: Dict[str, Any]):
"""Activate Google Actions component.""" """Activate Google Actions component."""
config = yaml_config.get(DOMAIN, {}) config = yaml_config.get(DOMAIN, {})
google_config = GoogleConfig(hass, config) google_config = GoogleConfig(hass, config)
await google_config.async_initialize()
hass.http.register_view(GoogleAssistantView(google_config)) hass.http.register_view(GoogleAssistantView(google_config))

View file

@ -141,3 +141,5 @@ DEVICE_CLASS_TO_GOOGLE_TYPES = {
CHALLENGE_ACK_NEEDED = "ackNeeded" CHALLENGE_ACK_NEEDED = "ackNeeded"
CHALLENGE_PIN_NEEDED = "pinNeeded" CHALLENGE_PIN_NEEDED = "pinNeeded"
CHALLENGE_FAILED_PIN_NEEDED = "challengeFailedPinNeeded" CHALLENGE_FAILED_PIN_NEEDED = "challengeFailedPinNeeded"
STORE_AGENT_USER_IDS = "agent_user_ids"

View file

@ -10,6 +10,7 @@ from aiohttp.web import json_response
from homeassistant.core import Context, callback, HomeAssistant, State from homeassistant.core import Context, callback, HomeAssistant, State
from homeassistant.helpers.event import async_call_later from homeassistant.helpers.event import async_call_later
from homeassistant.components import webhook from homeassistant.components import webhook
from homeassistant.helpers.storage import Store
from homeassistant.const import ( from homeassistant.const import (
CONF_NAME, CONF_NAME,
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
@ -26,6 +27,7 @@ from .const import (
ERR_FUNCTION_NOT_SUPPORTED, ERR_FUNCTION_NOT_SUPPORTED,
DEVICE_CLASS_TO_GOOGLE_TYPES, DEVICE_CLASS_TO_GOOGLE_TYPES,
CONF_ROOM_HINT, CONF_ROOM_HINT,
STORE_AGENT_USER_IDS,
) )
from .error import SmartHomeError from .error import SmartHomeError
@ -41,19 +43,20 @@ class AbstractConfig:
def __init__(self, hass): def __init__(self, hass):
"""Initialize abstract config.""" """Initialize abstract config."""
self.hass = hass self.hass = hass
self._store = None
self._google_sync_unsub = {} self._google_sync_unsub = {}
self._local_sdk_active = False self._local_sdk_active = False
async def async_initialize(self):
"""Perform async initialization of config."""
self._store = GoogleConfigStore(self.hass)
await self._store.async_load()
@property @property
def enabled(self): def enabled(self):
"""Return if Google is enabled.""" """Return if Google is enabled."""
return False return False
@property
def agent_user_id(self):
"""Return Agent User Id to use for query responses."""
return None
@property @property
def entity_config(self): def entity_config(self):
"""Return entity config.""" """Return entity config."""
@ -101,10 +104,18 @@ class AbstractConfig:
# pylint: disable=no-self-use # pylint: disable=no-self-use
return True return True
async def async_report_state(self, message): async def async_report_state(self, message, agent_user_id: str):
"""Send a state report to Google.""" """Send a state report to Google."""
raise NotImplementedError raise NotImplementedError
async def async_report_state_all(self, message):
"""Send a state report to Google for all previously synced users."""
jobs = [
self.async_report_state(message, agent_user_id)
for agent_user_id in self._store.agent_user_ids
]
await gather(*jobs)
def async_enable_report_state(self): def async_enable_report_state(self):
"""Enable proactive mode.""" """Enable proactive mode."""
# Circular dep # Circular dep
@ -123,9 +134,18 @@ class AbstractConfig:
"""Sync all entities to Google.""" """Sync all entities to Google."""
# Remove any pending sync # Remove any pending sync
self._google_sync_unsub.pop(agent_user_id, lambda: None)() self._google_sync_unsub.pop(agent_user_id, lambda: None)()
return await self._async_request_sync_devices(agent_user_id) return await self._async_request_sync_devices(agent_user_id)
async def async_sync_entities_all(self):
"""Sync all entities to Google for all registered agents."""
res = await gather(
*[
self.async_sync_entities(agent_user_id)
for agent_user_id in self._store.agent_user_ids
]
)
return max(res, default=204)
@callback @callback
def async_schedule_google_sync(self, agent_user_id: str): def async_schedule_google_sync(self, agent_user_id: str):
"""Schedule a sync.""" """Schedule a sync."""
@ -141,6 +161,12 @@ class AbstractConfig:
self.hass, SYNC_DELAY, _schedule_callback self.hass, SYNC_DELAY, _schedule_callback
) )
@callback
def async_schedule_google_sync_all(self):
"""Schedule a sync for all registered agents."""
for agent_user_id in self._store.agent_user_ids:
self.async_schedule_google_sync(agent_user_id)
async def _async_request_sync_devices(self, agent_user_id: str) -> int: async def _async_request_sync_devices(self, agent_user_id: str) -> int:
"""Trigger a sync with Google. """Trigger a sync with Google.
@ -148,11 +174,19 @@ class AbstractConfig:
""" """
raise NotImplementedError raise NotImplementedError
async def async_deactivate_report_state(self): async def async_connect_agent_user(self, agent_user_id: str):
"""Add an synced and known agent_user_id.
Called when a completed sync response have been sent to Google.
"""
self._store.add_agent_user_id(agent_user_id)
async def async_disconnect_agent_user(self, agent_user_id: str):
"""Turn off report state and disable further state reporting. """Turn off report state and disable further state reporting.
Called when the user disconnects their account from Google. Called when the user disconnects their account from Google.
""" """
self._store.pop_agent_user_id(agent_user_id)
@callback @callback
def async_enable_local_sdk(self): def async_enable_local_sdk(self):
@ -199,6 +233,44 @@ class AbstractConfig:
return json_response(result) return json_response(result)
class GoogleConfigStore:
"""A configuration store for google assistant."""
_STORAGE_VERSION = 1
_STORAGE_KEY = DOMAIN
def __init__(self, hass):
"""Initialize a configuration store."""
self._hass = hass
self._store = Store(hass, self._STORAGE_VERSION, self._STORAGE_KEY)
self._data = {STORE_AGENT_USER_IDS: {}}
@property
def agent_user_ids(self):
"""Return a list of connected agent user_ids."""
return self._data[STORE_AGENT_USER_IDS]
@callback
def add_agent_user_id(self, agent_user_id):
"""Add an agent user id to store."""
if agent_user_id not in self._data[STORE_AGENT_USER_IDS]:
self._data[STORE_AGENT_USER_IDS][agent_user_id] = {}
self._store.async_delay_save(lambda: self._data, 1.0)
@callback
def pop_agent_user_id(self, agent_user_id):
"""Remove agent user id from store."""
if agent_user_id in self._data[STORE_AGENT_USER_IDS]:
self._data[STORE_AGENT_USER_IDS].pop(agent_user_id, None)
self._store.async_delay_save(lambda: self._data, 1.0)
async def async_load(self):
"""Store current configuration to disk."""
data = await self._store.async_load()
if data:
self._data = data
class RequestData: class RequestData:
"""Hold data associated with a particular request.""" """Hold data associated with a particular request."""
@ -278,7 +350,7 @@ class GoogleEntity:
trait.might_2fa(domain, features, device_class) for trait in self.traits() trait.might_2fa(domain, features, device_class) for trait in self.traits()
) )
async def sync_serialize(self): async def sync_serialize(self, agent_user_id):
"""Serialize entity for a SYNC response. """Serialize entity for a SYNC response.
https://developers.google.com/actions/smarthome/create-app#actiondevicessync https://developers.google.com/actions/smarthome/create-app#actiondevicessync
@ -314,7 +386,7 @@ class GoogleEntity:
"webhookId": self.config.local_sdk_webhook_id, "webhookId": self.config.local_sdk_webhook_id,
"httpPort": self.hass.config.api.port, "httpPort": self.hass.config.api.port,
"httpSSL": self.hass.config.api.use_ssl, "httpSSL": self.hass.config.api.use_ssl,
"proxyDeviceId": self.config.agent_user_id, "proxyDeviceId": agent_user_id,
} }
for trt in traits: for trt in traits:

View file

@ -81,11 +81,6 @@ class GoogleConfig(AbstractConfig):
"""Return if Google is enabled.""" """Return if Google is enabled."""
return True return True
@property
def agent_user_id(self):
"""Return Agent User Id to use for query responses."""
return None
@property @property
def entity_config(self): def entity_config(self):
"""Return entity config.""" """Return entity config."""
@ -214,11 +209,11 @@ class GoogleConfig(AbstractConfig):
_LOGGER.error("Could not contact %s", url) _LOGGER.error("Could not contact %s", url)
return 500 return 500
async def async_report_state(self, message): async def async_report_state(self, message, agent_user_id: str):
"""Send a state report to Google.""" """Send a state report to Google."""
data = { data = {
"requestId": uuid4().hex, "requestId": uuid4().hex,
"agentUserId": (await self.hass.auth.async_get_owner()).id, "agentUserId": agent_user_id,
"payload": message, "payload": message,
} }
await self.async_call_homegraph_api(REPORT_STATE_BASE_URL, data) await self.async_call_homegraph_api(REPORT_STATE_BASE_URL, data)

View file

@ -45,7 +45,7 @@ def async_enable_report_state(hass: HomeAssistant, google_config: AbstractConfig
if entity_data == old_entity.query_serialize(): if entity_data == old_entity.query_serialize():
return return
await google_config.async_report_state( await google_config.async_report_state_all(
{"devices": {"states": {changed_entity: entity_data}}} {"devices": {"states": {changed_entity: entity_data}}}
) )
@ -62,7 +62,7 @@ def async_enable_report_state(hass: HomeAssistant, google_config: AbstractConfig
except SmartHomeError: except SmartHomeError:
continue continue
await google_config.async_report_state({"devices": {"states": entities}}) await google_config.async_report_state_all({"devices": {"states": entities}})
async_call_later(hass, INITIAL_REPORT_DELAY, inital_report) async_call_later(hass, INITIAL_REPORT_DELAY, inital_report)

View file

@ -79,18 +79,19 @@ async def async_devices_sync(hass, data, payload):
EVENT_SYNC_RECEIVED, {"request_id": data.request_id}, context=data.context EVENT_SYNC_RECEIVED, {"request_id": data.request_id}, context=data.context
) )
agent_user_id = data.context.user_id
devices = await asyncio.gather( devices = await asyncio.gather(
*( *(
entity.sync_serialize() entity.sync_serialize(agent_user_id)
for entity in async_get_entities(hass, data.config) for entity in async_get_entities(hass, data.config)
if entity.should_expose() if entity.should_expose()
) )
) )
response = { response = {"agentUserId": agent_user_id, "devices": devices}
"agentUserId": data.config.agent_user_id or data.context.user_id,
"devices": devices, await data.config.async_connect_agent_user(agent_user_id)
}
return response return response
@ -197,7 +198,7 @@ async def async_devices_disconnect(hass, data: RequestData, payload):
https://developers.google.com/assistant/smarthome/develop/process-intents#DISCONNECT https://developers.google.com/assistant/smarthome/develop/process-intents#DISCONNECT
""" """
await data.config.async_deactivate_report_state() await data.config.async_disconnect_agent_user(data.context.user_id)
return None return None
@ -209,7 +210,7 @@ async def async_devices_identify(hass, data: RequestData, payload):
""" """
return { return {
"device": { "device": {
"id": data.config.agent_user_id, "id": data.context.user_id,
"isLocalOnly": True, "isLocalOnly": True,
"isProxy": True, "isProxy": True,
"deviceInfo": { "deviceInfo": {

View file

@ -102,16 +102,13 @@ async def test_handler_google_actions(hass):
reqid = "5711642932632160983" reqid = "5711642932632160983"
data = {"requestId": reqid, "inputs": [{"intent": "action.devices.SYNC"}]} data = {"requestId": reqid, "inputs": [{"intent": "action.devices.SYNC"}]}
with patch( config = await cloud.client.get_google_config()
"hass_nabucasa.Cloud._decode_claims", resp = await cloud.client.async_google_message(data)
return_value={"cognito:username": "myUserName"},
):
resp = await cloud.client.async_google_message(data)
assert resp["requestId"] == reqid assert resp["requestId"] == reqid
payload = resp["payload"] payload = resp["payload"]
assert payload["agentUserId"] == "myUserName" assert payload["agentUserId"] == config.cloud_user
devices = payload["devices"] devices = payload["devices"]
assert len(devices) == 1 assert len(devices) == 1

View file

@ -19,6 +19,8 @@ async def test_google_update_report_state(hass, cloud_prefs):
cloud_prefs, cloud_prefs,
Mock(claims={"cognito:username": "abcdefghjkl"}), Mock(claims={"cognito:username": "abcdefghjkl"}),
) )
await config.async_initialize()
await config.async_connect_agent_user("mock-user-id")
with patch.object( with patch.object(
config, "async_sync_entities", side_effect=mock_coro config, "async_sync_entities", side_effect=mock_coro
@ -58,6 +60,8 @@ async def test_google_update_expose_trigger_sync(hass, cloud_prefs):
cloud_prefs, cloud_prefs,
Mock(claims={"cognito:username": "abcdefghjkl"}), Mock(claims={"cognito:username": "abcdefghjkl"}),
) )
await config.async_initialize()
await config.async_connect_agent_user("mock-user-id")
with patch.object( with patch.object(
config, "async_sync_entities", side_effect=mock_coro config, "async_sync_entities", side_effect=mock_coro
@ -95,6 +99,8 @@ async def test_google_entity_registry_sync(hass, mock_cloud_login, cloud_prefs):
config = CloudGoogleConfig( config = CloudGoogleConfig(
hass, GACTIONS_SCHEMA({}), "mock-user-id", cloud_prefs, hass.data["cloud"] hass, GACTIONS_SCHEMA({}), "mock-user-id", cloud_prefs, hass.data["cloud"]
) )
await config.async_initialize()
await config.async_connect_agent_user("mock-user-id")
with patch.object( with patch.object(
config, "async_sync_entities", side_effect=mock_coro config, "async_sync_entities", side_effect=mock_coro

View file

@ -1,7 +1,18 @@
"""Tests for the Google Assistant integration.""" """Tests for the Google Assistant integration."""
from asynctest.mock import MagicMock
from homeassistant.components.google_assistant import helpers from homeassistant.components.google_assistant import helpers
def mock_google_config_store(agent_user_ids=None):
"""Fake a storage for google assistant."""
store = MagicMock(spec=helpers.GoogleConfigStore)
if agent_user_ids is not None:
store.agent_user_ids = agent_user_ids
else:
store.agent_user_ids = {}
return store
class MockConfig(helpers.AbstractConfig): class MockConfig(helpers.AbstractConfig):
"""Fake config that always exposes everything.""" """Fake config that always exposes everything."""
@ -15,6 +26,7 @@ class MockConfig(helpers.AbstractConfig):
local_sdk_webhook_id=None, local_sdk_webhook_id=None,
local_sdk_user_id=None, local_sdk_user_id=None,
enabled=True, enabled=True,
agent_user_ids=None,
): ):
"""Initialize config.""" """Initialize config."""
super().__init__(hass) super().__init__(hass)
@ -24,6 +36,7 @@ class MockConfig(helpers.AbstractConfig):
self._local_sdk_webhook_id = local_sdk_webhook_id self._local_sdk_webhook_id = local_sdk_webhook_id
self._local_sdk_user_id = local_sdk_user_id self._local_sdk_user_id = local_sdk_user_id
self._enabled = enabled self._enabled = enabled
self._store = mock_google_config_store(agent_user_ids)
@property @property
def enabled(self): def enabled(self):

View file

@ -1,11 +1,18 @@
"""Test Google Assistant helpers.""" """Test Google Assistant helpers."""
from unittest.mock import Mock from asynctest.mock import Mock, patch, call
from datetime import timedelta
import pytest
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.components.google_assistant import helpers from homeassistant.components.google_assistant import helpers
from homeassistant.components.google_assistant.const import EVENT_COMMAND_RECEIVED from homeassistant.components.google_assistant.const import EVENT_COMMAND_RECEIVED
from homeassistant.util import dt
from . import MockConfig from . import MockConfig
from tests.common import async_capture_events, async_mock_service from tests.common import (
async_capture_events,
async_mock_service,
async_fire_time_changed,
)
async def test_google_entity_sync_serialize_with_local_sdk(hass): async def test_google_entity_sync_serialize_with_local_sdk(hass):
@ -19,13 +26,13 @@ async def test_google_entity_sync_serialize_with_local_sdk(hass):
) )
entity = helpers.GoogleEntity(hass, config, hass.states.get("light.ceiling_lights")) entity = helpers.GoogleEntity(hass, config, hass.states.get("light.ceiling_lights"))
serialized = await entity.sync_serialize() serialized = await entity.sync_serialize(None)
assert "otherDeviceIds" not in serialized assert "otherDeviceIds" not in serialized
assert "customData" not in serialized assert "customData" not in serialized
config.async_enable_local_sdk() config.async_enable_local_sdk()
serialized = await entity.sync_serialize() serialized = await entity.sync_serialize(None)
assert serialized["otherDeviceIds"] == [{"deviceId": "light.ceiling_lights"}] assert serialized["otherDeviceIds"] == [{"deviceId": "light.ceiling_lights"}]
assert serialized["customData"] == { assert serialized["customData"] == {
"httpPort": 1234, "httpPort": 1234,
@ -128,3 +135,84 @@ async def test_config_local_sdk_if_disabled(hass, hass_client):
resp = await client.post("/api/webhook/mock-webhook-id") resp = await client.post("/api/webhook/mock-webhook-id")
assert resp.status == 200 assert resp.status == 200
assert await resp.read() == b"" assert await resp.read() == b""
async def test_agent_user_id_storage(hass, hass_storage):
"""Test a disconnect message."""
hass_storage["google_assistant"] = {
"version": 1,
"key": "google_assistant",
"data": {"agent_user_ids": {"agent_1": {}}},
}
store = helpers.GoogleConfigStore(hass)
await store.async_load()
assert hass_storage["google_assistant"] == {
"version": 1,
"key": "google_assistant",
"data": {"agent_user_ids": {"agent_1": {}}},
}
async def _check_after_delay(data):
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=2))
await hass.async_block_till_done()
assert hass_storage["google_assistant"] == {
"version": 1,
"key": "google_assistant",
"data": data,
}
store.add_agent_user_id("agent_2")
await _check_after_delay({"agent_user_ids": {"agent_1": {}, "agent_2": {}}})
store.pop_agent_user_id("agent_1")
await _check_after_delay({"agent_user_ids": {"agent_2": {}}})
async def test_agent_user_id_connect():
"""Test the connection and disconnection of users."""
config = MockConfig()
store = config._store
await config.async_connect_agent_user("agent_2")
assert store.add_agent_user_id.call_args == call("agent_2")
await config.async_connect_agent_user("agent_1")
assert store.add_agent_user_id.call_args == call("agent_1")
await config.async_disconnect_agent_user("agent_2")
assert store.pop_agent_user_id.call_args == call("agent_2")
await config.async_disconnect_agent_user("agent_1")
assert store.pop_agent_user_id.call_args == call("agent_1")
@pytest.mark.parametrize("agents", [{}, {"1"}, {"1", "2"}])
async def test_report_state_all(agents):
"""Test a disconnect message."""
config = MockConfig(agent_user_ids=agents)
data = {}
with patch.object(config, "async_report_state") as mock:
await config.async_report_state_all(data)
assert sorted(mock.mock_calls) == sorted(
[call(data, agent) for agent in agents]
)
@pytest.mark.parametrize(
"agents, result", [({}, 204), ({"1": 200}, 200), ({"1": 200, "2": 300}, 300)],
)
async def test_sync_entities_all(agents, result):
"""Test sync entities ."""
config = MockConfig(agent_user_ids=set(agents.keys()))
with patch.object(
config,
"async_sync_entities",
side_effect=lambda agent_user_id: agents[agent_user_id],
) as mock:
res = await config.async_sync_entities_all()
assert sorted(mock.mock_calls) == sorted([call(agent) for agent in agents])
assert res == result

View file

@ -12,7 +12,6 @@ from homeassistant.components.google_assistant.const import (
REPORT_STATE_BASE_URL, REPORT_STATE_BASE_URL,
HOMEGRAPH_TOKEN_URL, HOMEGRAPH_TOKEN_URL,
) )
from homeassistant.auth.models import User
DUMMY_CONFIG = GOOGLE_ASSISTANT_SCHEMA( DUMMY_CONFIG = GOOGLE_ASSISTANT_SCHEMA(
{ {
@ -67,6 +66,7 @@ async def test_update_access_token(hass):
jwt = "dummyjwt" jwt = "dummyjwt"
config = GoogleConfig(hass, DUMMY_CONFIG) config = GoogleConfig(hass, DUMMY_CONFIG)
await config.async_initialize()
base_time = datetime(2019, 10, 14, tzinfo=timezone.utc) base_time = datetime(2019, 10, 14, tzinfo=timezone.utc)
with patch( with patch(
@ -99,6 +99,8 @@ async def test_update_access_token(hass):
async def test_call_homegraph_api(hass, aioclient_mock, hass_storage): async def test_call_homegraph_api(hass, aioclient_mock, hass_storage):
"""Test the function to call the homegraph api.""" """Test the function to call the homegraph api."""
config = GoogleConfig(hass, DUMMY_CONFIG) config = GoogleConfig(hass, DUMMY_CONFIG)
await config.async_initialize()
with patch( with patch(
"homeassistant.components.google_assistant.http._get_homegraph_token" "homeassistant.components.google_assistant.http._get_homegraph_token"
) as mock_get_token: ) as mock_get_token:
@ -120,6 +122,8 @@ async def test_call_homegraph_api(hass, aioclient_mock, hass_storage):
async def test_call_homegraph_api_retry(hass, aioclient_mock, hass_storage): async def test_call_homegraph_api_retry(hass, aioclient_mock, hass_storage):
"""Test the that the calls get retried with new token on 401.""" """Test the that the calls get retried with new token on 401."""
config = GoogleConfig(hass, DUMMY_CONFIG) config = GoogleConfig(hass, DUMMY_CONFIG)
await config.async_initialize()
with patch( with patch(
"homeassistant.components.google_assistant.http._get_homegraph_token" "homeassistant.components.google_assistant.http._get_homegraph_token"
) as mock_get_token: ) as mock_get_token:
@ -143,8 +147,10 @@ async def test_call_homegraph_api_retry(hass, aioclient_mock, hass_storage):
async def test_call_homegraph_api_key(hass, aioclient_mock, hass_storage): async def test_call_homegraph_api_key(hass, aioclient_mock, hass_storage):
"""Test the function to call the homegraph api.""" """Test the function to call the homegraph api."""
config = GoogleConfig( config = GoogleConfig(
hass, GOOGLE_ASSISTANT_SCHEMA({"project_id": "1234", "api_key": "dummy_key"}) hass, GOOGLE_ASSISTANT_SCHEMA({"project_id": "1234", "api_key": "dummy_key"}),
) )
await config.async_initialize()
aioclient_mock.post(MOCK_URL, status=200, json={}) aioclient_mock.post(MOCK_URL, status=200, json={})
res = await config.async_call_homegraph_api_key(MOCK_URL, MOCK_JSON) res = await config.async_call_homegraph_api_key(MOCK_URL, MOCK_JSON)
@ -159,8 +165,10 @@ async def test_call_homegraph_api_key(hass, aioclient_mock, hass_storage):
async def test_call_homegraph_api_key_fail(hass, aioclient_mock, hass_storage): async def test_call_homegraph_api_key_fail(hass, aioclient_mock, hass_storage):
"""Test the function to call the homegraph api.""" """Test the function to call the homegraph api."""
config = GoogleConfig( config = GoogleConfig(
hass, GOOGLE_ASSISTANT_SCHEMA({"project_id": "1234", "api_key": "dummy_key"}) hass, GOOGLE_ASSISTANT_SCHEMA({"project_id": "1234", "api_key": "dummy_key"}),
) )
await config.async_initialize()
aioclient_mock.post(MOCK_URL, status=666, json={}) aioclient_mock.post(MOCK_URL, status=666, json={})
res = await config.async_call_homegraph_api_key(MOCK_URL, MOCK_JSON) res = await config.async_call_homegraph_api_key(MOCK_URL, MOCK_JSON)
@ -170,17 +178,16 @@ async def test_call_homegraph_api_key_fail(hass, aioclient_mock, hass_storage):
async def test_report_state(hass, aioclient_mock, hass_storage): async def test_report_state(hass, aioclient_mock, hass_storage):
"""Test the report state function.""" """Test the report state function."""
agent_user_id = "user"
config = GoogleConfig(hass, DUMMY_CONFIG) config = GoogleConfig(hass, DUMMY_CONFIG)
await config.async_initialize()
await config.async_connect_agent_user(agent_user_id)
message = {"devices": {}} message = {"devices": {}}
owner = User(name="Test User", perm_lookup=None, groups=[], is_owner=True)
with patch.object(config, "async_call_homegraph_api") as mock_call, patch.object( with patch.object(config, "async_call_homegraph_api") as mock_call:
hass.auth, "async_get_owner" await config.async_report_state(message, agent_user_id)
) as mock_get_owner:
mock_get_owner.return_value = owner
await config.async_report_state(message)
mock_call.assert_called_once_with( mock_call.assert_called_once_with(
REPORT_STATE_BASE_URL, REPORT_STATE_BASE_URL,
{"requestId": ANY, "agentUserId": owner.id, "payload": message}, {"requestId": ANY, "agentUserId": agent_user_id, "payload": message},
) )

View file

@ -16,7 +16,7 @@ async def test_report_state(hass, caplog):
hass.states.async_set("switch.ac", "on") hass.states.async_set("switch.ac", "on")
with patch.object( with patch.object(
BASIC_CONFIG, "async_report_state", side_effect=mock_coro BASIC_CONFIG, "async_report_state_all", side_effect=mock_coro
) as mock_report, patch.object(report_state, "INITIAL_REPORT_DELAY", 0): ) as mock_report, patch.object(report_state, "INITIAL_REPORT_DELAY", 0):
unsub = report_state.async_enable_report_state(hass, BASIC_CONFIG) unsub = report_state.async_enable_report_state(hass, BASIC_CONFIG)
@ -35,7 +35,7 @@ async def test_report_state(hass, caplog):
} }
with patch.object( with patch.object(
BASIC_CONFIG, "async_report_state", side_effect=mock_coro BASIC_CONFIG, "async_report_state_all", side_effect=mock_coro
) as mock_report: ) as mock_report:
hass.states.async_set("light.kitchen", "on") hass.states.async_set("light.kitchen", "on")
await hass.async_block_till_done() await hass.async_block_till_done()
@ -48,7 +48,7 @@ async def test_report_state(hass, caplog):
# Test that state changes that change something that Google doesn't care about # Test that state changes that change something that Google doesn't care about
# do not trigger a state report. # do not trigger a state report.
with patch.object( with patch.object(
BASIC_CONFIG, "async_report_state", side_effect=mock_coro BASIC_CONFIG, "async_report_state_all", side_effect=mock_coro
) as mock_report: ) as mock_report:
hass.states.async_set( hass.states.async_set(
"light.kitchen", "on", {"irrelevant": "should_be_ignored"} "light.kitchen", "on", {"irrelevant": "should_be_ignored"}
@ -59,7 +59,7 @@ async def test_report_state(hass, caplog):
# Test that entities that we can't query don't report a state # Test that entities that we can't query don't report a state
with patch.object( with patch.object(
BASIC_CONFIG, "async_report_state", side_effect=mock_coro BASIC_CONFIG, "async_report_state_all", side_effect=mock_coro
) as mock_report, patch( ) as mock_report, patch(
"homeassistant.components.google_assistant.report_state.GoogleEntity.query_serialize", "homeassistant.components.google_assistant.report_state.GoogleEntity.query_serialize",
side_effect=error.SmartHomeError("mock-error", "mock-msg"), side_effect=error.SmartHomeError("mock-error", "mock-msg"),
@ -73,7 +73,7 @@ async def test_report_state(hass, caplog):
unsub() unsub()
with patch.object( with patch.object(
BASIC_CONFIG, "async_report_state", side_effect=mock_coro BASIC_CONFIG, "async_report_state_all", side_effect=mock_coro
) as mock_report: ) as mock_report:
hass.states.async_set("light.kitchen", "on") hass.states.async_set("light.kitchen", "on")
await hass.async_block_till_done() await hass.async_block_till_done()

View file

@ -455,7 +455,7 @@ async def test_serialize_input_boolean(hass):
state = State("input_boolean.bla", "on") state = State("input_boolean.bla", "on")
# pylint: disable=protected-access # pylint: disable=protected-access
entity = sh.GoogleEntity(hass, BASIC_CONFIG, state) entity = sh.GoogleEntity(hass, BASIC_CONFIG, state)
result = await entity.sync_serialize() result = await entity.sync_serialize(None)
assert result == { assert result == {
"id": "input_boolean.bla", "id": "input_boolean.bla",
"attributes": {}, "attributes": {},
@ -664,8 +664,8 @@ async def test_query_disconnect(hass):
config.async_enable_report_state() config.async_enable_report_state()
assert config._unsub_report_state is not None assert config._unsub_report_state is not None
with patch.object( with patch.object(
config, "async_deactivate_report_state", side_effect=mock_coro config, "async_disconnect_agent_user", side_effect=mock_coro
) as mock_deactivate: ) as mock_disconnect:
result = await sh.async_handle_message( result = await sh.async_handle_message(
hass, hass,
config, config,
@ -673,7 +673,7 @@ async def test_query_disconnect(hass):
{"inputs": [{"intent": "action.devices.DISCONNECT"}], "requestId": REQ_ID}, {"inputs": [{"intent": "action.devices.DISCONNECT"}], "requestId": REQ_ID},
) )
assert result is None assert result is None
assert len(mock_deactivate.mock_calls) == 1 assert len(mock_disconnect.mock_calls) == 1
async def test_trait_execute_adding_query_data(hass): async def test_trait_execute_adding_query_data(hass):
@ -741,10 +741,12 @@ async def test_trait_execute_adding_query_data(hass):
async def test_identify(hass): async def test_identify(hass):
"""Test identify message.""" """Test identify message."""
user_agent_id = "mock-user-id"
proxy_device_id = user_agent_id
result = await sh.async_handle_message( result = await sh.async_handle_message(
hass, hass,
BASIC_CONFIG, BASIC_CONFIG,
None, user_agent_id,
{ {
"requestId": REQ_ID, "requestId": REQ_ID,
"inputs": [ "inputs": [
@ -778,7 +780,7 @@ async def test_identify(hass):
"customData": { "customData": {
"httpPort": 8123, "httpPort": 8123,
"httpSSL": False, "httpSSL": False,
"proxyDeviceId": BASIC_CONFIG.agent_user_id, "proxyDeviceId": proxy_device_id,
"webhookId": "dde3b9800a905e886cc4d38e226a6e7e3f2a6993d2b9b9f63d13e42ee7de3219", "webhookId": "dde3b9800a905e886cc4d38e226a6e7e3f2a6993d2b9b9f63d13e42ee7de3219",
}, },
} }
@ -790,7 +792,7 @@ async def test_identify(hass):
"requestId": REQ_ID, "requestId": REQ_ID,
"payload": { "payload": {
"device": { "device": {
"id": BASIC_CONFIG.agent_user_id, "id": proxy_device_id,
"isLocalOnly": True, "isLocalOnly": True,
"isProxy": True, "isProxy": True,
"deviceInfo": { "deviceInfo": {
@ -822,10 +824,13 @@ async def test_reachable_devices(hass):
should_expose=lambda state: state.entity_id != "light.not_expose" should_expose=lambda state: state.entity_id != "light.not_expose"
) )
user_agent_id = "mock-user-id"
proxy_device_id = user_agent_id
result = await sh.async_handle_message( result = await sh.async_handle_message(
hass, hass,
config, config,
None, user_agent_id,
{ {
"requestId": REQ_ID, "requestId": REQ_ID,
"inputs": [ "inputs": [
@ -834,7 +839,7 @@ async def test_reachable_devices(hass):
"payload": { "payload": {
"device": { "device": {
"proxyDevice": { "proxyDevice": {
"id": "6a04f0f7-6125-4356-a846-861df7e01497", "id": proxy_device_id,
"customData": "{}", "customData": "{}",
"proxyData": "{}", "proxyData": "{}",
} }
@ -849,7 +854,7 @@ async def test_reachable_devices(hass):
"customData": { "customData": {
"httpPort": 8123, "httpPort": 8123,
"httpSSL": False, "httpSSL": False,
"proxyDeviceId": BASIC_CONFIG.agent_user_id, "proxyDeviceId": proxy_device_id,
"webhookId": "dde3b9800a905e886cc4d38e226a6e7e3f2a6993d2b9b9f63d13e42ee7de3219", "webhookId": "dde3b9800a905e886cc4d38e226a6e7e3f2a6993d2b9b9f63d13e42ee7de3219",
}, },
}, },
@ -858,11 +863,11 @@ async def test_reachable_devices(hass):
"customData": { "customData": {
"httpPort": 8123, "httpPort": 8123,
"httpSSL": False, "httpSSL": False,
"proxyDeviceId": BASIC_CONFIG.agent_user_id, "proxyDeviceId": proxy_device_id,
"webhookId": "dde3b9800a905e886cc4d38e226a6e7e3f2a6993d2b9b9f63d13e42ee7de3219", "webhookId": "dde3b9800a905e886cc4d38e226a6e7e3f2a6993d2b9b9f63d13e42ee7de3219",
}, },
}, },
{"id": BASIC_CONFIG.agent_user_id, "customData": {}}, {"id": proxy_device_id, "customData": {}},
], ],
}, },
) )