diff --git a/homeassistant/components/alexa/__init__.py b/homeassistant/components/alexa/__init__.py index 337d8993b28..0bfa01a83ca 100644 --- a/homeassistant/components/alexa/__init__.py +++ b/homeassistant/components/alexa/__init__.py @@ -13,8 +13,9 @@ from homeassistant.helpers import entityfilter from . import flash_briefings, intent, smart_home from .const import ( - CONF_AUDIO, CONF_DISPLAY_URL, CONF_TEXT, CONF_TITLE, CONF_UID, DOMAIN, - CONF_FILTER, CONF_ENTITY_CONFIG) + CONF_AUDIO, CONF_CLIENT_ID, CONF_CLIENT_SECRET, CONF_DISPLAY_URL, + CONF_ENDPOINT, CONF_TEXT, CONF_TITLE, CONF_UID, DOMAIN, CONF_FILTER, + CONF_ENTITY_CONFIG) _LOGGER = logging.getLogger(__name__) @@ -30,6 +31,9 @@ ALEXA_ENTITY_SCHEMA = vol.Schema({ }) SMART_HOME_SCHEMA = vol.Schema({ + vol.Optional(CONF_ENDPOINT): cv.string, + vol.Optional(CONF_CLIENT_ID): cv.string, + vol.Optional(CONF_CLIENT_SECRET): cv.string, vol.Optional(CONF_FILTER, default={}): entityfilter.FILTER_SCHEMA, vol.Optional(CONF_ENTITY_CONFIG): {cv.entity_id: ALEXA_ENTITY_SCHEMA} }) diff --git a/homeassistant/components/alexa/auth.py b/homeassistant/components/alexa/auth.py new file mode 100644 index 00000000000..978cb611895 --- /dev/null +++ b/homeassistant/components/alexa/auth.py @@ -0,0 +1,154 @@ +"""Support for Alexa skill auth.""" + +import asyncio +import json +import logging +from datetime import timedelta +import aiohttp +import async_timeout + +from homeassistant.core import callback +from homeassistant.helpers import aiohttp_client +from homeassistant.util import dt +from .const import DEFAULT_TIMEOUT + +_LOGGER = logging.getLogger(__name__) + +LWA_TOKEN_URI = "https://api.amazon.com/auth/o2/token" +LWA_HEADERS = { + "Content-Type": "application/x-www-form-urlencoded;charset=UTF-8" +} + +PREEMPTIVE_REFRESH_TTL_IN_SECONDS = 300 +STORAGE_KEY = 'alexa_auth' +STORAGE_VERSION = 1 +STORAGE_EXPIRE_TIME = "expire_time" +STORAGE_ACCESS_TOKEN = "access_token" +STORAGE_REFRESH_TOKEN = "refresh_token" + + +class Auth: + """Handle authentication to send events to Alexa.""" + + def __init__(self, hass, client_id, client_secret): + """Initialize the Auth class.""" + self.hass = hass + + self.client_id = client_id + self.client_secret = client_secret + + self._prefs = None + self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) + + self._get_token_lock = asyncio.Lock(loop=hass.loop) + + async def async_do_auth(self, accept_grant_code): + """Do authentication with an AcceptGrant code.""" + # access token not retrieved yet for the first time, so this should + # be an access token request + + lwa_params = { + "grant_type": "authorization_code", + "code": accept_grant_code, + "client_id": self.client_id, + "client_secret": self.client_secret + } + _LOGGER.debug("Calling LWA to get the access token (first time), " + "with: %s", json.dumps(lwa_params)) + + return await self._async_request_new_token(lwa_params) + + async def async_get_access_token(self): + """Perform access token or token refresh request.""" + async with self._get_token_lock: + if self._prefs is None: + await self.async_load_preferences() + + if self.is_token_valid(): + _LOGGER.debug("Token still valid, using it.") + return self._prefs[STORAGE_ACCESS_TOKEN] + + if self._prefs[STORAGE_REFRESH_TOKEN] is None: + _LOGGER.debug("Token invalid and no refresh token available.") + return None + + lwa_params = { + "grant_type": "refresh_token", + "refresh_token": self._prefs[STORAGE_REFRESH_TOKEN], + "client_id": self.client_id, + "client_secret": self.client_secret + } + + _LOGGER.debug("Calling LWA to refresh the access token.") + return await self._async_request_new_token(lwa_params) + + @callback + def is_token_valid(self): + """Check if a token is already loaded and if it is still valid.""" + if not self._prefs[STORAGE_ACCESS_TOKEN]: + return False + + expire_time = dt.parse_datetime(self._prefs[STORAGE_EXPIRE_TIME]) + preemptive_expire_time = expire_time - timedelta( + seconds=PREEMPTIVE_REFRESH_TTL_IN_SECONDS) + + return dt.utcnow() < preemptive_expire_time + + async def _async_request_new_token(self, lwa_params): + + try: + session = aiohttp_client.async_get_clientsession(self.hass) + with async_timeout.timeout(DEFAULT_TIMEOUT, loop=self.hass.loop): + response = await session.post(LWA_TOKEN_URI, + headers=LWA_HEADERS, + data=lwa_params, + allow_redirects=True) + + except (asyncio.TimeoutError, aiohttp.ClientError): + _LOGGER.error("Timeout calling LWA to get auth token.") + return None + + _LOGGER.debug("LWA response header: %s", response.headers) + _LOGGER.debug("LWA response status: %s", response.status) + + if response.status != 200: + _LOGGER.error("Error calling LWA to get auth token.") + return None + + response_json = await response.json() + _LOGGER.debug("LWA response body : %s", response_json) + + access_token = response_json["access_token"] + refresh_token = response_json["refresh_token"] + expires_in = response_json["expires_in"] + expire_time = dt.utcnow() + timedelta(seconds=expires_in) + + await self._async_update_preferences(access_token, refresh_token, + expire_time.isoformat()) + + return access_token + + async def async_load_preferences(self): + """Load preferences with stored tokens.""" + self._prefs = await self._store.async_load() + + if self._prefs is None: + self._prefs = { + STORAGE_ACCESS_TOKEN: None, + STORAGE_REFRESH_TOKEN: None, + STORAGE_EXPIRE_TIME: None + } + + async def _async_update_preferences(self, access_token, refresh_token, + expire_time): + """Update user preferences.""" + if self._prefs is None: + await self.async_load_preferences() + + if access_token is not None: + self._prefs[STORAGE_ACCESS_TOKEN] = access_token + if refresh_token is not None: + self._prefs[STORAGE_REFRESH_TOKEN] = refresh_token + if expire_time is not None: + self._prefs[STORAGE_EXPIRE_TIME] = expire_time + await self._store.async_save(self._prefs) diff --git a/homeassistant/components/alexa/const.py b/homeassistant/components/alexa/const.py index 7d6489b535a..78f7d02f5f0 100644 --- a/homeassistant/components/alexa/const.py +++ b/homeassistant/components/alexa/const.py @@ -10,6 +10,9 @@ CONF_DISPLAY_URL = 'display_url' CONF_FILTER = 'filter' CONF_ENTITY_CONFIG = 'entity_config' +CONF_ENDPOINT = 'endpoint' +CONF_CLIENT_ID = 'client_id' +CONF_CLIENT_SECRET = 'client_secret' ATTR_UID = 'uid' ATTR_UPDATE_DATE = 'updateDate' @@ -21,3 +24,5 @@ ATTR_REDIRECTION_URL = 'redirectionURL' SYN_RESOLUTION_MATCH = 'ER_SUCCESS_MATCH' DATE_FORMAT = '%Y-%m-%dT%H:%M:%S.0Z' + +DEFAULT_TIMEOUT = 30 diff --git a/homeassistant/components/alexa/smart_home.py b/homeassistant/components/alexa/smart_home.py index f06b853087f..1558a1bf218 100644 --- a/homeassistant/components/alexa/smart_home.py +++ b/homeassistant/components/alexa/smart_home.py @@ -5,15 +5,22 @@ https://developer.amazon.com/docs/smarthome/understand-the-smart-home-skill-api. https://developer.amazon.com/docs/device-apis/message-guide.html """ +import asyncio from collections import OrderedDict from datetime import datetime +import json import logging import math from uuid import uuid4 +import aiohttp +import async_timeout + from homeassistant.components import ( alert, automation, binary_sensor, climate, cover, fan, group, http, input_boolean, light, lock, media_player, scene, script, sensor, switch) +from homeassistant.helpers import aiohttp_client +from homeassistant.helpers.event import async_track_state_change from homeassistant.const import ( ATTR_DEVICE_CLASS, ATTR_ENTITY_ID, ATTR_SUPPORTED_FEATURES, ATTR_TEMPERATURE, ATTR_UNIT_OF_MEASUREMENT, CLOUD_NEVER_EXPOSED_ENTITIES, @@ -21,13 +28,15 @@ from homeassistant.const import ( SERVICE_MEDIA_PLAY, SERVICE_MEDIA_PREVIOUS_TRACK, SERVICE_MEDIA_STOP, SERVICE_SET_COVER_POSITION, SERVICE_TURN_OFF, SERVICE_TURN_ON, SERVICE_UNLOCK, SERVICE_VOLUME_SET, STATE_LOCKED, STATE_ON, STATE_UNLOCKED, - TEMP_CELSIUS, TEMP_FAHRENHEIT) + TEMP_CELSIUS, TEMP_FAHRENHEIT, MATCH_ALL) import homeassistant.core as ha import homeassistant.util.color as color_util from homeassistant.util.decorator import Registry from homeassistant.util.temperature import convert as convert_temperature -from .const import CONF_ENTITY_CONFIG, CONF_FILTER +from .const import CONF_CLIENT_ID, CONF_CLIENT_SECRET, CONF_ENDPOINT, \ + CONF_ENTITY_CONFIG, CONF_FILTER, DATE_FORMAT, DEFAULT_TIMEOUT +from .auth import Auth _LOGGER = logging.getLogger(__name__) @@ -37,6 +46,8 @@ API_EVENT = 'event' API_CONTEXT = 'context' API_HEADER = 'header' API_PAYLOAD = 'payload' +API_SCOPE = 'scope' +API_CHANGE = 'change' API_TEMP_UNITS = { TEMP_FAHRENHEIT: 'FAHRENHEIT', @@ -66,6 +77,8 @@ HANDLERS = Registry() ENTITY_ADAPTERS = Registry() EVENT_ALEXA_SMART_HOME = 'alexa_smart_home' +AUTH_KEY = "alexa.smart_home.auth" + class _DisplayCategory: """Possible display categories for Discovery response. @@ -375,6 +388,8 @@ class _AlexaInterface: 'name': prop_name, 'namespace': self.name(), 'value': prop_value, + 'timeOfSample': datetime.now().strftime(DATE_FORMAT), + 'uncertaintyInMilliseconds': 0 } @@ -390,6 +405,9 @@ class _AlexaPowerController(_AlexaInterface): def properties_supported(self): return [{'name': 'powerState'}] + def properties_proactively_reported(self): + return True + def properties_retrievable(self): return True @@ -417,6 +435,9 @@ class _AlexaLockController(_AlexaInterface): def properties_retrievable(self): return True + def properties_proactively_reported(self): + return True + def get_property(self, name): if name != 'lockState': raise _UnsupportedProperty(name) @@ -454,6 +475,9 @@ class _AlexaBrightnessController(_AlexaInterface): def properties_supported(self): return [{'name': 'brightness'}] + def properties_proactively_reported(self): + return True + def properties_retrievable(self): return True @@ -585,6 +609,9 @@ class _AlexaTemperatureSensor(_AlexaInterface): def properties_supported(self): return [{'name': 'temperature'}] + def properties_proactively_reported(self): + return True + def properties_retrievable(self): return True @@ -625,6 +652,9 @@ class _AlexaContactSensor(_AlexaInterface): def properties_supported(self): return [{'name': 'detectionState'}] + def properties_proactively_reported(self): + return True + def properties_retrievable(self): return True @@ -648,6 +678,9 @@ class _AlexaMotionSensor(_AlexaInterface): def properties_supported(self): return [{'name': 'detectionState'}] + def properties_proactively_reported(self): + return True + def properties_retrievable(self): return True @@ -686,6 +719,9 @@ class _AlexaThermostatController(_AlexaInterface): properties.append({'name': 'thermostatMode'}) return properties + def properties_proactively_reported(self): + return True + def properties_retrievable(self): return True @@ -948,8 +984,11 @@ class _Cause: class Config: """Hold the configuration for Alexa.""" - def __init__(self, should_expose, entity_config=None): + def __init__(self, endpoint, async_get_access_token, should_expose, + entity_config=None): """Initialize the configuration.""" + self.endpoint = endpoint + self.async_get_access_token = async_get_access_token self.should_expose = should_expose self.entity_config = entity_config or {} @@ -964,12 +1003,62 @@ def async_setup(hass, config): Even if that's disabled, the functionality in this module may still be used by the cloud component which will call async_handle_message directly. """ + if config.get(CONF_CLIENT_ID) and config.get(CONF_CLIENT_SECRET): + hass.data[AUTH_KEY] = Auth(hass, config[CONF_CLIENT_ID], + config[CONF_CLIENT_SECRET]) + + async_get_access_token = \ + hass.data[AUTH_KEY].async_get_access_token if AUTH_KEY in hass.data \ + else None + smart_home_config = Config( + endpoint=config.get(CONF_ENDPOINT), + async_get_access_token=async_get_access_token, should_expose=config[CONF_FILTER], entity_config=config.get(CONF_ENTITY_CONFIG), ) hass.http.register_view(SmartHomeView(smart_home_config)) + if AUTH_KEY in hass.data: + hass.loop.create_task( + async_enable_proactive_mode(hass, smart_home_config)) + + +async def async_enable_proactive_mode(hass, smart_home_config): + """Enable the proactive mode. + + Proactive mode makes this component report state changes to Alexa. + """ + if smart_home_config.async_get_access_token is None: + # no function to call to get token + return + + if await smart_home_config.async_get_access_token() is None: + # not ready yet + return + + async def async_entity_state_listener(changed_entity, old_state, + new_state): + if not smart_home_config.should_expose(changed_entity): + _LOGGER.debug("Not exposing %s because filtered by config", + changed_entity) + return + + if new_state.domain not in ENTITY_ADAPTERS: + return + + alexa_changed_entity = \ + ENTITY_ADAPTERS[new_state.domain](hass, smart_home_config, + new_state) + + for interface in alexa_changed_entity.interfaces(): + if interface.properties_proactively_reported(): + await async_send_changereport_message(hass, smart_home_config, + alexa_changed_entity) + return + + async_track_state_change(hass, MATCH_ALL, async_entity_state_listener) + class SmartHomeView(http.HomeAssistantView): """Expose Smart Home v3 payload interface via HTTP POST.""" @@ -1112,6 +1201,24 @@ class _AlexaResponse: """ self._response[API_EVENT][API_HEADER]['correlationToken'] = token + def set_endpoint_full(self, bearer_token, endpoint_id, cookie=None): + """Set the endpoint dictionary. + + This is used to send proactive messages to Alexa. + """ + self._response[API_EVENT][API_ENDPOINT] = { + API_SCOPE: { + 'type': 'BearerToken', + 'token': bearer_token + } + } + + if endpoint_id is not None: + self._response[API_EVENT][API_ENDPOINT]['endpointId'] = endpoint_id + + if cookie is not None: + self._response[API_EVENT][API_ENDPOINT]['cookie'] = cookie + def set_endpoint(self, endpoint): """Set the endpoint. @@ -1222,6 +1329,62 @@ async def async_handle_message( return response.serialize() +async def async_send_changereport_message(hass, config, alexa_entity): + """Send a ChangeReport message for an Alexa entity.""" + token = await config.async_get_access_token() + if not token: + _LOGGER.error("Invalid access token.") + return + + headers = { + "Authorization": "Bearer {}".format(token), + "Content-Type": "application/json;charset=UTF-8" + } + + endpoint = alexa_entity.entity_id() + + # this sends all the properties of the Alexa Entity, whether they have + # changed or not. this should be improved, and properties that have not + # changed should be moved to the 'context' object + properties = list(alexa_entity.serialize_properties()) + + payload = { + API_CHANGE: { + 'cause': {'type': _Cause.APP_INTERACTION}, + 'properties': properties + } + } + + message = _AlexaResponse(name='ChangeReport', namespace='Alexa', + payload=payload) + message.set_endpoint_full(token, endpoint) + + message_str = json.dumps(message.serialize()) + + try: + session = aiohttp_client.async_get_clientsession(hass) + with async_timeout.timeout(DEFAULT_TIMEOUT, loop=hass.loop): + response = await session.post(config.endpoint, + headers=headers, + data=message_str, + allow_redirects=True) + + except (asyncio.TimeoutError, aiohttp.ClientError): + _LOGGER.error("Timeout calling LWA to get auth token.") + return None + + response_text = await response.text() + + _LOGGER.debug("Sent: %s", message_str) + _LOGGER.debug("Received (%s): %s", response.status, response_text) + + if response.status != 202: + response_json = json.loads(response_text) + _LOGGER.error("Error when sending ChangeReport to Alexa: %s: %s", + response_json["payload"]["code"], + response_json["payload"]["description"]) + + @HANDLERS.register(('Alexa.Discovery', 'Discover')) async def async_api_discovery(hass, config, directive, context): """Create a API formatted discovery response. @@ -1258,8 +1421,9 @@ async def async_api_discovery(hass, config, directive, context): i.serialize_discovery() for i in alexa_entity.interfaces()] if not endpoint['capabilities']: - _LOGGER.debug("Not exposing %s because it has no capabilities", - entity.entity_id) + _LOGGER.debug( + "Not exposing %s because it has no capabilities", + entity.entity_id) continue discovery_endpoints.append(endpoint) @@ -1270,6 +1434,25 @@ async def async_api_discovery(hass, config, directive, context): ) +@HANDLERS.register(('Alexa.Authorization', 'AcceptGrant')) +async def async_api_accept_grant(hass, config, directive, context): + """Create a API formatted AcceptGrant response. + + Async friendly. + """ + auth_code = directive.payload['grant']['code'] + _LOGGER.debug("AcceptGrant code: %s", auth_code) + + if AUTH_KEY in hass.data: + await hass.data[AUTH_KEY].async_do_auth(auth_code) + await async_enable_proactive_mode(hass, config) + + return directive.response( + name='AcceptGrant.Response', + namespace='Alexa.Authorization', + payload={}) + + @HANDLERS.register(('Alexa.PowerController', 'TurnOn')) async def async_api_turn_on(hass, config, directive, context): """Process a turn on request.""" diff --git a/homeassistant/components/cloud/__init__.py b/homeassistant/components/cloud/__init__.py index fd5b413043e..d938dd20e67 100644 --- a/homeassistant/components/cloud/__init__.py +++ b/homeassistant/components/cloud/__init__.py @@ -99,6 +99,8 @@ async def async_setup(hass, config): kwargs[CONF_GOOGLE_ACTIONS] = GACTIONS_SCHEMA({}) kwargs[CONF_ALEXA] = alexa_sh.Config( + endpoint=None, + async_get_access_token=None, should_expose=alexa_conf[CONF_FILTER], entity_config=alexa_conf.get(CONF_ENTITY_CONFIG), ) diff --git a/tests/components/alexa/test_smart_home.py b/tests/components/alexa/test_smart_home.py index ddf66d1c617..93551076461 100644 --- a/tests/components/alexa/test_smart_home.py +++ b/tests/components/alexa/test_smart_home.py @@ -11,11 +11,24 @@ from homeassistant.const import ( from homeassistant.setup import async_setup_component from homeassistant.components import alexa from homeassistant.components.alexa import smart_home +from homeassistant.components.alexa.auth import Auth from homeassistant.helpers import entityfilter from tests.common import async_mock_service -DEFAULT_CONFIG = smart_home.Config(should_expose=lambda entity_id: True) + +async def get_access_token(): + """Return a test access token.""" + return "thisisnotanacesstoken" + + +TEST_URL = "https://api.amazonalexa.com/v3/events" +TEST_TOKEN_URL = "https://api.amazon.com/auth/o2/token" + +DEFAULT_CONFIG = smart_home.Config( + endpoint=TEST_URL, + async_get_access_token=get_access_token, + should_expose=lambda entity_id: True) @pytest.fixture @@ -940,12 +953,15 @@ async def test_exclude_filters(hass): hass.states.async_set( 'cover.deny', 'off', {'friendly_name': "Blocked cover"}) - config = smart_home.Config(should_expose=entityfilter.generate_filter( - include_domains=[], - include_entities=[], - exclude_domains=['script'], - exclude_entities=['cover.deny'], - )) + config = smart_home.Config( + endpoint=None, + async_get_access_token=None, + should_expose=entityfilter.generate_filter( + include_domains=[], + include_entities=[], + exclude_domains=['script'], + exclude_entities=['cover.deny'], + )) msg = await smart_home.async_handle_message(hass, config, request) await hass.async_block_till_done() @@ -972,12 +988,15 @@ async def test_include_filters(hass): hass.states.async_set( 'group.allow', 'off', {'friendly_name': "Allowed group"}) - config = smart_home.Config(should_expose=entityfilter.generate_filter( - include_domains=['automation', 'group'], - include_entities=['script.deny'], - exclude_domains=[], - exclude_entities=[], - )) + config = smart_home.Config( + endpoint=None, + async_get_access_token=None, + should_expose=entityfilter.generate_filter( + include_domains=['automation', 'group'], + include_entities=['script.deny'], + exclude_domains=[], + exclude_entities=[], + )) msg = await smart_home.async_handle_message(hass, config, request) await hass.async_block_till_done() @@ -998,12 +1017,15 @@ async def test_never_exposed_entities(hass): hass.states.async_set( 'group.allow', 'off', {'friendly_name': "Allowed group"}) - config = smart_home.Config(should_expose=entityfilter.generate_filter( - include_domains=['group'], - include_entities=[], - exclude_domains=[], - exclude_entities=[], - )) + config = smart_home.Config( + endpoint=None, + async_get_access_token=None, + should_expose=entityfilter.generate_filter( + include_domains=['group'], + include_entities=[], + exclude_domains=[], + exclude_entities=[], + )) msg = await smart_home.async_handle_message(hass, config, request) await hass.async_block_till_done() @@ -1293,6 +1315,33 @@ async def test_api_increase_color_temp(hass, result, initial): assert msg['header']['name'] == 'Response' +async def test_api_accept_grant(hass): + """Test api AcceptGrant process.""" + request = get_new_request("Alexa.Authorization", "AcceptGrant") + + # add payload + request['directive']['payload'] = { + 'grant': { + 'type': 'OAuth2.AuthorizationCode', + 'code': 'VGhpcyBpcyBhbiBhdXRob3JpemF0aW9uIGNvZGUuIDotKQ==' + }, + 'grantee': { + 'type': 'BearerToken', + 'token': 'access-token-from-skill' + } + } + + # setup test devices + msg = await smart_home.async_handle_message( + hass, DEFAULT_CONFIG, request) + await hass.async_block_till_done() + + assert 'event' in msg + msg = msg['event'] + + assert msg['header']['name'] == 'AcceptGrant.Response' + + async def test_report_lock_state(hass): """Test LockController implements lockState property.""" hass.states.async_set( @@ -1412,6 +1461,8 @@ async def test_entity_config(hass): 'light.test_1', 'on', {'friendly_name': "Test light 1"}) config = smart_home.Config( + endpoint=None, + async_get_access_token=None, should_expose=lambda entity_id: True, entity_config={ 'light.test_1': { @@ -1598,3 +1649,104 @@ async def test_disabled(hass): assert msg['header']['name'] == 'ErrorResponse' assert msg['header']['namespace'] == 'Alexa' assert msg['payload']['type'] == 'BRIDGE_UNREACHABLE' + + +async def test_report_state(hass, aioclient_mock): + """Test proactive state reports.""" + aioclient_mock.post(TEST_URL, json={'data': 'is irrelevant'}) + + hass.states.async_set( + 'binary_sensor.test_contact', + 'on', + { + 'friendly_name': "Test Contact Sensor", + 'device_class': 'door', + } + ) + + await smart_home.async_enable_proactive_mode(hass, DEFAULT_CONFIG) + + hass.states.async_set( + 'binary_sensor.test_contact', + 'off', + { + 'friendly_name': "Test Contact Sensor", + 'device_class': 'door', + } + ) + + # To trigger event listener + await hass.async_block_till_done() + + assert len(aioclient_mock.mock_calls) == 1 + call = aioclient_mock.mock_calls + + call_json = json.loads(call[0][2]) + assert call_json["event"]["payload"]["change"]["properties"][0][ + "value"] == "NOT_DETECTED" + assert call_json["event"]["endpoint"][ + "endpointId"] == "binary_sensor#test_contact" + + +async def run_auth_get_access_token(hass, aioclient_mock, expires_in, + client_id, client_secret, + accept_grant_code, refresh_token): + """Do auth and request a new token for tests.""" + aioclient_mock.post(TEST_TOKEN_URL, + json={'access_token': 'the_access_token', + 'refresh_token': refresh_token, + 'expires_in': expires_in}) + + auth = Auth(hass, client_id, client_secret) + await auth.async_do_auth(accept_grant_code) + await auth.async_get_access_token() + + +async def test_auth_get_access_token_expired(hass, aioclient_mock): + """Test the auth get access token function.""" + client_id = "client123" + client_secret = "shhhhh" + accept_grant_code = "abcdefg" + refresh_token = "refresher" + + await run_auth_get_access_token(hass, aioclient_mock, -5, + client_id, client_secret, + accept_grant_code, refresh_token) + + assert len(aioclient_mock.mock_calls) == 2 + calls = aioclient_mock.mock_calls + + auth_call_json = calls[0][2] + token_call_json = calls[1][2] + + assert auth_call_json["grant_type"] == "authorization_code" + assert auth_call_json["code"] == accept_grant_code + assert auth_call_json["client_id"] == client_id + assert auth_call_json["client_secret"] == client_secret + + assert token_call_json["grant_type"] == "refresh_token" + assert token_call_json["refresh_token"] == refresh_token + assert token_call_json["client_id"] == client_id + assert token_call_json["client_secret"] == client_secret + + +async def test_auth_get_access_token_not_expired(hass, aioclient_mock): + """Test the auth get access token function.""" + client_id = "client123" + client_secret = "shhhhh" + accept_grant_code = "abcdefg" + refresh_token = "refresher" + + await run_auth_get_access_token(hass, aioclient_mock, 555, + client_id, client_secret, + accept_grant_code, refresh_token) + + assert len(aioclient_mock.mock_calls) == 1 + call = aioclient_mock.mock_calls + + auth_call_json = call[0][2] + + assert auth_call_json["grant_type"] == "authorization_code" + assert auth_call_json["code"] == accept_grant_code + assert auth_call_json["client_id"] == client_id + assert auth_call_json["client_secret"] == client_secret