Google Assistant: Create and pass context to service calls (#21551)
* Google Assistant: Create and pass context to service calls * Refactor request data into separate object and pass to execute.
This commit is contained in:
parent
fc1ee9be43
commit
d1038ea79f
8 changed files with 345 additions and 265 deletions
|
@ -1,4 +1,5 @@
|
|||
"""Helper classes for Google Assistant integration."""
|
||||
from homeassistant.core import Context
|
||||
|
||||
|
||||
class SmartHomeError(Exception):
|
||||
|
@ -16,10 +17,19 @@ class SmartHomeError(Exception):
|
|||
class Config:
|
||||
"""Hold the configuration for Google Assistant."""
|
||||
|
||||
def __init__(self, should_expose, allow_unlock, agent_user_id,
|
||||
def __init__(self, should_expose, allow_unlock,
|
||||
entity_config=None):
|
||||
"""Initialize the configuration."""
|
||||
self.should_expose = should_expose
|
||||
self.agent_user_id = agent_user_id
|
||||
self.entity_config = entity_config or {}
|
||||
self.allow_unlock = allow_unlock
|
||||
|
||||
|
||||
class RequestData:
|
||||
"""Hold data associated with a particular request."""
|
||||
|
||||
def __init__(self, config, user_id, request_id):
|
||||
"""Initialize the request data."""
|
||||
self.config = config
|
||||
self.request_id = request_id
|
||||
self.context = Context(user_id=user_id)
|
||||
|
|
|
@ -71,17 +71,16 @@ class GoogleAssistantView(HomeAssistantView):
|
|||
|
||||
def __init__(self, is_exposed, entity_config, allow_unlock):
|
||||
"""Initialize the Google Assistant request handler."""
|
||||
self.is_exposed = is_exposed
|
||||
self.entity_config = entity_config
|
||||
self.allow_unlock = allow_unlock
|
||||
self.config = Config(is_exposed,
|
||||
allow_unlock,
|
||||
entity_config)
|
||||
|
||||
async def post(self, request: Request) -> Response:
|
||||
"""Handle Google Assistant requests."""
|
||||
message = await request.json() # type: dict
|
||||
config = Config(self.is_exposed,
|
||||
self.allow_unlock,
|
||||
request['hass_user'].id,
|
||||
self.entity_config)
|
||||
result = await async_handle_message(
|
||||
request.app['hass'], config, message)
|
||||
request.app['hass'],
|
||||
self.config,
|
||||
request['hass_user'].id,
|
||||
message)
|
||||
return self.json(result)
|
||||
|
|
|
@ -36,7 +36,7 @@ from .const import (
|
|||
ERR_UNKNOWN_ERROR,
|
||||
EVENT_COMMAND_RECEIVED, EVENT_SYNC_RECEIVED, EVENT_QUERY_RECEIVED
|
||||
)
|
||||
from .helpers import SmartHomeError
|
||||
from .helpers import SmartHomeError, RequestData
|
||||
|
||||
HANDLERS = Registry()
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -87,7 +87,8 @@ class _GoogleEntity:
|
|||
domain = state.domain
|
||||
features = state.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
|
||||
|
||||
return [Trait(self.hass, state, self.config) for Trait in trait.TRAITS
|
||||
return [Trait(self.hass, state, self.config)
|
||||
for Trait in trait.TRAITS
|
||||
if Trait.supported(domain, features)]
|
||||
|
||||
async def sync_serialize(self):
|
||||
|
@ -178,7 +179,7 @@ class _GoogleEntity:
|
|||
|
||||
return attrs
|
||||
|
||||
async def execute(self, command, params):
|
||||
async def execute(self, command, data, params):
|
||||
"""Execute a command.
|
||||
|
||||
https://developers.google.com/actions/smarthome/create-app#actiondevicesexecute
|
||||
|
@ -186,7 +187,7 @@ class _GoogleEntity:
|
|||
executed = False
|
||||
for trt in self.traits():
|
||||
if trt.can_execute(command, params):
|
||||
await trt.execute(command, params)
|
||||
await trt.execute(command, data, params)
|
||||
executed = True
|
||||
break
|
||||
|
||||
|
@ -202,9 +203,13 @@ class _GoogleEntity:
|
|||
self.state = self.hass.states.get(self.entity_id)
|
||||
|
||||
|
||||
async def async_handle_message(hass, config, message):
|
||||
async def async_handle_message(hass, config, user_id, message):
|
||||
"""Handle incoming API messages."""
|
||||
response = await _process(hass, config, message)
|
||||
request_id = message.get('requestId') # type: str
|
||||
|
||||
data = RequestData(config, user_id, request_id)
|
||||
|
||||
response = await _process(hass, data, message)
|
||||
|
||||
if response and 'errorCode' in response['payload']:
|
||||
_LOGGER.error('Error handling message %s: %s',
|
||||
|
@ -213,14 +218,13 @@ async def async_handle_message(hass, config, message):
|
|||
return response
|
||||
|
||||
|
||||
async def _process(hass, config, message):
|
||||
async def _process(hass, data, message):
|
||||
"""Process a message."""
|
||||
request_id = message.get('requestId') # type: str
|
||||
inputs = message.get('inputs') # type: list
|
||||
|
||||
if len(inputs) != 1:
|
||||
return {
|
||||
'requestId': request_id,
|
||||
'requestId': data.request_id,
|
||||
'payload': {'errorCode': ERR_PROTOCOL_ERROR}
|
||||
}
|
||||
|
||||
|
@ -228,49 +232,49 @@ async def _process(hass, config, message):
|
|||
|
||||
if handler is None:
|
||||
return {
|
||||
'requestId': request_id,
|
||||
'requestId': data.request_id,
|
||||
'payload': {'errorCode': ERR_PROTOCOL_ERROR}
|
||||
}
|
||||
|
||||
try:
|
||||
result = await handler(hass, config, request_id,
|
||||
inputs[0].get('payload'))
|
||||
result = await handler(hass, data, inputs[0].get('payload'))
|
||||
except SmartHomeError as err:
|
||||
return {
|
||||
'requestId': request_id,
|
||||
'requestId': data.request_id,
|
||||
'payload': {'errorCode': err.code}
|
||||
}
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception('Unexpected error')
|
||||
return {
|
||||
'requestId': request_id,
|
||||
'requestId': data.request_id,
|
||||
'payload': {'errorCode': ERR_UNKNOWN_ERROR}
|
||||
}
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
return {'requestId': request_id, 'payload': result}
|
||||
return {'requestId': data.request_id, 'payload': result}
|
||||
|
||||
|
||||
@HANDLERS.register('action.devices.SYNC')
|
||||
async def async_devices_sync(hass, config, request_id, payload):
|
||||
async def async_devices_sync(hass, data, payload):
|
||||
"""Handle action.devices.SYNC request.
|
||||
|
||||
https://developers.google.com/actions/smarthome/create-app#actiondevicessync
|
||||
"""
|
||||
hass.bus.async_fire(EVENT_SYNC_RECEIVED, {
|
||||
'request_id': request_id
|
||||
})
|
||||
hass.bus.async_fire(
|
||||
EVENT_SYNC_RECEIVED,
|
||||
{'request_id': data.request_id},
|
||||
context=data.context)
|
||||
|
||||
devices = []
|
||||
for state in hass.states.async_all():
|
||||
if state.entity_id in CLOUD_NEVER_EXPOSED_ENTITIES:
|
||||
continue
|
||||
|
||||
if not config.should_expose(state):
|
||||
if not data.config.should_expose(state):
|
||||
continue
|
||||
|
||||
entity = _GoogleEntity(hass, config, state)
|
||||
entity = _GoogleEntity(hass, data.config, state)
|
||||
serialized = await entity.sync_serialize()
|
||||
|
||||
if serialized is None:
|
||||
|
@ -280,7 +284,7 @@ async def async_devices_sync(hass, config, request_id, payload):
|
|||
devices.append(serialized)
|
||||
|
||||
response = {
|
||||
'agentUserId': config.agent_user_id,
|
||||
'agentUserId': data.context.user_id,
|
||||
'devices': devices,
|
||||
}
|
||||
|
||||
|
@ -288,7 +292,7 @@ async def async_devices_sync(hass, config, request_id, payload):
|
|||
|
||||
|
||||
@HANDLERS.register('action.devices.QUERY')
|
||||
async def async_devices_query(hass, config, request_id, payload):
|
||||
async def async_devices_query(hass, data, payload):
|
||||
"""Handle action.devices.QUERY request.
|
||||
|
||||
https://developers.google.com/actions/smarthome/create-app#actiondevicesquery
|
||||
|
@ -298,23 +302,27 @@ async def async_devices_query(hass, config, request_id, payload):
|
|||
devid = device['id']
|
||||
state = hass.states.get(devid)
|
||||
|
||||
hass.bus.async_fire(EVENT_QUERY_RECEIVED, {
|
||||
'request_id': request_id,
|
||||
ATTR_ENTITY_ID: devid,
|
||||
})
|
||||
hass.bus.async_fire(
|
||||
EVENT_QUERY_RECEIVED,
|
||||
{
|
||||
'request_id': data.request_id,
|
||||
ATTR_ENTITY_ID: devid,
|
||||
},
|
||||
context=data.context)
|
||||
|
||||
if not state:
|
||||
# If we can't find a state, the device is offline
|
||||
devices[devid] = {'online': False}
|
||||
continue
|
||||
|
||||
devices[devid] = _GoogleEntity(hass, config, state).query_serialize()
|
||||
entity = _GoogleEntity(hass, data.config, state)
|
||||
devices[devid] = entity.query_serialize()
|
||||
|
||||
return {'devices': devices}
|
||||
|
||||
|
||||
@HANDLERS.register('action.devices.EXECUTE')
|
||||
async def handle_devices_execute(hass, config, request_id, payload):
|
||||
async def handle_devices_execute(hass, data, payload):
|
||||
"""Handle action.devices.EXECUTE request.
|
||||
|
||||
https://developers.google.com/actions/smarthome/create-app#actiondevicesexecute
|
||||
|
@ -327,11 +335,14 @@ async def handle_devices_execute(hass, config, request_id, payload):
|
|||
command['execution']):
|
||||
entity_id = device['id']
|
||||
|
||||
hass.bus.async_fire(EVENT_COMMAND_RECEIVED, {
|
||||
'request_id': request_id,
|
||||
ATTR_ENTITY_ID: entity_id,
|
||||
'execution': execution
|
||||
})
|
||||
hass.bus.async_fire(
|
||||
EVENT_COMMAND_RECEIVED,
|
||||
{
|
||||
'request_id': data.request_id,
|
||||
ATTR_ENTITY_ID: entity_id,
|
||||
'execution': execution
|
||||
},
|
||||
context=data.context)
|
||||
|
||||
# Happens if error occurred. Skip entity for further processing
|
||||
if entity_id in results:
|
||||
|
@ -348,10 +359,11 @@ async def handle_devices_execute(hass, config, request_id, payload):
|
|||
}
|
||||
continue
|
||||
|
||||
entities[entity_id] = _GoogleEntity(hass, config, state)
|
||||
entities[entity_id] = _GoogleEntity(hass, data.config, state)
|
||||
|
||||
try:
|
||||
await entities[entity_id].execute(execution['command'],
|
||||
data,
|
||||
execution.get('params', {}))
|
||||
except SmartHomeError as err:
|
||||
results[entity_id] = {
|
||||
|
@ -378,7 +390,7 @@ async def handle_devices_execute(hass, config, request_id, payload):
|
|||
|
||||
|
||||
@HANDLERS.register('action.devices.DISCONNECT')
|
||||
async def async_devices_disconnect(hass, config, request_id, payload):
|
||||
async def async_devices_disconnect(hass, data, payload):
|
||||
"""Handle action.devices.DISCONNECT request.
|
||||
|
||||
https://developers.google.com/actions/smarthome/create#actiondevicesdisconnect
|
||||
|
|
|
@ -102,7 +102,7 @@ class _Trait:
|
|||
"""Test if command can be executed."""
|
||||
return command in self.commands
|
||||
|
||||
async def execute(self, command, params):
|
||||
async def execute(self, command, data, params):
|
||||
"""Execute a trait command."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -159,7 +159,7 @@ class BrightnessTrait(_Trait):
|
|||
|
||||
return response
|
||||
|
||||
async def execute(self, command, params):
|
||||
async def execute(self, command, data, params):
|
||||
"""Execute a brightness command."""
|
||||
domain = self.state.domain
|
||||
|
||||
|
@ -168,20 +168,20 @@ class BrightnessTrait(_Trait):
|
|||
light.DOMAIN, light.SERVICE_TURN_ON, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id,
|
||||
light.ATTR_BRIGHTNESS_PCT: params['brightness']
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
elif domain == cover.DOMAIN:
|
||||
await self.hass.services.async_call(
|
||||
cover.DOMAIN, cover.SERVICE_SET_COVER_POSITION, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id,
|
||||
cover.ATTR_POSITION: params['brightness']
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
elif domain == media_player.DOMAIN:
|
||||
await self.hass.services.async_call(
|
||||
media_player.DOMAIN, media_player.SERVICE_VOLUME_SET, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id,
|
||||
media_player.ATTR_MEDIA_VOLUME_LEVEL:
|
||||
params['brightness'] / 100
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
|
||||
|
||||
@register_trait
|
||||
|
@ -221,7 +221,7 @@ class OnOffTrait(_Trait):
|
|||
return {'on': self.state.state != cover.STATE_CLOSED}
|
||||
return {'on': self.state.state != STATE_OFF}
|
||||
|
||||
async def execute(self, command, params):
|
||||
async def execute(self, command, data, params):
|
||||
"""Execute an OnOff command."""
|
||||
domain = self.state.domain
|
||||
|
||||
|
@ -242,7 +242,7 @@ class OnOffTrait(_Trait):
|
|||
|
||||
await self.hass.services.async_call(service_domain, service, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
|
||||
|
||||
@register_trait
|
||||
|
@ -288,7 +288,7 @@ class ColorSpectrumTrait(_Trait):
|
|||
return (command in self.commands and
|
||||
'spectrumRGB' in params.get('color', {}))
|
||||
|
||||
async def execute(self, command, params):
|
||||
async def execute(self, command, data, params):
|
||||
"""Execute a color spectrum command."""
|
||||
# Convert integer to hex format and left pad with 0's till length 6
|
||||
hex_value = "{0:06x}".format(params['color']['spectrumRGB'])
|
||||
|
@ -298,7 +298,7 @@ class ColorSpectrumTrait(_Trait):
|
|||
await self.hass.services.async_call(light.DOMAIN, SERVICE_TURN_ON, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id,
|
||||
light.ATTR_HS_COLOR: color
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
|
||||
|
||||
@register_trait
|
||||
|
@ -355,7 +355,7 @@ class ColorTemperatureTrait(_Trait):
|
|||
return (command in self.commands and
|
||||
'temperature' in params.get('color', {}))
|
||||
|
||||
async def execute(self, command, params):
|
||||
async def execute(self, command, data, params):
|
||||
"""Execute a color temperature command."""
|
||||
temp = color_util.color_temperature_kelvin_to_mired(
|
||||
params['color']['temperature'])
|
||||
|
@ -371,7 +371,7 @@ class ColorTemperatureTrait(_Trait):
|
|||
await self.hass.services.async_call(light.DOMAIN, SERVICE_TURN_ON, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id,
|
||||
light.ATTR_COLOR_TEMP: temp,
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
|
||||
|
||||
@register_trait
|
||||
|
@ -400,13 +400,14 @@ class SceneTrait(_Trait):
|
|||
"""Return scene query attributes."""
|
||||
return {}
|
||||
|
||||
async def execute(self, command, params):
|
||||
async def execute(self, command, data, params):
|
||||
"""Execute a scene command."""
|
||||
# Don't block for scripts as they can be slow.
|
||||
await self.hass.services.async_call(
|
||||
self.state.domain, SERVICE_TURN_ON, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id
|
||||
}, blocking=self.state.domain != script.DOMAIN)
|
||||
}, blocking=self.state.domain != script.DOMAIN,
|
||||
context=data.context)
|
||||
|
||||
|
||||
@register_trait
|
||||
|
@ -434,12 +435,12 @@ class DockTrait(_Trait):
|
|||
"""Return dock query attributes."""
|
||||
return {'isDocked': self.state.state == vacuum.STATE_DOCKED}
|
||||
|
||||
async def execute(self, command, params):
|
||||
async def execute(self, command, data, params):
|
||||
"""Execute a dock command."""
|
||||
await self.hass.services.async_call(
|
||||
self.state.domain, vacuum.SERVICE_RETURN_TO_BASE, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
|
||||
|
||||
@register_trait
|
||||
|
@ -473,30 +474,30 @@ class StartStopTrait(_Trait):
|
|||
'isPaused': self.state.state == vacuum.STATE_PAUSED,
|
||||
}
|
||||
|
||||
async def execute(self, command, params):
|
||||
async def execute(self, command, data, params):
|
||||
"""Execute a StartStop command."""
|
||||
if command == COMMAND_STARTSTOP:
|
||||
if params['start']:
|
||||
await self.hass.services.async_call(
|
||||
self.state.domain, vacuum.SERVICE_START, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
else:
|
||||
await self.hass.services.async_call(
|
||||
self.state.domain, vacuum.SERVICE_STOP, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
elif command == COMMAND_PAUSEUNPAUSE:
|
||||
if params['pause']:
|
||||
await self.hass.services.async_call(
|
||||
self.state.domain, vacuum.SERVICE_PAUSE, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
else:
|
||||
await self.hass.services.async_call(
|
||||
self.state.domain, vacuum.SERVICE_START, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
|
||||
|
||||
@register_trait
|
||||
|
@ -584,7 +585,7 @@ class TemperatureSettingTrait(_Trait):
|
|||
|
||||
return response
|
||||
|
||||
async def execute(self, command, params):
|
||||
async def execute(self, command, data, params):
|
||||
"""Execute a temperature point or mode command."""
|
||||
# All sent in temperatures are always in Celsius
|
||||
unit = self.hass.config.units.temperature_unit
|
||||
|
@ -608,7 +609,7 @@ class TemperatureSettingTrait(_Trait):
|
|||
climate.DOMAIN, climate.SERVICE_SET_TEMPERATURE, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id,
|
||||
ATTR_TEMPERATURE: temp
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
|
||||
elif command == COMMAND_THERMOSTAT_TEMPERATURE_SET_RANGE:
|
||||
temp_high = temp_util.convert(
|
||||
|
@ -640,7 +641,7 @@ class TemperatureSettingTrait(_Trait):
|
|||
ATTR_ENTITY_ID: self.state.entity_id,
|
||||
climate.ATTR_TARGET_TEMP_HIGH: temp_high,
|
||||
climate.ATTR_TARGET_TEMP_LOW: temp_low,
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
|
||||
elif command == COMMAND_THERMOSTAT_SET_MODE:
|
||||
await self.hass.services.async_call(
|
||||
|
@ -648,7 +649,7 @@ class TemperatureSettingTrait(_Trait):
|
|||
ATTR_ENTITY_ID: self.state.entity_id,
|
||||
climate.ATTR_OPERATION_MODE:
|
||||
self.google_to_hass[params['thermostatMode']],
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
|
||||
|
||||
@register_trait
|
||||
|
@ -681,7 +682,7 @@ class LockUnlockTrait(_Trait):
|
|||
allowed_unlock = not params['lock'] and self.config.allow_unlock
|
||||
return params['lock'] or allowed_unlock
|
||||
|
||||
async def execute(self, command, params):
|
||||
async def execute(self, command, data, params):
|
||||
"""Execute an LockUnlock command."""
|
||||
if params['lock']:
|
||||
service = lock.SERVICE_LOCK
|
||||
|
@ -690,7 +691,7 @@ class LockUnlockTrait(_Trait):
|
|||
|
||||
await self.hass.services.async_call(lock.DOMAIN, service, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
|
||||
|
||||
@register_trait
|
||||
|
@ -760,13 +761,13 @@ class FanSpeedTrait(_Trait):
|
|||
|
||||
return response
|
||||
|
||||
async def execute(self, command, params):
|
||||
async def execute(self, command, data, params):
|
||||
"""Execute an SetFanSpeed command."""
|
||||
await self.hass.services.async_call(
|
||||
fan.DOMAIN, fan.SERVICE_SET_SPEED, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id,
|
||||
fan.ATTR_SPEED: params['fanSpeed']
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
|
||||
|
||||
@register_trait
|
||||
|
@ -934,7 +935,7 @@ class ModesTrait(_Trait):
|
|||
|
||||
return response
|
||||
|
||||
async def execute(self, command, params):
|
||||
async def execute(self, command, data, params):
|
||||
"""Execute an SetModes command."""
|
||||
settings = params.get('updateModeSettings')
|
||||
requested_source = settings.get(
|
||||
|
@ -951,4 +952,4 @@ class ModesTrait(_Trait):
|
|||
media_player.SERVICE_SELECT_SOURCE, {
|
||||
ATTR_ENTITY_ID: self.state.entity_id,
|
||||
media_player.ATTR_INPUT_SOURCE: source
|
||||
}, blocking=True)
|
||||
}, blocking=True, context=data.context)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue