Enhance SmartThings component subscription (#21124)

* Move to config v2 to store SmartApp oauth keys

* Add migration functionality.

* Regenerate refresh token on periodic basis

* Fix regenerate and misc. optimizations

* Review feedback

* Subscription sync logic now performs a difference operation

* Removed config entry reloading.
This commit is contained in:
Andrew Sayre 2019-02-22 13:35:12 -06:00 committed by Martin Hjelmare
parent d9712027e8
commit 8b38b82e73
14 changed files with 529 additions and 275 deletions

View file

@ -14,16 +14,20 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_send) async_dispatcher_connect, async_dispatcher_send)
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import ConfigType, HomeAssistantType from homeassistant.helpers.typing import ConfigType, HomeAssistantType
from .config_flow import SmartThingsFlowHandler # noqa from .config_flow import SmartThingsFlowHandler # noqa
from .const import ( from .const import (
CONF_APP_ID, CONF_INSTALLED_APP_ID, DATA_BROKERS, DATA_MANAGER, DOMAIN, CONF_APP_ID, CONF_INSTALLED_APP_ID, CONF_OAUTH_CLIENT_ID,
EVENT_BUTTON, SIGNAL_SMARTTHINGS_UPDATE, SUPPORTED_PLATFORMS) CONF_OAUTH_CLIENT_SECRET, CONF_REFRESH_TOKEN, DATA_BROKERS, DATA_MANAGER,
DOMAIN, EVENT_BUTTON, SIGNAL_SMARTTHINGS_UPDATE, SUPPORTED_PLATFORMS,
TOKEN_REFRESH_INTERVAL)
from .smartapp import ( from .smartapp import (
setup_smartapp, setup_smartapp_endpoint, validate_installed_app) setup_smartapp, setup_smartapp_endpoint, smartapp_sync_subscriptions,
validate_installed_app)
REQUIREMENTS = ['pysmartapp==0.3.0', 'pysmartthings==0.6.2'] REQUIREMENTS = ['pysmartapp==0.3.0', 'pysmartthings==0.6.3']
DEPENDENCIES = ['webhook'] DEPENDENCIES = ['webhook']
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -35,6 +39,33 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType):
return True return True
async def async_migrate_entry(hass: HomeAssistantType, entry: ConfigEntry):
"""Handle migration of a previous version config entry.
A config entry created under a previous version must go through the
integration setup again so we can properly retrieve the needed data
elements. Force this by removing the entry and triggering a new flow.
"""
from pysmartthings import SmartThings
# Delete the installed app
api = SmartThings(async_get_clientsession(hass),
entry.data[CONF_ACCESS_TOKEN])
await api.delete_installed_app(entry.data[CONF_INSTALLED_APP_ID])
# Delete the entry
hass.async_create_task(
hass.config_entries.async_remove(entry.entry_id))
# only create new flow if there isn't a pending one for SmartThings.
flows = hass.config_entries.flow.async_progress()
if not [flow for flow in flows if flow['handler'] == DOMAIN]:
hass.async_create_task(
hass.config_entries.flow.async_init(
DOMAIN, context={'source': 'import'}))
# Return False because it could not be migrated.
return False
async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry): async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry):
"""Initialize config entry which represents an installed SmartApp.""" """Initialize config entry which represents an installed SmartApp."""
from pysmartthings import SmartThings from pysmartthings import SmartThings
@ -62,6 +93,14 @@ async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry):
installed_app = await validate_installed_app( installed_app = await validate_installed_app(
api, entry.data[CONF_INSTALLED_APP_ID]) api, entry.data[CONF_INSTALLED_APP_ID])
# Get SmartApp token to sync subscriptions
token = await api.generate_tokens(
entry.data[CONF_OAUTH_CLIENT_ID],
entry.data[CONF_OAUTH_CLIENT_SECRET],
entry.data[CONF_REFRESH_TOKEN])
entry.data[CONF_REFRESH_TOKEN] = token.refresh_token
hass.config_entries.async_update_entry(entry)
# Get devices and their current status # Get devices and their current status
devices = await api.devices( devices = await api.devices(
location_ids=[installed_app.location_id]) location_ids=[installed_app.location_id])
@ -71,18 +110,21 @@ async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry):
await device.status.refresh() await device.status.refresh()
except ClientResponseError: except ClientResponseError:
_LOGGER.debug("Unable to update status for device: %s (%s), " _LOGGER.debug("Unable to update status for device: %s (%s), "
"the device will be ignored", "the device will be excluded",
device.label, device.device_id, exc_info=True) device.label, device.device_id, exc_info=True)
devices.remove(device) devices.remove(device)
await asyncio.gather(*[retrieve_device_status(d) await asyncio.gather(*[retrieve_device_status(d)
for d in devices.copy()]) for d in devices.copy()])
# Sync device subscriptions
await smartapp_sync_subscriptions(
hass, token.access_token, installed_app.location_id,
installed_app.installed_app_id, devices)
# Setup device broker # Setup device broker
broker = DeviceBroker(hass, devices, broker = DeviceBroker(hass, entry, token, smart_app, devices)
installed_app.installed_app_id) broker.connect()
broker.event_handler_disconnect = \
smart_app.connect_event(broker.event_handler)
hass.data[DOMAIN][DATA_BROKERS][entry.entry_id] = broker hass.data[DOMAIN][DATA_BROKERS][entry.entry_id] = broker
except ClientResponseError as ex: except ClientResponseError as ex:
@ -117,8 +159,8 @@ async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry):
async def async_unload_entry(hass: HomeAssistantType, entry: ConfigEntry): async def async_unload_entry(hass: HomeAssistantType, entry: ConfigEntry):
"""Unload a config entry.""" """Unload a config entry."""
broker = hass.data[DOMAIN][DATA_BROKERS].pop(entry.entry_id, None) broker = hass.data[DOMAIN][DATA_BROKERS].pop(entry.entry_id, None)
if broker and broker.event_handler_disconnect: if broker:
broker.event_handler_disconnect() broker.disconnect()
tasks = [hass.config_entries.async_forward_entry_unload(entry, component) tasks = [hass.config_entries.async_forward_entry_unload(entry, component)
for component in SUPPORTED_PLATFORMS] for component in SUPPORTED_PLATFORMS]
@ -128,14 +170,18 @@ async def async_unload_entry(hass: HomeAssistantType, entry: ConfigEntry):
class DeviceBroker: class DeviceBroker:
"""Manages an individual SmartThings config entry.""" """Manages an individual SmartThings config entry."""
def __init__(self, hass: HomeAssistantType, devices: Iterable, def __init__(self, hass: HomeAssistantType, entry: ConfigEntry,
installed_app_id: str): token, smart_app, devices: Iterable):
"""Create a new instance of the DeviceBroker.""" """Create a new instance of the DeviceBroker."""
self._hass = hass self._hass = hass
self._installed_app_id = installed_app_id self._entry = entry
self.assignments = self._assign_capabilities(devices) self._installed_app_id = entry.data[CONF_INSTALLED_APP_ID]
self._smart_app = smart_app
self._token = token
self._event_disconnect = None
self._regenerate_token_remove = None
self._assignments = self._assign_capabilities(devices)
self.devices = {device.device_id: device for device in devices} self.devices = {device.device_id: device for device in devices}
self.event_handler_disconnect = None
def _assign_capabilities(self, devices: Iterable): def _assign_capabilities(self, devices: Iterable):
"""Assign platforms to capabilities.""" """Assign platforms to capabilities."""
@ -158,17 +204,45 @@ class DeviceBroker:
assignments[device.device_id] = slots assignments[device.device_id] = slots
return assignments return assignments
def connect(self):
"""Connect handlers/listeners for device/lifecycle events."""
# Setup interval to regenerate the refresh token on a periodic basis.
# Tokens expire in 30 days and once expired, cannot be recovered.
async def regenerate_refresh_token(now):
"""Generate a new refresh token and update the config entry."""
await self._token.refresh(
self._entry.data[CONF_OAUTH_CLIENT_ID],
self._entry.data[CONF_OAUTH_CLIENT_SECRET])
self._entry.data[CONF_REFRESH_TOKEN] = self._token.refresh_token
self._hass.config_entries.async_update_entry(self._entry)
_LOGGER.debug('Regenerated refresh token for installed app: %s',
self._installed_app_id)
self._regenerate_token_remove = async_track_time_interval(
self._hass, regenerate_refresh_token, TOKEN_REFRESH_INTERVAL)
# Connect handler to incoming device events
self._event_disconnect = \
self._smart_app.connect_event(self._event_handler)
def disconnect(self):
"""Disconnects handlers/listeners for device/lifecycle events."""
if self._regenerate_token_remove:
self._regenerate_token_remove()
if self._event_disconnect:
self._event_disconnect()
def get_assigned(self, device_id: str, platform: str): def get_assigned(self, device_id: str, platform: str):
"""Get the capabilities assigned to the platform.""" """Get the capabilities assigned to the platform."""
slots = self.assignments.get(device_id, {}) slots = self._assignments.get(device_id, {})
return [key for key, value in slots.items() if value == platform] return [key for key, value in slots.items() if value == platform]
def any_assigned(self, device_id: str, platform: str): def any_assigned(self, device_id: str, platform: str):
"""Return True if the platform has any assigned capabilities.""" """Return True if the platform has any assigned capabilities."""
slots = self.assignments.get(device_id, {}) slots = self._assignments.get(device_id, {})
return any(value for value in slots.values() if value == platform) return any(value for value in slots.values() if value == platform)
async def event_handler(self, req, resp, app): async def _event_handler(self, req, resp, app):
"""Broker for incoming events.""" """Broker for incoming events."""
from pysmartapp.event import EVENT_TYPE_DEVICE from pysmartapp.event import EVENT_TYPE_DEVICE
from pysmartthings import Capability, Attribute from pysmartthings import Capability, Attribute

View file

@ -9,7 +9,8 @@ from homeassistant.const import CONF_ACCESS_TOKEN
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import ( from .const import (
CONF_APP_ID, CONF_INSTALLED_APP_ID, CONF_LOCATION_ID, DOMAIN, APP_OAUTH_CLIENT_NAME, APP_OAUTH_SCOPES, CONF_APP_ID, CONF_INSTALLED_APPS,
CONF_LOCATION_ID, CONF_OAUTH_CLIENT_ID, CONF_OAUTH_CLIENT_SECRET, DOMAIN,
VAL_UID_MATCHER) VAL_UID_MATCHER)
from .smartapp import ( from .smartapp import (
create_app, find_app, setup_smartapp, setup_smartapp_endpoint, update_app) create_app, find_app, setup_smartapp, setup_smartapp_endpoint, update_app)
@ -35,7 +36,7 @@ class SmartThingsFlowHandler(config_entries.ConfigFlow):
b) Config entries setup for all installations b) Config entries setup for all installations
""" """
VERSION = 1 VERSION = 2
CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_PUSH CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_PUSH
def __init__(self): def __init__(self):
@ -43,6 +44,8 @@ class SmartThingsFlowHandler(config_entries.ConfigFlow):
self.access_token = None self.access_token = None
self.app_id = None self.app_id = None
self.api = None self.api = None
self.oauth_client_secret = None
self.oauth_client_id = None
async def async_step_import(self, user_input=None): async def async_step_import(self, user_input=None):
"""Occurs when a previously entry setup fails and is re-initiated.""" """Occurs when a previously entry setup fails and is re-initiated."""
@ -50,7 +53,7 @@ class SmartThingsFlowHandler(config_entries.ConfigFlow):
async def async_step_user(self, user_input=None): async def async_step_user(self, user_input=None):
"""Get access token and validate it.""" """Get access token and validate it."""
from pysmartthings import APIResponseError, SmartThings from pysmartthings import APIResponseError, AppOAuth, SmartThings
errors = {} errors = {}
if not self.hass.config.api.base_url.lower().startswith('https://'): if not self.hass.config.api.base_url.lower().startswith('https://'):
@ -83,10 +86,18 @@ class SmartThingsFlowHandler(config_entries.ConfigFlow):
if app: if app:
await app.refresh() # load all attributes await app.refresh() # load all attributes
await update_app(self.hass, app) await update_app(self.hass, app)
# Get oauth client id/secret by regenerating it
app_oauth = AppOAuth(app.app_id)
app_oauth.client_name = APP_OAUTH_CLIENT_NAME
app_oauth.scope.extend(APP_OAUTH_SCOPES)
client = await self.api.generate_app_oauth(app_oauth)
else: else:
app = await create_app(self.hass, self.api) app, client = await create_app(self.hass, self.api)
setup_smartapp(self.hass, app) setup_smartapp(self.hass, app)
self.app_id = app.app_id self.app_id = app.app_id
self.oauth_client_secret = client.client_secret
self.oauth_client_id = client.client_id
except APIResponseError as ex: except APIResponseError as ex:
if ex.is_target_error(): if ex.is_target_error():
errors['base'] = 'webhook_error' errors['base'] = 'webhook_error'
@ -113,19 +124,23 @@ class SmartThingsFlowHandler(config_entries.ConfigFlow):
async def async_step_wait_install(self, user_input=None): async def async_step_wait_install(self, user_input=None):
"""Wait for SmartApp installation.""" """Wait for SmartApp installation."""
from pysmartthings import InstalledAppStatus
errors = {} errors = {}
if user_input is None: if user_input is None:
return self._show_step_wait_install(errors) return self._show_step_wait_install(errors)
# Find installed apps that were authorized # Find installed apps that were authorized
installed_apps = [app for app in await self.api.installed_apps( installed_apps = self.hass.data[DOMAIN][CONF_INSTALLED_APPS].copy()
installed_app_status=InstalledAppStatus.AUTHORIZED)
if app.app_id == self.app_id]
if not installed_apps: if not installed_apps:
errors['base'] = 'app_not_installed' errors['base'] = 'app_not_installed'
return self._show_step_wait_install(errors) return self._show_step_wait_install(errors)
self.hass.data[DOMAIN][CONF_INSTALLED_APPS].clear()
# Enrich the data
for installed_app in installed_apps:
installed_app[CONF_APP_ID] = self.app_id
installed_app[CONF_ACCESS_TOKEN] = self.access_token
installed_app[CONF_OAUTH_CLIENT_ID] = self.oauth_client_id
installed_app[CONF_OAUTH_CLIENT_SECRET] = self.oauth_client_secret
# User may have installed the SmartApp in more than one SmartThings # User may have installed the SmartApp in more than one SmartThings
# location. Config flows are created for the additional installations # location. Config flows are created for the additional installations
@ -133,21 +148,10 @@ class SmartThingsFlowHandler(config_entries.ConfigFlow):
self.hass.async_create_task( self.hass.async_create_task(
self.hass.config_entries.flow.async_init( self.hass.config_entries.flow.async_init(
DOMAIN, context={'source': 'install'}, DOMAIN, context={'source': 'install'},
data={ data=installed_app))
CONF_APP_ID: installed_app.app_id,
CONF_INSTALLED_APP_ID: installed_app.installed_app_id,
CONF_LOCATION_ID: installed_app.location_id,
CONF_ACCESS_TOKEN: self.access_token
}))
# return entity for the first one. # Create config entity for the first one.
installed_app = installed_apps[0] return await self.async_step_install(installed_apps[0])
return await self.async_step_install({
CONF_APP_ID: installed_app.app_id,
CONF_INSTALLED_APP_ID: installed_app.installed_app_id,
CONF_LOCATION_ID: installed_app.location_id,
CONF_ACCESS_TOKEN: self.access_token
})
def _show_step_user(self, errors): def _show_step_user(self, errors):
return self.async_show_form( return self.async_show_form(

View file

@ -1,14 +1,20 @@
"""Constants used by the SmartThings component and platforms.""" """Constants used by the SmartThings component and platforms."""
from datetime import timedelta
import re import re
APP_OAUTH_CLIENT_NAME = "Home Assistant"
APP_OAUTH_SCOPES = [ APP_OAUTH_SCOPES = [
'r:devices:*' 'r:devices:*'
] ]
APP_NAME_PREFIX = 'homeassistant.' APP_NAME_PREFIX = 'homeassistant.'
CONF_APP_ID = 'app_id' CONF_APP_ID = 'app_id'
CONF_INSTALLED_APP_ID = 'installed_app_id' CONF_INSTALLED_APP_ID = 'installed_app_id'
CONF_INSTALLED_APPS = 'installed_apps'
CONF_INSTANCE_ID = 'instance_id' CONF_INSTANCE_ID = 'instance_id'
CONF_LOCATION_ID = 'location_id' CONF_LOCATION_ID = 'location_id'
CONF_OAUTH_CLIENT_ID = 'client_id'
CONF_OAUTH_CLIENT_SECRET = 'client_secret'
CONF_REFRESH_TOKEN = 'refresh_token'
DATA_MANAGER = 'manager' DATA_MANAGER = 'manager'
DATA_BROKERS = 'brokers' DATA_BROKERS = 'brokers'
DOMAIN = 'smartthings' DOMAIN = 'smartthings'
@ -29,6 +35,7 @@ SUPPORTED_PLATFORMS = [
'binary_sensor', 'binary_sensor',
'sensor' 'sensor'
] ]
TOKEN_REFRESH_INTERVAL = timedelta(days=14)
VAL_UID = "^(?:([0-9a-fA-F]{32})|([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]" \ VAL_UID = "^(?:([0-9a-fA-F]{32})|([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]" \
"{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}))$" "{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}))$"
VAL_UID_MATCHER = re.compile(VAL_UID) VAL_UID_MATCHER = re.compile(VAL_UID)

View file

@ -13,15 +13,16 @@ from uuid import uuid4
from aiohttp import web from aiohttp import web
from homeassistant.components import webhook from homeassistant.components import webhook
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_WEBHOOK_ID from homeassistant.const import CONF_WEBHOOK_ID
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_send) async_dispatcher_connect, async_dispatcher_send)
from homeassistant.helpers.typing import HomeAssistantType from homeassistant.helpers.typing import HomeAssistantType
from .const import ( from .const import (
APP_NAME_PREFIX, APP_OAUTH_SCOPES, CONF_APP_ID, CONF_INSTALLED_APP_ID, APP_NAME_PREFIX, APP_OAUTH_CLIENT_NAME, APP_OAUTH_SCOPES, CONF_APP_ID,
CONF_INSTANCE_ID, CONF_LOCATION_ID, DATA_BROKERS, DATA_MANAGER, DOMAIN, CONF_INSTALLED_APP_ID, CONF_INSTALLED_APPS, CONF_INSTANCE_ID,
CONF_LOCATION_ID, CONF_REFRESH_TOKEN, DATA_BROKERS, DATA_MANAGER, DOMAIN,
SETTINGS_INSTANCE_ID, SIGNAL_SMARTAPP_PREFIX, STORAGE_KEY, STORAGE_VERSION) SETTINGS_INSTANCE_ID, SIGNAL_SMARTAPP_PREFIX, STORAGE_KEY, STORAGE_VERSION)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -83,7 +84,7 @@ async def create_app(hass: HomeAssistantType, api):
app = App() app = App()
for key, value in template.items(): for key, value in template.items():
setattr(app, key, value) setattr(app, key, value)
app = (await api.create_app(app))[0] app, client = await api.create_app(app)
_LOGGER.debug("Created SmartApp '%s' (%s)", app.app_name, app.app_id) _LOGGER.debug("Created SmartApp '%s' (%s)", app.app_name, app.app_id)
# Set unique hass id in settings # Set unique hass id in settings
@ -97,12 +98,12 @@ async def create_app(hass: HomeAssistantType, api):
# Set oauth scopes # Set oauth scopes
oauth = AppOAuth(app.app_id) oauth = AppOAuth(app.app_id)
oauth.client_name = 'Home Assistant' oauth.client_name = APP_OAUTH_CLIENT_NAME
oauth.scope.extend(APP_OAUTH_SCOPES) oauth.scope.extend(APP_OAUTH_SCOPES)
await api.update_app_oauth(oauth) await api.update_app_oauth(oauth)
_LOGGER.debug("Updated App OAuth for SmartApp '%s' (%s)", _LOGGER.debug("Updated App OAuth for SmartApp '%s' (%s)",
app.app_name, app.app_id) app.app_name, app.app_id)
return app return app, client
async def update_app(hass: HomeAssistantType, app): async def update_app(hass: HomeAssistantType, app):
@ -185,32 +186,24 @@ async def setup_smartapp_endpoint(hass: HomeAssistantType):
DATA_MANAGER: manager, DATA_MANAGER: manager,
CONF_INSTANCE_ID: config[CONF_INSTANCE_ID], CONF_INSTANCE_ID: config[CONF_INSTANCE_ID],
DATA_BROKERS: {}, DATA_BROKERS: {},
CONF_WEBHOOK_ID: config[CONF_WEBHOOK_ID] CONF_WEBHOOK_ID: config[CONF_WEBHOOK_ID],
CONF_INSTALLED_APPS: []
} }
async def smartapp_sync_subscriptions( async def smartapp_sync_subscriptions(
hass: HomeAssistantType, auth_token: str, location_id: str, hass: HomeAssistantType, auth_token: str, location_id: str,
installed_app_id: str, *, skip_delete=False): installed_app_id: str, devices):
"""Synchronize subscriptions of an installed up.""" """Synchronize subscriptions of an installed up."""
from pysmartthings import ( from pysmartthings import (
CAPABILITIES, SmartThings, SourceType, Subscription) CAPABILITIES, SmartThings, SourceType, Subscription,
SubscriptionEntity
)
api = SmartThings(async_get_clientsession(hass), auth_token) api = SmartThings(async_get_clientsession(hass), auth_token)
devices = await api.devices(location_ids=[location_id]) tasks = []
# Build set of capabilities and prune unsupported ones async def create_subscription(target: str):
capabilities = set()
for device in devices:
capabilities.update(device.capabilities)
capabilities.intersection_update(CAPABILITIES)
# Remove all (except for installs)
if not skip_delete:
await api.delete_subscriptions(installed_app_id)
# Create for each capability
async def create_subscription(target):
sub = Subscription() sub = Subscription()
sub.installed_app_id = installed_app_id sub.installed_app_id = installed_app_id
sub.location_id = location_id sub.location_id = location_id
@ -224,52 +217,89 @@ async def smartapp_sync_subscriptions(
_LOGGER.exception("Failed to create subscription for '%s' under " _LOGGER.exception("Failed to create subscription for '%s' under "
"app '%s'", target, installed_app_id) "app '%s'", target, installed_app_id)
tasks = [create_subscription(c) for c in capabilities] async def delete_subscription(sub: SubscriptionEntity):
await asyncio.gather(*tasks) try:
await api.delete_subscription(
installed_app_id, sub.subscription_id)
_LOGGER.debug("Removed subscription for '%s' under app '%s' "
"because it was no longer needed",
sub.capability, installed_app_id)
except Exception: # pylint:disable=broad-except
_LOGGER.exception("Failed to remove subscription for '%s' under "
"app '%s'", sub.capability, installed_app_id)
# Build set of capabilities and prune unsupported ones
capabilities = set()
for device in devices:
capabilities.update(device.capabilities)
capabilities.intersection_update(CAPABILITIES)
# Get current subscriptions and find differences
subscriptions = await api.subscriptions(installed_app_id)
for subscription in subscriptions:
if subscription.capability in capabilities:
capabilities.remove(subscription.capability)
else:
# Delete the subscription
tasks.append(delete_subscription(subscription))
# Remaining capabilities need subscriptions created
tasks.extend([create_subscription(c) for c in capabilities])
if tasks:
await asyncio.gather(*tasks)
else:
_LOGGER.debug("Subscriptions for app '%s' are up-to-date",
installed_app_id)
async def smartapp_install(hass: HomeAssistantType, req, resp, app): async def smartapp_install(hass: HomeAssistantType, req, resp, app):
""" """
Handle when a SmartApp is installed by the user into a location. Handle when a SmartApp is installed by the user into a location.
Setup subscriptions using the access token SmartThings provided in the Create a config entry representing the installation if this is not
event. An explicit subscription is required for each 'capability' in order the first installation under the account, otherwise store the data
to receive the related attribute updates. Finally, create a config entry for the config flow.
representing the installation if this is not the first installation under
the account.
""" """
await smartapp_sync_subscriptions( install_data = {
hass, req.auth_token, req.location_id, req.installed_app_id, CONF_INSTALLED_APP_ID: req.installed_app_id,
skip_delete=True) CONF_LOCATION_ID: req.location_id,
CONF_REFRESH_TOKEN: req.refresh_token
# The permanent access token is copied from another config flow with the }
# same parent app_id. If one is not found, that means the user is within # App attributes (client id/secret, etc...) are copied from another entry
# the initial config flow and the entry at the conclusion. # with the same parent app_id. If one is not found, the install data is
access_token = next(( # stored for the config flow to retrieve during the wait step.
entry.data.get(CONF_ACCESS_TOKEN) for entry entry = next((
entry for entry
in hass.config_entries.async_entries(DOMAIN) in hass.config_entries.async_entries(DOMAIN)
if entry.data[CONF_APP_ID] == app.app_id), None) if entry.data[CONF_APP_ID] == app.app_id), None)
if access_token: if entry:
data = entry.data.copy()
data.update(install_data)
# Add as job not needed because the current coroutine was invoked # Add as job not needed because the current coroutine was invoked
# from the dispatcher and is not being awaited. # from the dispatcher and is not being awaited.
await hass.config_entries.flow.async_init( await hass.config_entries.flow.async_init(
DOMAIN, context={'source': 'install'}, DOMAIN, context={'source': 'install'},
data={ data=data)
CONF_APP_ID: app.app_id, else:
CONF_INSTALLED_APP_ID: req.installed_app_id, # Store the data where the flow can find it
CONF_LOCATION_ID: req.location_id, hass.data[DOMAIN][CONF_INSTALLED_APPS].append(install_data)
CONF_ACCESS_TOKEN: access_token
})
async def smartapp_update(hass: HomeAssistantType, req, resp, app): async def smartapp_update(hass: HomeAssistantType, req, resp, app):
""" """
Handle when a SmartApp is updated (reconfigured) by the user. Handle when a SmartApp is updated (reconfigured) by the user.
Synchronize subscriptions to ensure we're up-to-date. Store the refresh token in the config entry.
""" """
await smartapp_sync_subscriptions( # Update refresh token in config entry
hass, req.auth_token, req.location_id, req.installed_app_id) entry = next((entry for entry in hass.config_entries.async_entries(DOMAIN)
if entry.data.get(CONF_INSTALLED_APP_ID) ==
req.installed_app_id),
None)
if entry:
entry.data[CONF_REFRESH_TOKEN] = req.refresh_token
hass.config_entries.async_update_entry(entry)
_LOGGER.debug("SmartApp '%s' under parent app '%s' was updated", _LOGGER.debug("SmartApp '%s' under parent app '%s' was updated",
req.installed_app_id, app.app_id) req.installed_app_id, app.app_id)

View file

@ -1252,7 +1252,7 @@ pysma==0.3.1
pysmartapp==0.3.0 pysmartapp==0.3.0
# homeassistant.components.smartthings # homeassistant.components.smartthings
pysmartthings==0.6.2 pysmartthings==0.6.3
# homeassistant.components.device_tracker.snmp # homeassistant.components.device_tracker.snmp
# homeassistant.components.sensor.snmp # homeassistant.components.sensor.snmp

View file

@ -223,7 +223,7 @@ pyqwikswitch==0.8
pysmartapp==0.3.0 pysmartapp==0.3.0
# homeassistant.components.smartthings # homeassistant.components.smartthings
pysmartthings==0.6.2 pysmartthings==0.6.3
# homeassistant.components.sonos # homeassistant.components.sonos
pysonos==0.0.6 pysonos==0.0.6

View file

@ -4,8 +4,8 @@ from unittest.mock import Mock, patch
from uuid import uuid4 from uuid import uuid4
from pysmartthings import ( from pysmartthings import (
CLASSIFICATION_AUTOMATION, AppEntity, AppSettings, DeviceEntity, CLASSIFICATION_AUTOMATION, AppEntity, AppOAuthClient, AppSettings,
InstalledApp, Location) DeviceEntity, InstalledApp, Location, Subscription)
from pysmartthings.api import Api from pysmartthings.api import Api
import pytest import pytest
@ -13,8 +13,9 @@ from homeassistant.components import webhook
from homeassistant.components.smartthings import DeviceBroker from homeassistant.components.smartthings import DeviceBroker
from homeassistant.components.smartthings.const import ( from homeassistant.components.smartthings.const import (
APP_NAME_PREFIX, CONF_APP_ID, CONF_INSTALLED_APP_ID, CONF_INSTANCE_ID, APP_NAME_PREFIX, CONF_APP_ID, CONF_INSTALLED_APP_ID, CONF_INSTANCE_ID,
CONF_LOCATION_ID, DATA_BROKERS, DOMAIN, SETTINGS_INSTANCE_ID, STORAGE_KEY, CONF_LOCATION_ID, CONF_OAUTH_CLIENT_ID, CONF_OAUTH_CLIENT_SECRET,
STORAGE_VERSION) CONF_REFRESH_TOKEN, DATA_BROKERS, DOMAIN, SETTINGS_INSTANCE_ID,
STORAGE_KEY, STORAGE_VERSION)
from homeassistant.config_entries import ( from homeassistant.config_entries import (
CONN_CLASS_CLOUD_PUSH, SOURCE_USER, ConfigEntry) CONN_CLASS_CLOUD_PUSH, SOURCE_USER, ConfigEntry)
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_WEBHOOK_ID from homeassistant.const import CONF_ACCESS_TOKEN, CONF_WEBHOOK_ID
@ -26,9 +27,11 @@ from tests.common import mock_coro
async def setup_platform(hass, platform: str, *devices): async def setup_platform(hass, platform: str, *devices):
"""Set up the SmartThings platform and prerequisites.""" """Set up the SmartThings platform and prerequisites."""
hass.config.components.add(DOMAIN) hass.config.components.add(DOMAIN)
broker = DeviceBroker(hass, devices, '') config_entry = ConfigEntry(2, DOMAIN, "Test",
config_entry = ConfigEntry("1", DOMAIN, "Test", {}, {CONF_INSTALLED_APP_ID: str(uuid4())},
SOURCE_USER, CONN_CLASS_CLOUD_PUSH) SOURCE_USER, CONN_CLASS_CLOUD_PUSH)
broker = DeviceBroker(hass, config_entry, Mock(), Mock(), devices)
hass.data[DOMAIN] = { hass.data[DOMAIN] = {
DATA_BROKERS: { DATA_BROKERS: {
config_entry.entry_id: broker config_entry.entry_id: broker
@ -98,6 +101,15 @@ def app_fixture(hass, config_file):
return app return app
@pytest.fixture(name="app_oauth_client")
def app_oauth_client_fixture():
"""Fixture for a single app's oauth."""
return AppOAuthClient({
'oauthClientId': str(uuid4()),
'oauthClientSecret': str(uuid4())
})
@pytest.fixture(name='app_settings') @pytest.fixture(name='app_settings')
def app_settings_fixture(app, config_file): def app_settings_fixture(app, config_file):
"""Fixture for an app settings.""" """Fixture for an app settings."""
@ -225,12 +237,25 @@ def config_entry_fixture(hass, installed_app, location):
CONF_ACCESS_TOKEN: str(uuid4()), CONF_ACCESS_TOKEN: str(uuid4()),
CONF_INSTALLED_APP_ID: installed_app.installed_app_id, CONF_INSTALLED_APP_ID: installed_app.installed_app_id,
CONF_APP_ID: installed_app.app_id, CONF_APP_ID: installed_app.app_id,
CONF_LOCATION_ID: location.location_id CONF_LOCATION_ID: location.location_id,
CONF_REFRESH_TOKEN: str(uuid4()),
CONF_OAUTH_CLIENT_ID: str(uuid4()),
CONF_OAUTH_CLIENT_SECRET: str(uuid4())
} }
return ConfigEntry("1", DOMAIN, location.name, data, SOURCE_USER, return ConfigEntry(2, DOMAIN, location.name, data, SOURCE_USER,
CONN_CLASS_CLOUD_PUSH) CONN_CLASS_CLOUD_PUSH)
@pytest.fixture(name="subscription_factory")
def subscription_factory_fixture():
"""Fixture for creating mock subscriptions."""
def _factory(capability):
sub = Subscription()
sub.capability = capability
return sub
return _factory
@pytest.fixture(name="device_factory") @pytest.fixture(name="device_factory")
def device_factory_fixture(): def device_factory_fixture():
"""Fixture for creating mock devices.""" """Fixture for creating mock devices."""

View file

@ -6,31 +6,15 @@ real HTTP calls are not initiated during testing.
""" """
from pysmartthings import ATTRIBUTES, CAPABILITIES, Attribute, Capability from pysmartthings import ATTRIBUTES, CAPABILITIES, Attribute, Capability
from homeassistant.components.binary_sensor import DEVICE_CLASSES from homeassistant.components.binary_sensor import (
from homeassistant.components.smartthings import DeviceBroker, binary_sensor DEVICE_CLASSES, DOMAIN as BINARY_SENSOR_DOMAIN)
from homeassistant.components.smartthings import binary_sensor
from homeassistant.components.smartthings.const import ( from homeassistant.components.smartthings.const import (
DATA_BROKERS, DOMAIN, SIGNAL_SMARTTHINGS_UPDATE) DOMAIN, SIGNAL_SMARTTHINGS_UPDATE)
from homeassistant.config_entries import (
CONN_CLASS_CLOUD_PUSH, SOURCE_USER, ConfigEntry)
from homeassistant.const import ATTR_FRIENDLY_NAME from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from .conftest import setup_platform
async def _setup_platform(hass, *devices):
"""Set up the SmartThings binary_sensor platform and prerequisites."""
hass.config.components.add(DOMAIN)
broker = DeviceBroker(hass, devices, '')
config_entry = ConfigEntry("1", DOMAIN, "Test", {},
SOURCE_USER, CONN_CLASS_CLOUD_PUSH)
hass.data[DOMAIN] = {
DATA_BROKERS: {
config_entry.entry_id: broker
}
}
await hass.config_entries.async_forward_entry_setup(
config_entry, 'binary_sensor')
await hass.async_block_till_done()
return config_entry
async def test_mapping_integrity(): async def test_mapping_integrity():
@ -56,7 +40,7 @@ async def test_entity_state(hass, device_factory):
"""Tests the state attributes properly match the light types.""" """Tests the state attributes properly match the light types."""
device = device_factory('Motion Sensor 1', [Capability.motion_sensor], device = device_factory('Motion Sensor 1', [Capability.motion_sensor],
{Attribute.motion: 'inactive'}) {Attribute.motion: 'inactive'})
await _setup_platform(hass, device) await setup_platform(hass, BINARY_SENSOR_DOMAIN, device)
state = hass.states.get('binary_sensor.motion_sensor_1_motion') state = hass.states.get('binary_sensor.motion_sensor_1_motion')
assert state.state == 'off' assert state.state == 'off'
assert state.attributes[ATTR_FRIENDLY_NAME] ==\ assert state.attributes[ATTR_FRIENDLY_NAME] ==\
@ -71,7 +55,7 @@ async def test_entity_and_device_attributes(hass, device_factory):
entity_registry = await hass.helpers.entity_registry.async_get_registry() entity_registry = await hass.helpers.entity_registry.async_get_registry()
device_registry = await hass.helpers.device_registry.async_get_registry() device_registry = await hass.helpers.device_registry.async_get_registry()
# Act # Act
await _setup_platform(hass, device) await setup_platform(hass, BINARY_SENSOR_DOMAIN, device)
# Assert # Assert
entry = entity_registry.async_get('binary_sensor.motion_sensor_1_motion') entry = entity_registry.async_get('binary_sensor.motion_sensor_1_motion')
assert entry assert entry
@ -89,7 +73,7 @@ async def test_update_from_signal(hass, device_factory):
# Arrange # Arrange
device = device_factory('Motion Sensor 1', [Capability.motion_sensor], device = device_factory('Motion Sensor 1', [Capability.motion_sensor],
{Attribute.motion: 'inactive'}) {Attribute.motion: 'inactive'})
await _setup_platform(hass, device) await setup_platform(hass, BINARY_SENSOR_DOMAIN, device)
device.status.apply_attribute_update( device.status.apply_attribute_update(
'main', Capability.motion_sensor, Attribute.motion, 'active') 'main', Capability.motion_sensor, Attribute.motion, 'active')
# Act # Act
@ -107,7 +91,7 @@ async def test_unload_config_entry(hass, device_factory):
# Arrange # Arrange
device = device_factory('Motion Sensor 1', [Capability.motion_sensor], device = device_factory('Motion Sensor 1', [Capability.motion_sensor],
{Attribute.motion: 'inactive'}) {Attribute.motion: 'inactive'})
config_entry = await _setup_platform(hass, device) config_entry = await setup_platform(hass, BINARY_SENSOR_DOMAIN, device)
# Act # Act
await hass.config_entries.async_forward_entry_unload( await hass.config_entries.async_forward_entry_unload(
config_entry, 'binary_sensor') config_entry, 'binary_sensor')

View file

@ -8,6 +8,9 @@ from pysmartthings import APIResponseError
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.components.smartthings.config_flow import ( from homeassistant.components.smartthings.config_flow import (
SmartThingsFlowHandler) SmartThingsFlowHandler)
from homeassistant.components.smartthings.const import (
CONF_INSTALLED_APP_ID, CONF_INSTALLED_APPS, CONF_LOCATION_ID,
CONF_REFRESH_TOKEN, DOMAIN)
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from tests.common import mock_coro from tests.common import mock_coro
@ -171,14 +174,16 @@ async def test_unknown_error(hass, smartthings_mock):
assert result['errors'] == {'base': 'app_setup_error'} assert result['errors'] == {'base': 'app_setup_error'}
async def test_app_created_then_show_wait_form(hass, app, smartthings_mock): async def test_app_created_then_show_wait_form(
hass, app, app_oauth_client, smartthings_mock):
"""Test SmartApp is created when one does not exist and shows wait form.""" """Test SmartApp is created when one does not exist and shows wait form."""
flow = SmartThingsFlowHandler() flow = SmartThingsFlowHandler()
flow.hass = hass flow.hass = hass
smartthings = smartthings_mock.return_value smartthings = smartthings_mock.return_value
smartthings.apps.return_value = mock_coro(return_value=[]) smartthings.apps.return_value = mock_coro(return_value=[])
smartthings.create_app.return_value = mock_coro(return_value=(app, None)) smartthings.create_app.return_value = \
mock_coro(return_value=(app, app_oauth_client))
smartthings.update_app_settings.return_value = mock_coro() smartthings.update_app_settings.return_value = mock_coro()
smartthings.update_app_oauth.return_value = mock_coro() smartthings.update_app_oauth.return_value = mock_coro()
@ -189,13 +194,15 @@ async def test_app_created_then_show_wait_form(hass, app, smartthings_mock):
async def test_app_updated_then_show_wait_form( async def test_app_updated_then_show_wait_form(
hass, app, smartthings_mock): hass, app, app_oauth_client, smartthings_mock):
"""Test SmartApp is updated when an existing is already created.""" """Test SmartApp is updated when an existing is already created."""
flow = SmartThingsFlowHandler() flow = SmartThingsFlowHandler()
flow.hass = hass flow.hass = hass
api = smartthings_mock.return_value api = smartthings_mock.return_value
api.apps.return_value = mock_coro(return_value=[app]) api.apps.return_value = mock_coro(return_value=[app])
api.generate_app_oauth.return_value = \
mock_coro(return_value=app_oauth_client)
result = await flow.async_step_user({'access_token': str(uuid4())}) result = await flow.async_step_user({'access_token': str(uuid4())})
@ -219,8 +226,6 @@ async def test_wait_form_displayed_after_checking(hass, smartthings_mock):
flow = SmartThingsFlowHandler() flow = SmartThingsFlowHandler()
flow.hass = hass flow.hass = hass
flow.access_token = str(uuid4()) flow.access_token = str(uuid4())
flow.api = smartthings_mock.return_value
flow.api.installed_apps.return_value = mock_coro(return_value=[])
result = await flow.async_step_wait_install({}) result = await flow.async_step_wait_install({})
@ -235,19 +240,29 @@ async def test_config_entry_created_when_installed(
flow = SmartThingsFlowHandler() flow = SmartThingsFlowHandler()
flow.hass = hass flow.hass = hass
flow.access_token = str(uuid4()) flow.access_token = str(uuid4())
flow.api = smartthings_mock.return_value
flow.app_id = installed_app.app_id flow.app_id = installed_app.app_id
flow.api.installed_apps.return_value = \ flow.api = smartthings_mock.return_value
mock_coro(return_value=[installed_app]) flow.oauth_client_id = str(uuid4())
flow.oauth_client_secret = str(uuid4())
data = {
CONF_REFRESH_TOKEN: str(uuid4()),
CONF_LOCATION_ID: installed_app.location_id,
CONF_INSTALLED_APP_ID: installed_app.installed_app_id
}
hass.data[DOMAIN][CONF_INSTALLED_APPS].append(data)
result = await flow.async_step_wait_install({}) result = await flow.async_step_wait_install({})
assert not hass.data[DOMAIN][CONF_INSTALLED_APPS]
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result['data']['app_id'] == installed_app.app_id assert result['data']['app_id'] == installed_app.app_id
assert result['data']['installed_app_id'] == \ assert result['data']['installed_app_id'] == \
installed_app.installed_app_id installed_app.installed_app_id
assert result['data']['location_id'] == installed_app.location_id assert result['data']['location_id'] == installed_app.location_id
assert result['data']['access_token'] == flow.access_token assert result['data']['access_token'] == flow.access_token
assert result['data']['refresh_token'] == data[CONF_REFRESH_TOKEN]
assert result['data']['client_secret'] == flow.oauth_client_secret
assert result['data']['client_id'] == flow.oauth_client_id
assert result['title'] == location.name assert result['title'] == location.name
@ -259,17 +274,31 @@ async def test_multiple_config_entry_created_when_installed(
flow.access_token = str(uuid4()) flow.access_token = str(uuid4())
flow.app_id = app.app_id flow.app_id = app.app_id
flow.api = smartthings_mock.return_value flow.api = smartthings_mock.return_value
flow.api.installed_apps.return_value = \ flow.oauth_client_id = str(uuid4())
mock_coro(return_value=installed_apps) flow.oauth_client_secret = str(uuid4())
for installed_app in installed_apps:
data = {
CONF_REFRESH_TOKEN: str(uuid4()),
CONF_LOCATION_ID: installed_app.location_id,
CONF_INSTALLED_APP_ID: installed_app.installed_app_id
}
hass.data[DOMAIN][CONF_INSTALLED_APPS].append(data)
install_data = hass.data[DOMAIN][CONF_INSTALLED_APPS].copy()
result = await flow.async_step_wait_install({}) result = await flow.async_step_wait_install({})
assert not hass.data[DOMAIN][CONF_INSTALLED_APPS]
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result['data']['app_id'] == installed_apps[0].app_id assert result['data']['app_id'] == installed_apps[0].app_id
assert result['data']['installed_app_id'] == \ assert result['data']['installed_app_id'] == \
installed_apps[0].installed_app_id installed_apps[0].installed_app_id
assert result['data']['location_id'] == installed_apps[0].location_id assert result['data']['location_id'] == installed_apps[0].location_id
assert result['data']['access_token'] == flow.access_token assert result['data']['access_token'] == flow.access_token
assert result['data']['refresh_token'] == \
install_data[0][CONF_REFRESH_TOKEN]
assert result['data']['client_secret'] == flow.oauth_client_secret
assert result['data']['client_id'] == flow.oauth_client_id
assert result['title'] == locations[0].name assert result['title'] == locations[0].name
await hass.async_block_till_done() await hass.async_block_till_done()
@ -280,4 +309,6 @@ async def test_multiple_config_entry_created_when_installed(
installed_apps[1].installed_app_id installed_apps[1].installed_app_id
assert entries[0].data['location_id'] == installed_apps[1].location_id assert entries[0].data['location_id'] == installed_apps[1].location_id
assert entries[0].data['access_token'] == flow.access_token assert entries[0].data['access_token'] == flow.access_token
assert entries[0].data['client_secret'] == flow.oauth_client_secret
assert entries[0].data['client_id'] == flow.oauth_client_id
assert entries[0].title == locations[1].name assert entries[0].title == locations[1].name

View file

@ -7,31 +7,15 @@ real HTTP calls are not initiated during testing.
from pysmartthings import Attribute, Capability from pysmartthings import Attribute, Capability
from homeassistant.components.fan import ( from homeassistant.components.fan import (
ATTR_SPEED, ATTR_SPEED_LIST, SPEED_HIGH, SPEED_LOW, SPEED_MEDIUM, ATTR_SPEED, ATTR_SPEED_LIST, DOMAIN as FAN_DOMAIN, SPEED_HIGH, SPEED_LOW,
SPEED_OFF, SUPPORT_SET_SPEED) SPEED_MEDIUM, SPEED_OFF, SUPPORT_SET_SPEED)
from homeassistant.components.smartthings import DeviceBroker, fan from homeassistant.components.smartthings import fan
from homeassistant.components.smartthings.const import ( from homeassistant.components.smartthings.const import (
DATA_BROKERS, DOMAIN, SIGNAL_SMARTTHINGS_UPDATE) DOMAIN, SIGNAL_SMARTTHINGS_UPDATE)
from homeassistant.config_entries import (
CONN_CLASS_CLOUD_PUSH, SOURCE_USER, ConfigEntry)
from homeassistant.const import ATTR_ENTITY_ID, ATTR_SUPPORTED_FEATURES from homeassistant.const import ATTR_ENTITY_ID, ATTR_SUPPORTED_FEATURES
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from .conftest import setup_platform
async def _setup_platform(hass, *devices):
"""Set up the SmartThings fan platform and prerequisites."""
hass.config.components.add(DOMAIN)
broker = DeviceBroker(hass, devices, '')
config_entry = ConfigEntry("1", DOMAIN, "Test", {},
SOURCE_USER, CONN_CLASS_CLOUD_PUSH)
hass.data[DOMAIN] = {
DATA_BROKERS: {
config_entry.entry_id: broker
}
}
await hass.config_entries.async_forward_entry_setup(config_entry, 'fan')
await hass.async_block_till_done()
return config_entry
async def test_async_setup_platform(): async def test_async_setup_platform():
@ -45,7 +29,7 @@ async def test_entity_state(hass, device_factory):
"Fan 1", "Fan 1",
capabilities=[Capability.switch, Capability.fan_speed], capabilities=[Capability.switch, Capability.fan_speed],
status={Attribute.switch: 'on', Attribute.fan_speed: 2}) status={Attribute.switch: 'on', Attribute.fan_speed: 2})
await _setup_platform(hass, device) await setup_platform(hass, FAN_DOMAIN, device)
# Dimmer 1 # Dimmer 1
state = hass.states.get('fan.fan_1') state = hass.states.get('fan.fan_1')
@ -63,11 +47,10 @@ async def test_entity_and_device_attributes(hass, device_factory):
"Fan 1", "Fan 1",
capabilities=[Capability.switch, Capability.fan_speed], capabilities=[Capability.switch, Capability.fan_speed],
status={Attribute.switch: 'on', Attribute.fan_speed: 2}) status={Attribute.switch: 'on', Attribute.fan_speed: 2})
await _setup_platform(hass, device) # Act
await setup_platform(hass, FAN_DOMAIN, device)
entity_registry = await hass.helpers.entity_registry.async_get_registry() entity_registry = await hass.helpers.entity_registry.async_get_registry()
device_registry = await hass.helpers.device_registry.async_get_registry() device_registry = await hass.helpers.device_registry.async_get_registry()
# Act
await _setup_platform(hass, device)
# Assert # Assert
entry = entity_registry.async_get("fan.fan_1") entry = entity_registry.async_get("fan.fan_1")
assert entry assert entry
@ -88,7 +71,7 @@ async def test_turn_off(hass, device_factory):
"Fan 1", "Fan 1",
capabilities=[Capability.switch, Capability.fan_speed], capabilities=[Capability.switch, Capability.fan_speed],
status={Attribute.switch: 'on', Attribute.fan_speed: 2}) status={Attribute.switch: 'on', Attribute.fan_speed: 2})
await _setup_platform(hass, device) await setup_platform(hass, FAN_DOMAIN, device)
# Act # Act
await hass.services.async_call( await hass.services.async_call(
'fan', 'turn_off', {'entity_id': 'fan.fan_1'}, 'fan', 'turn_off', {'entity_id': 'fan.fan_1'},
@ -106,7 +89,7 @@ async def test_turn_on(hass, device_factory):
"Fan 1", "Fan 1",
capabilities=[Capability.switch, Capability.fan_speed], capabilities=[Capability.switch, Capability.fan_speed],
status={Attribute.switch: 'off', Attribute.fan_speed: 0}) status={Attribute.switch: 'off', Attribute.fan_speed: 0})
await _setup_platform(hass, device) await setup_platform(hass, FAN_DOMAIN, device)
# Act # Act
await hass.services.async_call( await hass.services.async_call(
'fan', 'turn_on', {ATTR_ENTITY_ID: "fan.fan_1"}, 'fan', 'turn_on', {ATTR_ENTITY_ID: "fan.fan_1"},
@ -124,7 +107,7 @@ async def test_turn_on_with_speed(hass, device_factory):
"Fan 1", "Fan 1",
capabilities=[Capability.switch, Capability.fan_speed], capabilities=[Capability.switch, Capability.fan_speed],
status={Attribute.switch: 'off', Attribute.fan_speed: 0}) status={Attribute.switch: 'off', Attribute.fan_speed: 0})
await _setup_platform(hass, device) await setup_platform(hass, FAN_DOMAIN, device)
# Act # Act
await hass.services.async_call( await hass.services.async_call(
'fan', 'turn_on', 'fan', 'turn_on',
@ -145,7 +128,7 @@ async def test_set_speed(hass, device_factory):
"Fan 1", "Fan 1",
capabilities=[Capability.switch, Capability.fan_speed], capabilities=[Capability.switch, Capability.fan_speed],
status={Attribute.switch: 'off', Attribute.fan_speed: 0}) status={Attribute.switch: 'off', Attribute.fan_speed: 0})
await _setup_platform(hass, device) await setup_platform(hass, FAN_DOMAIN, device)
# Act # Act
await hass.services.async_call( await hass.services.async_call(
'fan', 'set_speed', 'fan', 'set_speed',
@ -166,7 +149,7 @@ async def test_update_from_signal(hass, device_factory):
"Fan 1", "Fan 1",
capabilities=[Capability.switch, Capability.fan_speed], capabilities=[Capability.switch, Capability.fan_speed],
status={Attribute.switch: 'off', Attribute.fan_speed: 0}) status={Attribute.switch: 'off', Attribute.fan_speed: 0})
await _setup_platform(hass, device) await setup_platform(hass, FAN_DOMAIN, device)
await device.switch_on(True) await device.switch_on(True)
# Act # Act
async_dispatcher_send(hass, SIGNAL_SMARTTHINGS_UPDATE, async_dispatcher_send(hass, SIGNAL_SMARTTHINGS_UPDATE,
@ -185,7 +168,7 @@ async def test_unload_config_entry(hass, device_factory):
"Fan 1", "Fan 1",
capabilities=[Capability.switch, Capability.fan_speed], capabilities=[Capability.switch, Capability.fan_speed],
status={Attribute.switch: 'off', Attribute.fan_speed: 0}) status={Attribute.switch: 'off', Attribute.fan_speed: 0})
config_entry = await _setup_platform(hass, device) config_entry = await setup_platform(hass, FAN_DOMAIN, device)
# Act # Act
await hass.config_entries.async_forward_entry_unload( await hass.config_entries.async_forward_entry_unload(
config_entry, 'fan') config_entry, 'fan')

View file

@ -8,14 +8,33 @@ import pytest
from homeassistant.components import smartthings from homeassistant.components import smartthings
from homeassistant.components.smartthings.const import ( from homeassistant.components.smartthings.const import (
DATA_BROKERS, DOMAIN, EVENT_BUTTON, SIGNAL_SMARTTHINGS_UPDATE, CONF_INSTALLED_APP_ID, CONF_REFRESH_TOKEN, DATA_BROKERS, DOMAIN,
SUPPORTED_PLATFORMS) EVENT_BUTTON, SIGNAL_SMARTTHINGS_UPDATE, SUPPORTED_PLATFORMS)
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from tests.common import mock_coro from tests.common import mock_coro
async def test_migration_creates_new_flow(
hass, smartthings_mock, config_entry):
"""Test migration deletes app and creates new flow."""
config_entry.version = 1
setattr(hass.config_entries, '_entries', [config_entry])
api = smartthings_mock.return_value
api.delete_installed_app.return_value = mock_coro()
await smartthings.async_migrate_entry(hass, config_entry)
assert api.delete_installed_app.call_count == 1
await hass.async_block_till_done()
assert not hass.config_entries.async_entries(DOMAIN)
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
assert flows[0]['handler'] == 'smartthings'
assert flows[0]['context'] == {'source': 'import'}
async def test_unrecoverable_api_errors_create_new_flow( async def test_unrecoverable_api_errors_create_new_flow(
hass, config_entry, smartthings_mock): hass, config_entry, smartthings_mock):
""" """
@ -101,14 +120,22 @@ async def test_unauthorized_installed_app_raises_not_ready(
async def test_config_entry_loads_platforms( async def test_config_entry_loads_platforms(
hass, config_entry, app, installed_app, hass, config_entry, app, installed_app,
device, smartthings_mock): device, smartthings_mock, subscription_factory):
"""Test config entry loads properly and proxies to platforms.""" """Test config entry loads properly and proxies to platforms."""
setattr(hass.config_entries, '_entries', [config_entry]) setattr(hass.config_entries, '_entries', [config_entry])
api = smartthings_mock.return_value api = smartthings_mock.return_value
api.app.return_value = mock_coro(return_value=app) api.app.return_value = mock_coro(return_value=app)
api.installed_app.return_value = mock_coro(return_value=installed_app) api.installed_app.return_value = mock_coro(return_value=installed_app)
api.devices.return_value = mock_coro(return_value=[device]) api.devices.side_effect = \
lambda *args, **kwargs: mock_coro(return_value=[device])
mock_token = Mock()
mock_token.access_token.return_value = str(uuid4())
mock_token.refresh_token.return_value = str(uuid4())
api.generate_tokens.return_value = mock_coro(return_value=mock_token)
subscriptions = [subscription_factory(capability)
for capability in device.capabilities]
api.subscriptions.return_value = mock_coro(return_value=subscriptions)
with patch.object(hass.config_entries, 'async_forward_entry_setup', with patch.object(hass.config_entries, 'async_forward_entry_setup',
return_value=mock_coro()) as forward_mock: return_value=mock_coro()) as forward_mock:
@ -120,8 +147,12 @@ async def test_config_entry_loads_platforms(
async def test_unload_entry(hass, config_entry): async def test_unload_entry(hass, config_entry):
"""Test entries are unloaded correctly.""" """Test entries are unloaded correctly."""
broker = Mock() connect_disconnect = Mock()
broker.event_handler_disconnect = Mock() smart_app = Mock()
smart_app.connect_event.return_value = connect_disconnect
broker = smartthings.DeviceBroker(
hass, config_entry, Mock(), smart_app, [])
broker.connect()
hass.data[DOMAIN][DATA_BROKERS][config_entry.entry_id] = broker hass.data[DOMAIN][DATA_BROKERS][config_entry.entry_id] = broker
with patch.object(hass.config_entries, 'async_forward_entry_unload', with patch.object(hass.config_entries, 'async_forward_entry_unload',
@ -129,15 +160,41 @@ async def test_unload_entry(hass, config_entry):
return_value=True return_value=True
)) as forward_mock: )) as forward_mock:
assert await smartthings.async_unload_entry(hass, config_entry) assert await smartthings.async_unload_entry(hass, config_entry)
assert broker.event_handler_disconnect.call_count == 1
assert connect_disconnect.call_count == 1
assert config_entry.entry_id not in hass.data[DOMAIN][DATA_BROKERS] assert config_entry.entry_id not in hass.data[DOMAIN][DATA_BROKERS]
# Assert platforms unloaded # Assert platforms unloaded
await hass.async_block_till_done() await hass.async_block_till_done()
assert forward_mock.call_count == len(SUPPORTED_PLATFORMS) assert forward_mock.call_count == len(SUPPORTED_PLATFORMS)
async def test_broker_regenerates_token(
hass, config_entry):
"""Test the device broker regenerates the refresh token."""
token = Mock()
token.refresh_token = str(uuid4())
token.refresh.return_value = mock_coro()
stored_action = None
def async_track_time_interval(hass, action, interval):
nonlocal stored_action
stored_action = action
with patch('homeassistant.components.smartthings'
'.async_track_time_interval',
new=async_track_time_interval):
broker = smartthings.DeviceBroker(
hass, config_entry, token, Mock(), [])
broker.connect()
assert stored_action
await stored_action(None) # pylint:disable=not-callable
assert token.refresh.call_count == 1
assert config_entry.data[CONF_REFRESH_TOKEN] == token.refresh_token
async def test_event_handler_dispatches_updated_devices( async def test_event_handler_dispatches_updated_devices(
hass, device_factory, event_request_factory): hass, config_entry, device_factory, event_request_factory):
"""Test the event handler dispatches updated devices.""" """Test the event handler dispatches updated devices."""
devices = [ devices = [
device_factory('Bedroom 1 Switch', ['switch']), device_factory('Bedroom 1 Switch', ['switch']),
@ -147,6 +204,7 @@ async def test_event_handler_dispatches_updated_devices(
device_ids = [devices[0].device_id, devices[1].device_id, device_ids = [devices[0].device_id, devices[1].device_id,
devices[2].device_id] devices[2].device_id]
request = event_request_factory(device_ids) request = event_request_factory(device_ids)
config_entry.data[CONF_INSTALLED_APP_ID] = request.installed_app_id
called = False called = False
def signal(ids): def signal(ids):
@ -154,10 +212,13 @@ async def test_event_handler_dispatches_updated_devices(
called = True called = True
assert device_ids == ids assert device_ids == ids
async_dispatcher_connect(hass, SIGNAL_SMARTTHINGS_UPDATE, signal) async_dispatcher_connect(hass, SIGNAL_SMARTTHINGS_UPDATE, signal)
broker = smartthings.DeviceBroker(
hass, devices, request.installed_app_id)
await broker.event_handler(request, None, None) broker = smartthings.DeviceBroker(
hass, config_entry, Mock(), Mock(), devices)
broker.connect()
# pylint:disable=protected-access
await broker._event_handler(request, None, None)
await hass.async_block_till_done() await hass.async_block_till_done()
assert called assert called
@ -166,7 +227,7 @@ async def test_event_handler_dispatches_updated_devices(
async def test_event_handler_ignores_other_installed_app( async def test_event_handler_ignores_other_installed_app(
hass, device_factory, event_request_factory): hass, config_entry, device_factory, event_request_factory):
"""Test the event handler dispatches updated devices.""" """Test the event handler dispatches updated devices."""
device = device_factory('Bedroom 1 Switch', ['switch']) device = device_factory('Bedroom 1 Switch', ['switch'])
request = event_request_factory([device.device_id]) request = event_request_factory([device.device_id])
@ -176,21 +237,26 @@ async def test_event_handler_ignores_other_installed_app(
nonlocal called nonlocal called
called = True called = True
async_dispatcher_connect(hass, SIGNAL_SMARTTHINGS_UPDATE, signal) async_dispatcher_connect(hass, SIGNAL_SMARTTHINGS_UPDATE, signal)
broker = smartthings.DeviceBroker(hass, [device], str(uuid4())) broker = smartthings.DeviceBroker(
hass, config_entry, Mock(), Mock(), [device])
broker.connect()
await broker.event_handler(request, None, None) # pylint:disable=protected-access
await broker._event_handler(request, None, None)
await hass.async_block_till_done() await hass.async_block_till_done()
assert not called assert not called
async def test_event_handler_fires_button_events( async def test_event_handler_fires_button_events(
hass, device_factory, event_factory, event_request_factory): hass, config_entry, device_factory, event_factory,
event_request_factory):
"""Test the event handler fires button events.""" """Test the event handler fires button events."""
device = device_factory('Button 1', ['button']) device = device_factory('Button 1', ['button'])
event = event_factory(device.device_id, capability='button', event = event_factory(device.device_id, capability='button',
attribute='button', value='pushed') attribute='button', value='pushed')
request = event_request_factory(events=[event]) request = event_request_factory(events=[event])
config_entry.data[CONF_INSTALLED_APP_ID] = request.installed_app_id
called = False called = False
def handler(evt): def handler(evt):
@ -205,8 +271,11 @@ async def test_event_handler_fires_button_events(
} }
hass.bus.async_listen(EVENT_BUTTON, handler) hass.bus.async_listen(EVENT_BUTTON, handler)
broker = smartthings.DeviceBroker( broker = smartthings.DeviceBroker(
hass, [device], request.installed_app_id) hass, config_entry, Mock(), Mock(), [device])
await broker.event_handler(request, None, None) broker.connect()
# pylint:disable=protected-access
await broker._event_handler(request, None, None)
await hass.async_block_till_done() await hass.async_block_till_done()
assert called assert called

View file

@ -9,15 +9,16 @@ import pytest
from homeassistant.components.light import ( from homeassistant.components.light import (
ATTR_BRIGHTNESS, ATTR_COLOR_TEMP, ATTR_HS_COLOR, ATTR_TRANSITION, ATTR_BRIGHTNESS, ATTR_COLOR_TEMP, ATTR_HS_COLOR, ATTR_TRANSITION,
SUPPORT_BRIGHTNESS, SUPPORT_COLOR, SUPPORT_COLOR_TEMP, SUPPORT_TRANSITION) DOMAIN as LIGHT_DOMAIN, SUPPORT_BRIGHTNESS, SUPPORT_COLOR,
from homeassistant.components.smartthings import DeviceBroker, light SUPPORT_COLOR_TEMP, SUPPORT_TRANSITION)
from homeassistant.components.smartthings import light
from homeassistant.components.smartthings.const import ( from homeassistant.components.smartthings.const import (
DATA_BROKERS, DOMAIN, SIGNAL_SMARTTHINGS_UPDATE) DOMAIN, SIGNAL_SMARTTHINGS_UPDATE)
from homeassistant.config_entries import (
CONN_CLASS_CLOUD_PUSH, SOURCE_USER, ConfigEntry)
from homeassistant.const import ATTR_ENTITY_ID, ATTR_SUPPORTED_FEATURES from homeassistant.const import ATTR_ENTITY_ID, ATTR_SUPPORTED_FEATURES
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from .conftest import setup_platform
@pytest.fixture(name="light_devices") @pytest.fixture(name="light_devices")
def light_devices_fixture(device_factory): def light_devices_fixture(device_factory):
@ -44,22 +45,6 @@ def light_devices_fixture(device_factory):
] ]
async def _setup_platform(hass, *devices):
"""Set up the SmartThings light platform and prerequisites."""
hass.config.components.add(DOMAIN)
broker = DeviceBroker(hass, devices, '')
config_entry = ConfigEntry("1", DOMAIN, "Test", {},
SOURCE_USER, CONN_CLASS_CLOUD_PUSH)
hass.data[DOMAIN] = {
DATA_BROKERS: {
config_entry.entry_id: broker
}
}
await hass.config_entries.async_forward_entry_setup(config_entry, 'light')
await hass.async_block_till_done()
return config_entry
async def test_async_setup_platform(): async def test_async_setup_platform():
"""Test setup platform does nothing (it uses config entries).""" """Test setup platform does nothing (it uses config entries)."""
await light.async_setup_platform(None, None, None) await light.async_setup_platform(None, None, None)
@ -67,7 +52,7 @@ async def test_async_setup_platform():
async def test_entity_state(hass, light_devices): async def test_entity_state(hass, light_devices):
"""Tests the state attributes properly match the light types.""" """Tests the state attributes properly match the light types."""
await _setup_platform(hass, *light_devices) await setup_platform(hass, LIGHT_DOMAIN, *light_devices)
# Dimmer 1 # Dimmer 1
state = hass.states.get('light.dimmer_1') state = hass.states.get('light.dimmer_1')
@ -101,7 +86,7 @@ async def test_entity_and_device_attributes(hass, device_factory):
entity_registry = await hass.helpers.entity_registry.async_get_registry() entity_registry = await hass.helpers.entity_registry.async_get_registry()
device_registry = await hass.helpers.device_registry.async_get_registry() device_registry = await hass.helpers.device_registry.async_get_registry()
# Act # Act
await _setup_platform(hass, device) await setup_platform(hass, LIGHT_DOMAIN, device)
# Assert # Assert
entry = entity_registry.async_get("light.light_1") entry = entity_registry.async_get("light.light_1")
assert entry assert entry
@ -118,7 +103,7 @@ async def test_entity_and_device_attributes(hass, device_factory):
async def test_turn_off(hass, light_devices): async def test_turn_off(hass, light_devices):
"""Test the light turns of successfully.""" """Test the light turns of successfully."""
# Arrange # Arrange
await _setup_platform(hass, *light_devices) await setup_platform(hass, LIGHT_DOMAIN, *light_devices)
# Act # Act
await hass.services.async_call( await hass.services.async_call(
'light', 'turn_off', {'entity_id': 'light.color_dimmer_2'}, 'light', 'turn_off', {'entity_id': 'light.color_dimmer_2'},
@ -132,7 +117,7 @@ async def test_turn_off(hass, light_devices):
async def test_turn_off_with_transition(hass, light_devices): async def test_turn_off_with_transition(hass, light_devices):
"""Test the light turns of successfully with transition.""" """Test the light turns of successfully with transition."""
# Arrange # Arrange
await _setup_platform(hass, *light_devices) await setup_platform(hass, LIGHT_DOMAIN, *light_devices)
# Act # Act
await hass.services.async_call( await hass.services.async_call(
'light', 'turn_off', 'light', 'turn_off',
@ -147,7 +132,7 @@ async def test_turn_off_with_transition(hass, light_devices):
async def test_turn_on(hass, light_devices): async def test_turn_on(hass, light_devices):
"""Test the light turns of successfully.""" """Test the light turns of successfully."""
# Arrange # Arrange
await _setup_platform(hass, *light_devices) await setup_platform(hass, LIGHT_DOMAIN, *light_devices)
# Act # Act
await hass.services.async_call( await hass.services.async_call(
'light', 'turn_on', {ATTR_ENTITY_ID: "light.color_dimmer_1"}, 'light', 'turn_on', {ATTR_ENTITY_ID: "light.color_dimmer_1"},
@ -161,7 +146,7 @@ async def test_turn_on(hass, light_devices):
async def test_turn_on_with_brightness(hass, light_devices): async def test_turn_on_with_brightness(hass, light_devices):
"""Test the light turns on to the specified brightness.""" """Test the light turns on to the specified brightness."""
# Arrange # Arrange
await _setup_platform(hass, *light_devices) await setup_platform(hass, LIGHT_DOMAIN, *light_devices)
# Act # Act
await hass.services.async_call( await hass.services.async_call(
'light', 'turn_on', 'light', 'turn_on',
@ -185,7 +170,7 @@ async def test_turn_on_with_minimal_brightness(hass, light_devices):
set the level to zero, which turns off the lights in SmartThings. set the level to zero, which turns off the lights in SmartThings.
""" """
# Arrange # Arrange
await _setup_platform(hass, *light_devices) await setup_platform(hass, LIGHT_DOMAIN, *light_devices)
# Act # Act
await hass.services.async_call( await hass.services.async_call(
'light', 'turn_on', 'light', 'turn_on',
@ -203,7 +188,7 @@ async def test_turn_on_with_minimal_brightness(hass, light_devices):
async def test_turn_on_with_color(hass, light_devices): async def test_turn_on_with_color(hass, light_devices):
"""Test the light turns on with color.""" """Test the light turns on with color."""
# Arrange # Arrange
await _setup_platform(hass, *light_devices) await setup_platform(hass, LIGHT_DOMAIN, *light_devices)
# Act # Act
await hass.services.async_call( await hass.services.async_call(
'light', 'turn_on', 'light', 'turn_on',
@ -220,7 +205,7 @@ async def test_turn_on_with_color(hass, light_devices):
async def test_turn_on_with_color_temp(hass, light_devices): async def test_turn_on_with_color_temp(hass, light_devices):
"""Test the light turns on with color temp.""" """Test the light turns on with color temp."""
# Arrange # Arrange
await _setup_platform(hass, *light_devices) await setup_platform(hass, LIGHT_DOMAIN, *light_devices)
# Act # Act
await hass.services.async_call( await hass.services.async_call(
'light', 'turn_on', 'light', 'turn_on',
@ -244,7 +229,7 @@ async def test_update_from_signal(hass, device_factory):
status={Attribute.switch: 'off', Attribute.level: 100, status={Attribute.switch: 'off', Attribute.level: 100,
Attribute.hue: 76.0, Attribute.saturation: 55.0, Attribute.hue: 76.0, Attribute.saturation: 55.0,
Attribute.color_temperature: 4500}) Attribute.color_temperature: 4500})
await _setup_platform(hass, device) await setup_platform(hass, LIGHT_DOMAIN, device)
await device.switch_on(True) await device.switch_on(True)
# Act # Act
async_dispatcher_send(hass, SIGNAL_SMARTTHINGS_UPDATE, async_dispatcher_send(hass, SIGNAL_SMARTTHINGS_UPDATE,
@ -266,7 +251,7 @@ async def test_unload_config_entry(hass, device_factory):
status={Attribute.switch: 'off', Attribute.level: 100, status={Attribute.switch: 'off', Attribute.level: 100,
Attribute.hue: 76.0, Attribute.saturation: 55.0, Attribute.hue: 76.0, Attribute.saturation: 55.0,
Attribute.color_temperature: 4500}) Attribute.color_temperature: 4500})
config_entry = await _setup_platform(hass, device) config_entry = await setup_platform(hass, LIGHT_DOMAIN, device)
# Act # Act
await hass.config_entries.async_forward_entry_unload( await hass.config_entries.async_forward_entry_unload(
config_entry, 'light') config_entry, 'light')

View file

@ -5,7 +5,9 @@ from uuid import uuid4
from pysmartthings import AppEntity, Capability from pysmartthings import AppEntity, Capability
from homeassistant.components.smartthings import smartapp from homeassistant.components.smartthings import smartapp
from homeassistant.components.smartthings.const import DATA_MANAGER, DOMAIN from homeassistant.components.smartthings.const import (
CONF_INSTALLED_APP_ID, CONF_INSTALLED_APPS, CONF_LOCATION_ID,
CONF_REFRESH_TOKEN, DATA_MANAGER, DOMAIN)
from tests.common import mock_coro from tests.common import mock_coro
@ -35,31 +37,26 @@ async def test_update_app_updated_needed(hass, app):
assert mock_app.classifications == app.classifications assert mock_app.classifications == app.classifications
async def test_smartapp_install_abort_if_no_other( async def test_smartapp_install_store_if_no_other(
hass, smartthings_mock, device_factory): hass, smartthings_mock, device_factory):
"""Test aborts if no other app was configured already.""" """Test aborts if no other app was configured already."""
# Arrange # Arrange
api = smartthings_mock.return_value
api.create_subscription.return_value = mock_coro()
app = Mock() app = Mock()
app.app_id = uuid4() app.app_id = uuid4()
request = Mock() request = Mock()
request.installed_app_id = uuid4() request.installed_app_id = str(uuid4())
request.auth_token = uuid4() request.auth_token = str(uuid4())
request.location_id = uuid4() request.location_id = str(uuid4())
devices = [ request.refresh_token = str(uuid4())
device_factory('', [Capability.battery, 'ping']),
device_factory('', [Capability.switch, Capability.switch_level]),
device_factory('', [Capability.switch])
]
api.devices = Mock()
api.devices.return_value = mock_coro(return_value=devices)
# Act # Act
await smartapp.smartapp_install(hass, request, None, app) await smartapp.smartapp_install(hass, request, None, app)
# Assert # Assert
entries = hass.config_entries.async_entries('smartthings') entries = hass.config_entries.async_entries('smartthings')
assert not entries assert not entries
assert api.create_subscription.call_count == 3 data = hass.data[DOMAIN][CONF_INSTALLED_APPS][0]
assert data[CONF_REFRESH_TOKEN] == request.refresh_token
assert data[CONF_LOCATION_ID] == request.location_id
assert data[CONF_INSTALLED_APP_ID] == request.installed_app_id
async def test_smartapp_install_creates_flow( async def test_smartapp_install_creates_flow(
@ -68,12 +65,12 @@ async def test_smartapp_install_creates_flow(
# Arrange # Arrange
setattr(hass.config_entries, '_entries', [config_entry]) setattr(hass.config_entries, '_entries', [config_entry])
api = smartthings_mock.return_value api = smartthings_mock.return_value
api.create_subscription.return_value = mock_coro()
app = Mock() app = Mock()
app.app_id = config_entry.data['app_id'] app.app_id = config_entry.data['app_id']
request = Mock() request = Mock()
request.installed_app_id = str(uuid4()) request.installed_app_id = str(uuid4())
request.auth_token = str(uuid4()) request.auth_token = str(uuid4())
request.refresh_token = str(uuid4())
request.location_id = location.location_id request.location_id = location.location_id
devices = [ devices = [
device_factory('', [Capability.battery, 'ping']), device_factory('', [Capability.battery, 'ping']),
@ -88,42 +85,42 @@ async def test_smartapp_install_creates_flow(
await hass.async_block_till_done() await hass.async_block_till_done()
entries = hass.config_entries.async_entries('smartthings') entries = hass.config_entries.async_entries('smartthings')
assert len(entries) == 2 assert len(entries) == 2
assert api.create_subscription.call_count == 3
assert entries[1].data['app_id'] == app.app_id assert entries[1].data['app_id'] == app.app_id
assert entries[1].data['installed_app_id'] == request.installed_app_id assert entries[1].data['installed_app_id'] == request.installed_app_id
assert entries[1].data['location_id'] == request.location_id assert entries[1].data['location_id'] == request.location_id
assert entries[1].data['access_token'] == \ assert entries[1].data['access_token'] == \
config_entry.data['access_token'] config_entry.data['access_token']
assert entries[1].data['refresh_token'] == request.refresh_token
assert entries[1].data['client_secret'] == \
config_entry.data['client_secret']
assert entries[1].data['client_id'] == config_entry.data['client_id']
assert entries[1].title == location.name assert entries[1].title == location.name
async def test_smartapp_update_syncs_subs( async def test_smartapp_update_saves_token(
hass, smartthings_mock, config_entry, location, device_factory): hass, smartthings_mock, location, device_factory):
"""Test update synchronizes subscriptions.""" """Test update saves token."""
# Arrange # Arrange
setattr(hass.config_entries, '_entries', [config_entry]) entry = Mock()
entry.data = {
'installed_app_id': str(uuid4()),
'app_id': str(uuid4())
}
entry.domain = DOMAIN
setattr(hass.config_entries, '_entries', [entry])
app = Mock() app = Mock()
app.app_id = config_entry.data['app_id'] app.app_id = entry.data['app_id']
api = smartthings_mock.return_value
api.delete_subscriptions = Mock()
api.delete_subscriptions.return_value = mock_coro()
api.create_subscription.return_value = mock_coro()
request = Mock() request = Mock()
request.installed_app_id = str(uuid4()) request.installed_app_id = entry.data['installed_app_id']
request.auth_token = str(uuid4()) request.auth_token = str(uuid4())
request.refresh_token = str(uuid4())
request.location_id = location.location_id request.location_id = location.location_id
devices = [
device_factory('', [Capability.battery, 'ping']),
device_factory('', [Capability.switch, Capability.switch_level]),
device_factory('', [Capability.switch])
]
api.devices = Mock()
api.devices.return_value = mock_coro(return_value=devices)
# Act # Act
await smartapp.smartapp_update(hass, request, None, app) await smartapp.smartapp_update(hass, request, None, app)
# Assert # Assert
assert api.create_subscription.call_count == 3 assert entry.data[CONF_REFRESH_TOKEN] == request.refresh_token
assert api.delete_subscriptions.call_count == 1
async def test_smartapp_uninstall(hass, config_entry): async def test_smartapp_uninstall(hass, config_entry):
@ -152,3 +149,83 @@ async def test_smartapp_webhook(hass):
result = await smartapp.smartapp_webhook(hass, '', request) result = await smartapp.smartapp_webhook(hass, '', request)
assert result.body == b'{}' assert result.body == b'{}'
async def test_smartapp_sync_subscriptions(
hass, smartthings_mock, device_factory, subscription_factory):
"""Test synchronization adds and removes."""
api = smartthings_mock.return_value
api.delete_subscription.side_effect = lambda loc_id, sub_id: mock_coro()
api.create_subscription.side_effect = lambda sub: mock_coro()
subscriptions = [
subscription_factory(Capability.thermostat),
subscription_factory(Capability.switch),
subscription_factory(Capability.switch_level)
]
api.subscriptions.return_value = mock_coro(return_value=subscriptions)
devices = [
device_factory('', [Capability.battery, 'ping']),
device_factory('', [Capability.switch, Capability.switch_level]),
device_factory('', [Capability.switch])
]
await smartapp.smartapp_sync_subscriptions(
hass, str(uuid4()), str(uuid4()), str(uuid4()), devices)
assert api.subscriptions.call_count == 1
assert api.delete_subscription.call_count == 1
assert api.create_subscription.call_count == 1
async def test_smartapp_sync_subscriptions_up_to_date(
hass, smartthings_mock, device_factory, subscription_factory):
"""Test synchronization does nothing when current."""
api = smartthings_mock.return_value
api.delete_subscription.side_effect = lambda loc_id, sub_id: mock_coro()
api.create_subscription.side_effect = lambda sub: mock_coro()
subscriptions = [
subscription_factory(Capability.battery),
subscription_factory(Capability.switch),
subscription_factory(Capability.switch_level)
]
api.subscriptions.return_value = mock_coro(return_value=subscriptions)
devices = [
device_factory('', [Capability.battery, 'ping']),
device_factory('', [Capability.switch, Capability.switch_level]),
device_factory('', [Capability.switch])
]
await smartapp.smartapp_sync_subscriptions(
hass, str(uuid4()), str(uuid4()), str(uuid4()), devices)
assert api.subscriptions.call_count == 1
assert api.delete_subscription.call_count == 0
assert api.create_subscription.call_count == 0
async def test_smartapp_sync_subscriptions_handles_exceptions(
hass, smartthings_mock, device_factory, subscription_factory):
"""Test synchronization does nothing when current."""
api = smartthings_mock.return_value
api.delete_subscription.side_effect = \
lambda loc_id, sub_id: mock_coro(exception=Exception)
api.create_subscription.side_effect = \
lambda sub: mock_coro(exception=Exception)
subscriptions = [
subscription_factory(Capability.battery),
subscription_factory(Capability.switch),
subscription_factory(Capability.switch_level)
]
api.subscriptions.return_value = mock_coro(return_value=subscriptions)
devices = [
device_factory('', [Capability.thermostat, 'ping']),
device_factory('', [Capability.switch, Capability.switch_level]),
device_factory('', [Capability.switch])
]
await smartapp.smartapp_sync_subscriptions(
hass, str(uuid4()), str(uuid4()), str(uuid4()), devices)
assert api.subscriptions.call_count == 1
assert api.delete_subscription.call_count == 1
assert api.create_subscription.call_count == 1

View file

@ -6,28 +6,13 @@ real HTTP calls are not initiated during testing.
""" """
from pysmartthings import Attribute, Capability from pysmartthings import Attribute, Capability
from homeassistant.components.smartthings import DeviceBroker, switch from homeassistant.components.smartthings import switch
from homeassistant.components.smartthings.const import ( from homeassistant.components.smartthings.const import (
DATA_BROKERS, DOMAIN, SIGNAL_SMARTTHINGS_UPDATE) DOMAIN, SIGNAL_SMARTTHINGS_UPDATE)
from homeassistant.config_entries import ( from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN
CONN_CLASS_CLOUD_PUSH, SOURCE_USER, ConfigEntry)
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from .conftest import setup_platform
async def _setup_platform(hass, *devices):
"""Set up the SmartThings switch platform and prerequisites."""
hass.config.components.add(DOMAIN)
broker = DeviceBroker(hass, devices, '')
config_entry = ConfigEntry("1", DOMAIN, "Test", {},
SOURCE_USER, CONN_CLASS_CLOUD_PUSH)
hass.data[DOMAIN] = {
DATA_BROKERS: {
config_entry.entry_id: broker
}
}
await hass.config_entries.async_forward_entry_setup(config_entry, 'switch')
await hass.async_block_till_done()
return config_entry
async def test_async_setup_platform(): async def test_async_setup_platform():
@ -43,7 +28,7 @@ async def test_entity_and_device_attributes(hass, device_factory):
entity_registry = await hass.helpers.entity_registry.async_get_registry() entity_registry = await hass.helpers.entity_registry.async_get_registry()
device_registry = await hass.helpers.device_registry.async_get_registry() device_registry = await hass.helpers.device_registry.async_get_registry()
# Act # Act
await _setup_platform(hass, device) await setup_platform(hass, SWITCH_DOMAIN, device)
# Assert # Assert
entry = entity_registry.async_get('switch.switch_1') entry = entity_registry.async_get('switch.switch_1')
assert entry assert entry
@ -62,7 +47,7 @@ async def test_turn_off(hass, device_factory):
# Arrange # Arrange
device = device_factory('Switch_1', [Capability.switch], device = device_factory('Switch_1', [Capability.switch],
{Attribute.switch: 'on'}) {Attribute.switch: 'on'})
await _setup_platform(hass, device) await setup_platform(hass, SWITCH_DOMAIN, device)
# Act # Act
await hass.services.async_call( await hass.services.async_call(
'switch', 'turn_off', {'entity_id': 'switch.switch_1'}, 'switch', 'turn_off', {'entity_id': 'switch.switch_1'},
@ -78,7 +63,7 @@ async def test_turn_on(hass, device_factory):
# Arrange # Arrange
device = device_factory('Switch_1', [Capability.switch], device = device_factory('Switch_1', [Capability.switch],
{Attribute.switch: 'off'}) {Attribute.switch: 'off'})
await _setup_platform(hass, device) await setup_platform(hass, SWITCH_DOMAIN, device)
# Act # Act
await hass.services.async_call( await hass.services.async_call(
'switch', 'turn_on', {'entity_id': 'switch.switch_1'}, 'switch', 'turn_on', {'entity_id': 'switch.switch_1'},
@ -94,7 +79,7 @@ async def test_update_from_signal(hass, device_factory):
# Arrange # Arrange
device = device_factory('Switch_1', [Capability.switch], device = device_factory('Switch_1', [Capability.switch],
{Attribute.switch: 'off'}) {Attribute.switch: 'off'})
await _setup_platform(hass, device) await setup_platform(hass, SWITCH_DOMAIN, device)
await device.switch_on(True) await device.switch_on(True)
# Act # Act
async_dispatcher_send(hass, SIGNAL_SMARTTHINGS_UPDATE, async_dispatcher_send(hass, SIGNAL_SMARTTHINGS_UPDATE,
@ -111,7 +96,7 @@ async def test_unload_config_entry(hass, device_factory):
# Arrange # Arrange
device = device_factory('Switch 1', [Capability.switch], device = device_factory('Switch 1', [Capability.switch],
{Attribute.switch: 'on'}) {Attribute.switch: 'on'})
config_entry = await _setup_platform(hass, device) config_entry = await setup_platform(hass, SWITCH_DOMAIN, device)
# Act # Act
await hass.config_entries.async_forward_entry_unload( await hass.config_entries.async_forward_entry_unload(
config_entry, 'switch') config_entry, 'switch')