Google Assistant: Create and pass context to service calls (#21551)

* Google Assistant: Create and pass context to service calls

* Refactor request data into separate object and pass to execute.
This commit is contained in:
Penny Wood 2019-03-06 12:00:53 +08:00 committed by Paulus Schoutsen
parent fc1ee9be43
commit d1038ea79f
8 changed files with 345 additions and 265 deletions

View file

@ -1,4 +1,5 @@
"""Helper classes for Google Assistant integration."""
from homeassistant.core import Context
class SmartHomeError(Exception):
@ -16,10 +17,19 @@ class SmartHomeError(Exception):
class Config:
"""Hold the configuration for Google Assistant."""
def __init__(self, should_expose, allow_unlock, agent_user_id,
def __init__(self, should_expose, allow_unlock,
entity_config=None):
"""Initialize the configuration."""
self.should_expose = should_expose
self.agent_user_id = agent_user_id
self.entity_config = entity_config or {}
self.allow_unlock = allow_unlock
class RequestData:
"""Hold data associated with a particular request."""
def __init__(self, config, user_id, request_id):
"""Initialize the request data."""
self.config = config
self.request_id = request_id
self.context = Context(user_id=user_id)

View file

@ -71,17 +71,16 @@ class GoogleAssistantView(HomeAssistantView):
def __init__(self, is_exposed, entity_config, allow_unlock):
"""Initialize the Google Assistant request handler."""
self.is_exposed = is_exposed
self.entity_config = entity_config
self.allow_unlock = allow_unlock
self.config = Config(is_exposed,
allow_unlock,
entity_config)
async def post(self, request: Request) -> Response:
"""Handle Google Assistant requests."""
message = await request.json() # type: dict
config = Config(self.is_exposed,
self.allow_unlock,
request['hass_user'].id,
self.entity_config)
result = await async_handle_message(
request.app['hass'], config, message)
request.app['hass'],
self.config,
request['hass_user'].id,
message)
return self.json(result)

View file

@ -36,7 +36,7 @@ from .const import (
ERR_UNKNOWN_ERROR,
EVENT_COMMAND_RECEIVED, EVENT_SYNC_RECEIVED, EVENT_QUERY_RECEIVED
)
from .helpers import SmartHomeError
from .helpers import SmartHomeError, RequestData
HANDLERS = Registry()
_LOGGER = logging.getLogger(__name__)
@ -87,7 +87,8 @@ class _GoogleEntity:
domain = state.domain
features = state.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
return [Trait(self.hass, state, self.config) for Trait in trait.TRAITS
return [Trait(self.hass, state, self.config)
for Trait in trait.TRAITS
if Trait.supported(domain, features)]
async def sync_serialize(self):
@ -178,7 +179,7 @@ class _GoogleEntity:
return attrs
async def execute(self, command, params):
async def execute(self, command, data, params):
"""Execute a command.
https://developers.google.com/actions/smarthome/create-app#actiondevicesexecute
@ -186,7 +187,7 @@ class _GoogleEntity:
executed = False
for trt in self.traits():
if trt.can_execute(command, params):
await trt.execute(command, params)
await trt.execute(command, data, params)
executed = True
break
@ -202,9 +203,13 @@ class _GoogleEntity:
self.state = self.hass.states.get(self.entity_id)
async def async_handle_message(hass, config, message):
async def async_handle_message(hass, config, user_id, message):
"""Handle incoming API messages."""
response = await _process(hass, config, message)
request_id = message.get('requestId') # type: str
data = RequestData(config, user_id, request_id)
response = await _process(hass, data, message)
if response and 'errorCode' in response['payload']:
_LOGGER.error('Error handling message %s: %s',
@ -213,14 +218,13 @@ async def async_handle_message(hass, config, message):
return response
async def _process(hass, config, message):
async def _process(hass, data, message):
"""Process a message."""
request_id = message.get('requestId') # type: str
inputs = message.get('inputs') # type: list
if len(inputs) != 1:
return {
'requestId': request_id,
'requestId': data.request_id,
'payload': {'errorCode': ERR_PROTOCOL_ERROR}
}
@ -228,49 +232,49 @@ async def _process(hass, config, message):
if handler is None:
return {
'requestId': request_id,
'requestId': data.request_id,
'payload': {'errorCode': ERR_PROTOCOL_ERROR}
}
try:
result = await handler(hass, config, request_id,
inputs[0].get('payload'))
result = await handler(hass, data, inputs[0].get('payload'))
except SmartHomeError as err:
return {
'requestId': request_id,
'requestId': data.request_id,
'payload': {'errorCode': err.code}
}
except Exception: # pylint: disable=broad-except
_LOGGER.exception('Unexpected error')
return {
'requestId': request_id,
'requestId': data.request_id,
'payload': {'errorCode': ERR_UNKNOWN_ERROR}
}
if result is None:
return None
return {'requestId': request_id, 'payload': result}
return {'requestId': data.request_id, 'payload': result}
@HANDLERS.register('action.devices.SYNC')
async def async_devices_sync(hass, config, request_id, payload):
async def async_devices_sync(hass, data, payload):
"""Handle action.devices.SYNC request.
https://developers.google.com/actions/smarthome/create-app#actiondevicessync
"""
hass.bus.async_fire(EVENT_SYNC_RECEIVED, {
'request_id': request_id
})
hass.bus.async_fire(
EVENT_SYNC_RECEIVED,
{'request_id': data.request_id},
context=data.context)
devices = []
for state in hass.states.async_all():
if state.entity_id in CLOUD_NEVER_EXPOSED_ENTITIES:
continue
if not config.should_expose(state):
if not data.config.should_expose(state):
continue
entity = _GoogleEntity(hass, config, state)
entity = _GoogleEntity(hass, data.config, state)
serialized = await entity.sync_serialize()
if serialized is None:
@ -280,7 +284,7 @@ async def async_devices_sync(hass, config, request_id, payload):
devices.append(serialized)
response = {
'agentUserId': config.agent_user_id,
'agentUserId': data.context.user_id,
'devices': devices,
}
@ -288,7 +292,7 @@ async def async_devices_sync(hass, config, request_id, payload):
@HANDLERS.register('action.devices.QUERY')
async def async_devices_query(hass, config, request_id, payload):
async def async_devices_query(hass, data, payload):
"""Handle action.devices.QUERY request.
https://developers.google.com/actions/smarthome/create-app#actiondevicesquery
@ -298,23 +302,27 @@ async def async_devices_query(hass, config, request_id, payload):
devid = device['id']
state = hass.states.get(devid)
hass.bus.async_fire(EVENT_QUERY_RECEIVED, {
'request_id': request_id,
ATTR_ENTITY_ID: devid,
})
hass.bus.async_fire(
EVENT_QUERY_RECEIVED,
{
'request_id': data.request_id,
ATTR_ENTITY_ID: devid,
},
context=data.context)
if not state:
# If we can't find a state, the device is offline
devices[devid] = {'online': False}
continue
devices[devid] = _GoogleEntity(hass, config, state).query_serialize()
entity = _GoogleEntity(hass, data.config, state)
devices[devid] = entity.query_serialize()
return {'devices': devices}
@HANDLERS.register('action.devices.EXECUTE')
async def handle_devices_execute(hass, config, request_id, payload):
async def handle_devices_execute(hass, data, payload):
"""Handle action.devices.EXECUTE request.
https://developers.google.com/actions/smarthome/create-app#actiondevicesexecute
@ -327,11 +335,14 @@ async def handle_devices_execute(hass, config, request_id, payload):
command['execution']):
entity_id = device['id']
hass.bus.async_fire(EVENT_COMMAND_RECEIVED, {
'request_id': request_id,
ATTR_ENTITY_ID: entity_id,
'execution': execution
})
hass.bus.async_fire(
EVENT_COMMAND_RECEIVED,
{
'request_id': data.request_id,
ATTR_ENTITY_ID: entity_id,
'execution': execution
},
context=data.context)
# Happens if error occurred. Skip entity for further processing
if entity_id in results:
@ -348,10 +359,11 @@ async def handle_devices_execute(hass, config, request_id, payload):
}
continue
entities[entity_id] = _GoogleEntity(hass, config, state)
entities[entity_id] = _GoogleEntity(hass, data.config, state)
try:
await entities[entity_id].execute(execution['command'],
data,
execution.get('params', {}))
except SmartHomeError as err:
results[entity_id] = {
@ -378,7 +390,7 @@ async def handle_devices_execute(hass, config, request_id, payload):
@HANDLERS.register('action.devices.DISCONNECT')
async def async_devices_disconnect(hass, config, request_id, payload):
async def async_devices_disconnect(hass, data, payload):
"""Handle action.devices.DISCONNECT request.
https://developers.google.com/actions/smarthome/create#actiondevicesdisconnect

View file

@ -102,7 +102,7 @@ class _Trait:
"""Test if command can be executed."""
return command in self.commands
async def execute(self, command, params):
async def execute(self, command, data, params):
"""Execute a trait command."""
raise NotImplementedError
@ -159,7 +159,7 @@ class BrightnessTrait(_Trait):
return response
async def execute(self, command, params):
async def execute(self, command, data, params):
"""Execute a brightness command."""
domain = self.state.domain
@ -168,20 +168,20 @@ class BrightnessTrait(_Trait):
light.DOMAIN, light.SERVICE_TURN_ON, {
ATTR_ENTITY_ID: self.state.entity_id,
light.ATTR_BRIGHTNESS_PCT: params['brightness']
}, blocking=True)
}, blocking=True, context=data.context)
elif domain == cover.DOMAIN:
await self.hass.services.async_call(
cover.DOMAIN, cover.SERVICE_SET_COVER_POSITION, {
ATTR_ENTITY_ID: self.state.entity_id,
cover.ATTR_POSITION: params['brightness']
}, blocking=True)
}, blocking=True, context=data.context)
elif domain == media_player.DOMAIN:
await self.hass.services.async_call(
media_player.DOMAIN, media_player.SERVICE_VOLUME_SET, {
ATTR_ENTITY_ID: self.state.entity_id,
media_player.ATTR_MEDIA_VOLUME_LEVEL:
params['brightness'] / 100
}, blocking=True)
}, blocking=True, context=data.context)
@register_trait
@ -221,7 +221,7 @@ class OnOffTrait(_Trait):
return {'on': self.state.state != cover.STATE_CLOSED}
return {'on': self.state.state != STATE_OFF}
async def execute(self, command, params):
async def execute(self, command, data, params):
"""Execute an OnOff command."""
domain = self.state.domain
@ -242,7 +242,7 @@ class OnOffTrait(_Trait):
await self.hass.services.async_call(service_domain, service, {
ATTR_ENTITY_ID: self.state.entity_id
}, blocking=True)
}, blocking=True, context=data.context)
@register_trait
@ -288,7 +288,7 @@ class ColorSpectrumTrait(_Trait):
return (command in self.commands and
'spectrumRGB' in params.get('color', {}))
async def execute(self, command, params):
async def execute(self, command, data, params):
"""Execute a color spectrum command."""
# Convert integer to hex format and left pad with 0's till length 6
hex_value = "{0:06x}".format(params['color']['spectrumRGB'])
@ -298,7 +298,7 @@ class ColorSpectrumTrait(_Trait):
await self.hass.services.async_call(light.DOMAIN, SERVICE_TURN_ON, {
ATTR_ENTITY_ID: self.state.entity_id,
light.ATTR_HS_COLOR: color
}, blocking=True)
}, blocking=True, context=data.context)
@register_trait
@ -355,7 +355,7 @@ class ColorTemperatureTrait(_Trait):
return (command in self.commands and
'temperature' in params.get('color', {}))
async def execute(self, command, params):
async def execute(self, command, data, params):
"""Execute a color temperature command."""
temp = color_util.color_temperature_kelvin_to_mired(
params['color']['temperature'])
@ -371,7 +371,7 @@ class ColorTemperatureTrait(_Trait):
await self.hass.services.async_call(light.DOMAIN, SERVICE_TURN_ON, {
ATTR_ENTITY_ID: self.state.entity_id,
light.ATTR_COLOR_TEMP: temp,
}, blocking=True)
}, blocking=True, context=data.context)
@register_trait
@ -400,13 +400,14 @@ class SceneTrait(_Trait):
"""Return scene query attributes."""
return {}
async def execute(self, command, params):
async def execute(self, command, data, params):
"""Execute a scene command."""
# Don't block for scripts as they can be slow.
await self.hass.services.async_call(
self.state.domain, SERVICE_TURN_ON, {
ATTR_ENTITY_ID: self.state.entity_id
}, blocking=self.state.domain != script.DOMAIN)
}, blocking=self.state.domain != script.DOMAIN,
context=data.context)
@register_trait
@ -434,12 +435,12 @@ class DockTrait(_Trait):
"""Return dock query attributes."""
return {'isDocked': self.state.state == vacuum.STATE_DOCKED}
async def execute(self, command, params):
async def execute(self, command, data, params):
"""Execute a dock command."""
await self.hass.services.async_call(
self.state.domain, vacuum.SERVICE_RETURN_TO_BASE, {
ATTR_ENTITY_ID: self.state.entity_id
}, blocking=True)
}, blocking=True, context=data.context)
@register_trait
@ -473,30 +474,30 @@ class StartStopTrait(_Trait):
'isPaused': self.state.state == vacuum.STATE_PAUSED,
}
async def execute(self, command, params):
async def execute(self, command, data, params):
"""Execute a StartStop command."""
if command == COMMAND_STARTSTOP:
if params['start']:
await self.hass.services.async_call(
self.state.domain, vacuum.SERVICE_START, {
ATTR_ENTITY_ID: self.state.entity_id
}, blocking=True)
}, blocking=True, context=data.context)
else:
await self.hass.services.async_call(
self.state.domain, vacuum.SERVICE_STOP, {
ATTR_ENTITY_ID: self.state.entity_id
}, blocking=True)
}, blocking=True, context=data.context)
elif command == COMMAND_PAUSEUNPAUSE:
if params['pause']:
await self.hass.services.async_call(
self.state.domain, vacuum.SERVICE_PAUSE, {
ATTR_ENTITY_ID: self.state.entity_id
}, blocking=True)
}, blocking=True, context=data.context)
else:
await self.hass.services.async_call(
self.state.domain, vacuum.SERVICE_START, {
ATTR_ENTITY_ID: self.state.entity_id
}, blocking=True)
}, blocking=True, context=data.context)
@register_trait
@ -584,7 +585,7 @@ class TemperatureSettingTrait(_Trait):
return response
async def execute(self, command, params):
async def execute(self, command, data, params):
"""Execute a temperature point or mode command."""
# All sent in temperatures are always in Celsius
unit = self.hass.config.units.temperature_unit
@ -608,7 +609,7 @@ class TemperatureSettingTrait(_Trait):
climate.DOMAIN, climate.SERVICE_SET_TEMPERATURE, {
ATTR_ENTITY_ID: self.state.entity_id,
ATTR_TEMPERATURE: temp
}, blocking=True)
}, blocking=True, context=data.context)
elif command == COMMAND_THERMOSTAT_TEMPERATURE_SET_RANGE:
temp_high = temp_util.convert(
@ -640,7 +641,7 @@ class TemperatureSettingTrait(_Trait):
ATTR_ENTITY_ID: self.state.entity_id,
climate.ATTR_TARGET_TEMP_HIGH: temp_high,
climate.ATTR_TARGET_TEMP_LOW: temp_low,
}, blocking=True)
}, blocking=True, context=data.context)
elif command == COMMAND_THERMOSTAT_SET_MODE:
await self.hass.services.async_call(
@ -648,7 +649,7 @@ class TemperatureSettingTrait(_Trait):
ATTR_ENTITY_ID: self.state.entity_id,
climate.ATTR_OPERATION_MODE:
self.google_to_hass[params['thermostatMode']],
}, blocking=True)
}, blocking=True, context=data.context)
@register_trait
@ -681,7 +682,7 @@ class LockUnlockTrait(_Trait):
allowed_unlock = not params['lock'] and self.config.allow_unlock
return params['lock'] or allowed_unlock
async def execute(self, command, params):
async def execute(self, command, data, params):
"""Execute an LockUnlock command."""
if params['lock']:
service = lock.SERVICE_LOCK
@ -690,7 +691,7 @@ class LockUnlockTrait(_Trait):
await self.hass.services.async_call(lock.DOMAIN, service, {
ATTR_ENTITY_ID: self.state.entity_id
}, blocking=True)
}, blocking=True, context=data.context)
@register_trait
@ -760,13 +761,13 @@ class FanSpeedTrait(_Trait):
return response
async def execute(self, command, params):
async def execute(self, command, data, params):
"""Execute an SetFanSpeed command."""
await self.hass.services.async_call(
fan.DOMAIN, fan.SERVICE_SET_SPEED, {
ATTR_ENTITY_ID: self.state.entity_id,
fan.ATTR_SPEED: params['fanSpeed']
}, blocking=True)
}, blocking=True, context=data.context)
@register_trait
@ -934,7 +935,7 @@ class ModesTrait(_Trait):
return response
async def execute(self, command, params):
async def execute(self, command, data, params):
"""Execute an SetModes command."""
settings = params.get('updateModeSettings')
requested_source = settings.get(
@ -951,4 +952,4 @@ class ModesTrait(_Trait):
media_player.SERVICE_SELECT_SOURCE, {
ATTR_ENTITY_ID: self.state.entity_id,
media_player.ATTR_INPUT_SOURCE: source
}, blocking=True)
}, blocking=True, context=data.context)