Keep track of Alexa authorization status (#63979)

This commit is contained in:
Erik Montnemery 2022-01-13 18:47:31 +01:00 committed by GitHub
parent 49a32c398c
commit be628a7c4d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 204 additions and 23 deletions

View file

@ -2,9 +2,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.storage import Store
from .const import DOMAIN
from .state_report import async_enable_proactive_mode from .state_report import async_enable_proactive_mode
STORE_AUTHORIZED = "authorized"
class AbstractConfig(ABC): class AbstractConfig(ABC):
"""Hold the configuration for Alexa.""" """Hold the configuration for Alexa."""
@ -14,6 +18,12 @@ class AbstractConfig(ABC):
def __init__(self, hass): def __init__(self, hass):
"""Initialize abstract config.""" """Initialize abstract config."""
self.hass = hass self.hass = hass
self._store = None
async def async_initialize(self):
"""Perform async initialization of config."""
self._store = AlexaConfigStore(self.hass)
await self._store.async_load()
@property @property
def supports_auth(self): def supports_auth(self):
@ -86,3 +96,48 @@ class AbstractConfig(ABC):
async def async_accept_grant(self, code): async def async_accept_grant(self, code):
"""Accept a grant.""" """Accept a grant."""
raise NotImplementedError raise NotImplementedError
@property
def authorized(self):
"""Return authorization status."""
return self._store.authorized
def set_authorized(self, authorized):
"""Set authorization status.
- Set when an incoming message is received from Alexa.
- Unset if state reporting fails
"""
self._store.set_authorized(authorized)
class AlexaConfigStore:
"""A configuration store for Alexa."""
_STORAGE_VERSION = 1
_STORAGE_KEY = DOMAIN
def __init__(self, hass):
"""Initialize a configuration store."""
self._data = None
self._hass = hass
self._store = Store(hass, self._STORAGE_VERSION, self._STORAGE_KEY)
@property
def authorized(self):
"""Return authorization status."""
return self._data[STORE_AUTHORIZED]
@callback
def set_authorized(self, authorized):
"""Set authorization status."""
if authorized != self._data[STORE_AUTHORIZED]:
self._data[STORE_AUTHORIZED] = authorized
self._store.async_delay_save(lambda: self._data, 1.0)
async def async_load(self):
"""Load saved configuration from disk."""
if data := await self._store.async_load():
self._data = data
else:
self._data = {STORE_AUTHORIZED: False}

View file

@ -18,6 +18,10 @@ class NoTokenAvailable(HomeAssistantError):
"""There is no access token available.""" """There is no access token available."""
class RequireRelink(Exception):
"""The skill needs to be relinked."""
class AlexaError(Exception): class AlexaError(Exception):
"""Base class for errors that can be serialized for the Alexa API. """Base class for errors that can be serialized for the Alexa API.

View file

@ -31,6 +31,8 @@ async def async_handle_message(hass, config, request, context=None, enabled=True
"Alexa API not enabled in Home Assistant configuration" "Alexa API not enabled in Home Assistant configuration"
) )
config.set_authorized(True)
if directive.has_endpoint: if directive.has_endpoint:
directive.load_entity(hass, config) directive.load_entity(hass, config)

View file

@ -97,6 +97,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> None:
by the cloud component which will call async_handle_message directly. by the cloud component which will call async_handle_message directly.
""" """
smart_home_config = AlexaConfig(hass, config) smart_home_config = AlexaConfig(hass, config)
await smart_home_config.async_initialize()
hass.http.register_view(SmartHomeView(smart_home_config)) hass.http.register_view(SmartHomeView(smart_home_config))
if smart_home_config.should_report_state: if smart_home_config.should_report_state:

View file

@ -17,6 +17,7 @@ import homeassistant.util.dt as dt_util
from .const import API_CHANGE, DATE_FORMAT, DOMAIN, Cause from .const import API_CHANGE, DATE_FORMAT, DOMAIN, Cause
from .entities import ENTITY_ADAPTERS, AlexaEntity, generate_alexa_id from .entities import ENTITY_ADAPTERS, AlexaEntity, generate_alexa_id
from .errors import NoTokenAvailable, RequireRelink
from .messages import AlexaResponse from .messages import AlexaResponse
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -113,7 +114,10 @@ async def async_send_changereport_message(
https://developer.amazon.com/docs/smarthome/state-reporting-for-a-smart-home-skill.html#report-state-with-changereport-events https://developer.amazon.com/docs/smarthome/state-reporting-for-a-smart-home-skill.html#report-state-with-changereport-events
""" """
token = await config.async_get_access_token() try:
token = await config.async_get_access_token()
except (RequireRelink, NoTokenAvailable):
config.set_authorized(False)
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
@ -155,14 +159,18 @@ async def async_send_changereport_message(
response_json = json.loads(response_text) response_json = json.loads(response_text)
if ( if response_json["payload"]["code"] == "INVALID_ACCESS_TOKEN_EXCEPTION":
response_json["payload"]["code"] == "INVALID_ACCESS_TOKEN_EXCEPTION" if invalidate_access_token:
and not invalidate_access_token # Invalidate the access token and try again
): config.async_invalidate_access_token()
config.async_invalidate_access_token() return await async_send_changereport_message(
return await async_send_changereport_message( hass,
hass, config, alexa_entity, alexa_properties, invalidate_access_token=False config,
) alexa_entity,
alexa_properties,
invalidate_access_token=False,
)
config.set_authorized(False)
_LOGGER.error( _LOGGER.error(
"Error when sending ChangeReport to Alexa: %s: %s", "Error when sending ChangeReport to Alexa: %s: %s",

View file

@ -24,7 +24,7 @@ from homeassistant.helpers.event import async_call_later
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from .const import CONF_ENTITY_CONFIG, CONF_FILTER, PREF_SHOULD_EXPOSE, RequireRelink from .const import CONF_ENTITY_CONFIG, CONF_FILTER, PREF_SHOULD_EXPOSE
from .prefs import CloudPreferences from .prefs import CloudPreferences
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -103,6 +103,7 @@ class AlexaConfig(alexa_config.AbstractConfig):
async def async_initialize(self): async def async_initialize(self):
"""Initialize the Alexa config.""" """Initialize the Alexa config."""
await super().async_initialize()
async def hass_started(hass): async def hass_started(hass):
if self.enabled and ALEXA_DOMAIN not in self.hass.config.components: if self.enabled and ALEXA_DOMAIN not in self.hass.config.components:
@ -167,7 +168,7 @@ class AlexaConfig(alexa_config.AbstractConfig):
"Alexa state reporting disabled", "Alexa state reporting disabled",
"cloud_alexa_report", "cloud_alexa_report",
) )
raise RequireRelink raise alexa_errors.RequireRelink
raise alexa_errors.NoTokenAvailable raise alexa_errors.NoTokenAvailable

View file

@ -61,7 +61,3 @@ MODE_DEV = "development"
MODE_PROD = "production" MODE_PROD = "production"
DISPATCHER_REMOTE_UPDATE = "cloud_remote_update" DISPATCHER_REMOTE_UPDATE = "cloud_remote_update"
class RequireRelink(Exception):
"""The skill needs to be relinked."""

View file

@ -35,7 +35,6 @@ from .const import (
PREF_GOOGLE_SECURE_DEVICES_PIN, PREF_GOOGLE_SECURE_DEVICES_PIN,
PREF_TTS_DEFAULT_VOICE, PREF_TTS_DEFAULT_VOICE,
REQUEST_TIMEOUT, REQUEST_TIMEOUT,
RequireRelink,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -366,15 +365,18 @@ async def websocket_update_prefs(hass, connection, msg):
msg["id"], "alexa_timeout", "Timeout validating Alexa access token." msg["id"], "alexa_timeout", "Timeout validating Alexa access token."
) )
return return
except (alexa_errors.NoTokenAvailable, RequireRelink): except (alexa_errors.NoTokenAvailable, alexa_errors.RequireRelink):
connection.send_error( connection.send_error(
msg["id"], msg["id"],
"alexa_relink", "alexa_relink",
"Please go to the Alexa app and re-link the Home Assistant " "Please go to the Alexa app and re-link the Home Assistant "
"skill and then try to enable state reporting.", "skill and then try to enable state reporting.",
) )
alexa_config.set_authorized(False)
return return
alexa_config.set_authorized(True)
await cloud.client.prefs.async_update(**changes) await cloud.client.prefs.async_update(**changes)
connection.send_message(websocket_api.result_message(msg["id"])) connection.send_message(websocket_api.result_message(msg["id"]))
@ -422,7 +424,8 @@ async def _account_data(cloud):
client = cloud.client client = cloud.client
remote = cloud.remote remote = cloud.remote
gconf = await client.get_google_config() alexa_config = await client.get_alexa_config()
google_config = await client.get_google_config()
# Load remote certificate # Load remote certificate
if remote.certificate: if remote.certificate:
@ -435,8 +438,9 @@ async def _account_data(cloud):
"email": claims["email"], "email": claims["email"],
"cloud": cloud.iot.state, "cloud": cloud.iot.state,
"prefs": client.prefs.as_dict(), "prefs": client.prefs.as_dict(),
"google_registered": gconf.has_registered_user_agent, "google_registered": google_config.has_registered_user_agent,
"google_entities": client.google_user_config["filter"].config, "google_entities": client.google_user_config["filter"].config,
"alexa_registered": alexa_config.authorized,
"alexa_entities": client.alexa_user_config["filter"].config, "alexa_entities": client.alexa_user_config["filter"].config,
"remote_domain": remote.instance_domain, "remote_domain": remote.instance_domain,
"remote_connected": remote.is_connected, "remote_connected": remote.is_connected,

View file

@ -1,5 +1,6 @@
"""Tests for the Alexa integration.""" """Tests for the Alexa integration."""
import re import re
from unittest.mock import Mock
from uuid import uuid4 from uuid import uuid4
from homeassistant.components.alexa import config, smart_home from homeassistant.components.alexa import config, smart_home
@ -23,6 +24,11 @@ class MockConfig(config.AbstractConfig):
"camera.test": {"display_categories": "CAMERA"}, "camera.test": {"display_categories": "CAMERA"},
} }
def __init__(self, hass):
"""Mock Alexa config."""
super().__init__(hass)
self._store = Mock(spec_set=config.AlexaConfigStore)
@property @property
def supports_auth(self): def supports_auth(self):
"""Return if config supports auth.""" """Return if config supports auth."""
@ -47,6 +53,10 @@ class MockConfig(config.AbstractConfig):
"""If an entity should be exposed.""" """If an entity should be exposed."""
return True return True
@callback
def async_invalidate_access_token(self):
"""Invalidate access token."""
async def async_get_access_token(self): async def async_get_access_token(self):
"""Get an access token.""" """Get an access token."""
return "thisisnotanacesstoken" return "thisisnotanacesstoken"

View file

@ -3975,3 +3975,14 @@ async def test_button(hass, domain):
await assert_scene_controller_works( await assert_scene_controller_works(
f"{domain}#ring_doorbell", f"{domain}.press", False, hass f"{domain}#ring_doorbell", f"{domain}.press", False, hass
) )
async def test_api_message_sets_authorized(hass):
"""Test an incoming API messages sets the authorized flag."""
msg = get_new_request("Alexa.PowerController", "TurnOn", "switch#xy")
async_mock_service(hass, "switch", "turn_on")
config = get_default_config()
config._store.set_authorized.assert_not_called()
await smart_home.async_handle_message(hass, config, msg)
config._store.set_authorized.assert_called_once_with(True)

View file

@ -41,6 +41,64 @@ async def test_report_state(hass, aioclient_mock):
assert call_json["event"]["endpoint"]["endpointId"] == "binary_sensor#test_contact" assert call_json["event"]["endpoint"]["endpointId"] == "binary_sensor#test_contact"
async def test_report_state_retry(hass, aioclient_mock):
"""Test proactive state retries once."""
aioclient_mock.post(
TEST_URL,
text='{"payload":{"code":"INVALID_ACCESS_TOKEN_EXCEPTION","description":""}}',
status=403,
)
hass.states.async_set(
"binary_sensor.test_contact",
"on",
{"friendly_name": "Test Contact Sensor", "device_class": "door"},
)
await state_report.async_enable_proactive_mode(hass, get_default_config())
hass.states.async_set(
"binary_sensor.test_contact",
"off",
{"friendly_name": "Test Contact Sensor", "device_class": "door"},
)
# To trigger event listener
await hass.async_block_till_done()
assert len(aioclient_mock.mock_calls) == 2
async def test_report_state_unsets_authorized_on_error(hass, aioclient_mock):
"""Test proactive state unsets authorized on error."""
aioclient_mock.post(
TEST_URL,
text='{"payload":{"code":"INVALID_ACCESS_TOKEN_EXCEPTION","description":""}}',
status=403,
)
hass.states.async_set(
"binary_sensor.test_contact",
"on",
{"friendly_name": "Test Contact Sensor", "device_class": "door"},
)
config = get_default_config()
await state_report.async_enable_proactive_mode(hass, config)
hass.states.async_set(
"binary_sensor.test_contact",
"off",
{"friendly_name": "Test Contact Sensor", "device_class": "door"},
)
config._store.set_authorized.assert_not_called()
# To trigger event listener
await hass.async_block_till_done()
config._store.set_authorized.assert_called_once_with(False)
async def test_report_state_instance(hass, aioclient_mock): async def test_report_state_instance(hass, aioclient_mock):
"""Test proactive state reports with instance.""" """Test proactive state reports with instance."""
aioclient_mock.post(TEST_URL, text="", status=202) aioclient_mock.post(TEST_URL, text="", status=202)

View file

@ -12,7 +12,7 @@ import pytest
from homeassistant.components.alexa import errors as alexa_errors from homeassistant.components.alexa import errors as alexa_errors
from homeassistant.components.alexa.entities import LightCapabilities from homeassistant.components.alexa.entities import LightCapabilities
from homeassistant.components.cloud.const import DOMAIN, RequireRelink from homeassistant.components.cloud.const import DOMAIN
from homeassistant.components.google_assistant.helpers import GoogleEntity from homeassistant.components.google_assistant.helpers import GoogleEntity
from homeassistant.core import State from homeassistant.core import State
from homeassistant.util.location import LocationInfo from homeassistant.util.location import LocationInfo
@ -414,6 +414,7 @@ async def test_websocket_status(
"exclude_entity_globs": [], "exclude_entity_globs": [],
"exclude_entities": [], "exclude_entities": [],
}, },
"alexa_registered": False,
"google_entities": { "google_entities": {
"include_domains": ["light"], "include_domains": ["light"],
"include_entity_globs": [], "include_entity_globs": [],
@ -509,6 +510,28 @@ async def test_websocket_update_preferences(
assert setup_api.tts_default_voice == ("en-GB", "male") assert setup_api.tts_default_voice == ("en-GB", "male")
async def test_websocket_update_preferences_alexa_report_state(
hass, hass_ws_client, aioclient_mock, setup_api, mock_cloud_login
):
"""Test updating alexa_report_state sets alexa authorized."""
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.cloud.alexa_config.AlexaConfig"
".async_get_access_token",
), patch(
"homeassistant.components.cloud.alexa_config.AlexaConfig.set_authorized"
) as set_authorized_mock:
set_authorized_mock.assert_not_called()
await client.send_json(
{"id": 5, "type": "cloud/update_prefs", "alexa_report_state": True}
)
response = await client.receive_json()
set_authorized_mock.assert_called_once_with(True)
assert response["success"]
async def test_websocket_update_preferences_require_relink( async def test_websocket_update_preferences_require_relink(
hass, hass_ws_client, aioclient_mock, setup_api, mock_cloud_login hass, hass_ws_client, aioclient_mock, setup_api, mock_cloud_login
): ):
@ -518,12 +541,16 @@ async def test_websocket_update_preferences_require_relink(
with patch( with patch(
"homeassistant.components.cloud.alexa_config.AlexaConfig" "homeassistant.components.cloud.alexa_config.AlexaConfig"
".async_get_access_token", ".async_get_access_token",
side_effect=RequireRelink, side_effect=alexa_errors.RequireRelink,
): ), patch(
"homeassistant.components.cloud.alexa_config.AlexaConfig.set_authorized"
) as set_authorized_mock:
set_authorized_mock.assert_not_called()
await client.send_json( await client.send_json(
{"id": 5, "type": "cloud/update_prefs", "alexa_report_state": True} {"id": 5, "type": "cloud/update_prefs", "alexa_report_state": True}
) )
response = await client.receive_json() response = await client.receive_json()
set_authorized_mock.assert_called_once_with(False)
assert not response["success"] assert not response["success"]
assert response["error"]["code"] == "alexa_relink" assert response["error"]["code"] == "alexa_relink"
@ -539,11 +566,15 @@ async def test_websocket_update_preferences_no_token(
"homeassistant.components.cloud.alexa_config.AlexaConfig" "homeassistant.components.cloud.alexa_config.AlexaConfig"
".async_get_access_token", ".async_get_access_token",
side_effect=alexa_errors.NoTokenAvailable, side_effect=alexa_errors.NoTokenAvailable,
): ), patch(
"homeassistant.components.cloud.alexa_config.AlexaConfig.set_authorized"
) as set_authorized_mock:
set_authorized_mock.assert_not_called()
await client.send_json( await client.send_json(
{"id": 5, "type": "cloud/update_prefs", "alexa_report_state": True} {"id": 5, "type": "cloud/update_prefs", "alexa_report_state": True}
) )
response = await client.receive_json() response = await client.receive_json()
set_authorized_mock.assert_called_once_with(False)
assert not response["success"] assert not response["success"]
assert response["error"]["code"] == "alexa_relink" assert response["error"]["code"] == "alexa_relink"