Make sure we check access token expiration (#15207)

* Make sure we check access token expiration

* Use correct access token websocket
This commit is contained in:
Paulus Schoutsen 2018-06-29 00:02:33 -04:00 committed by GitHub
parent 2205090795
commit 39971ee919
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 11 deletions

View file

@ -159,9 +159,10 @@ class AccessToken:
default=attr.Factory(generate_secret))
@property
def expires(self):
"""Return datetime when this token expires."""
return self.created_at + self.refresh_token.access_token_expiration
def expired(self):
"""Return if this token has expired."""
expires = self.created_at + self.refresh_token.access_token_expiration
return dt_util.utcnow() > expires
@attr.s(slots=True)
@ -272,7 +273,12 @@ class AuthManager:
self.login_flow = data_entry_flow.FlowManager(
hass, self._async_create_login_flow,
self._async_finish_login_flow)
self.access_tokens = {}
self._access_tokens = {}
@property
def active(self):
"""Return if any auth providers are registered."""
return bool(self._providers)
@property
def async_auth_providers(self):
@ -308,13 +314,22 @@ class AuthManager:
def async_create_access_token(self, refresh_token):
"""Create a new access token."""
access_token = AccessToken(refresh_token)
self.access_tokens[access_token.token] = access_token
self._access_tokens[access_token.token] = access_token
return access_token
@callback
def async_get_access_token(self, token):
"""Get an access token."""
return self.access_tokens.get(token)
tkn = self._access_tokens.get(token)
if tkn is None:
return None
if tkn.expired:
self._access_tokens.pop(token)
return None
return tkn
async def async_create_client(self, name, *, redirect_uris=None,
no_secret=False):