Snips ASR and NLU component (#8156)
* Snips ASR and NLU component * Fix warning * Fix warnings * Fix lint issues * Add tests * Fix tabs * Fix newline * Fix quotes * Fix docstrings * Update tests * Remove logs * Fix lint warning * Update API * Fix Snips
This commit is contained in:
parent
c13fdd23c1
commit
b82003ae08
2 changed files with 191 additions and 0 deletions
138
homeassistant/components/snips.py
Normal file
138
homeassistant/components/snips.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
"""
|
||||
Support for Snips on-device ASR and NLU.
|
||||
|
||||
For more details about this component, please refer to the documentation at
|
||||
https://home-assistant.io/components/snips/
|
||||
"""
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import voluptuous as vol
|
||||
from homeassistant.helpers import template, script, config_validation as cv
|
||||
import homeassistant.loader as loader
|
||||
|
||||
DOMAIN = 'snips'
|
||||
DEPENDENCIES = ['mqtt']
|
||||
CONF_INTENTS = 'intents'
|
||||
CONF_ACTION = 'action'
|
||||
|
||||
INTENT_TOPIC = 'hermes/nlu/intentParsed'
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema({
|
||||
DOMAIN: {
|
||||
CONF_INTENTS: {
|
||||
cv.string: {
|
||||
vol.Optional(CONF_ACTION): cv.SCRIPT_SCHEMA,
|
||||
}
|
||||
}
|
||||
}
|
||||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
INTENT_SCHEMA = vol.Schema({
|
||||
vol.Required('text'): str,
|
||||
vol.Required('intent'): {
|
||||
vol.Required('intent_name'): str
|
||||
},
|
||||
vol.Optional('slots'): [{
|
||||
vol.Required('slot_name'): str,
|
||||
vol.Required('value'): {
|
||||
vol.Required('kind'): str,
|
||||
vol.Required('value'): cv.match_all
|
||||
}
|
||||
}]
|
||||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_setup(hass, config):
|
||||
"""Activate Snips component."""
|
||||
mqtt = loader.get_component('mqtt')
|
||||
intents = config[DOMAIN].get(CONF_INTENTS, {})
|
||||
handler = IntentHandler(hass, intents)
|
||||
|
||||
@asyncio.coroutine
|
||||
def message_received(topic, payload, qos):
|
||||
"""Handle new messages on MQTT."""
|
||||
LOGGER.debug("New intent: %s", payload)
|
||||
yield from handler.handle_intent(payload)
|
||||
|
||||
yield from mqtt.async_subscribe(hass, INTENT_TOPIC, message_received)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class IntentHandler(object):
|
||||
"""Help handling intents."""
|
||||
|
||||
def __init__(self, hass, intents):
|
||||
"""Initialize the intent handler."""
|
||||
self.hass = hass
|
||||
intents = copy.deepcopy(intents)
|
||||
template.attach(hass, intents)
|
||||
|
||||
for name, intent in intents.items():
|
||||
if CONF_ACTION in intent:
|
||||
intent[CONF_ACTION] = script.Script(
|
||||
hass, intent[CONF_ACTION], "Snips intent {}".format(name))
|
||||
|
||||
self.intents = intents
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_intent(self, payload):
|
||||
"""Handle an intent."""
|
||||
try:
|
||||
response = json.loads(payload)
|
||||
except TypeError:
|
||||
LOGGER.error('Received invalid JSON: %s', payload)
|
||||
return
|
||||
|
||||
try:
|
||||
response = INTENT_SCHEMA(response)
|
||||
except vol.Invalid as err:
|
||||
LOGGER.error('Intent has invalid schema: %s. %s', err, response)
|
||||
return
|
||||
|
||||
intent = response['intent']['intent_name'].split('__')[-1]
|
||||
config = self.intents.get(intent)
|
||||
|
||||
if config is None:
|
||||
LOGGER.warning("Received unknown intent %s. %s", intent, response)
|
||||
return
|
||||
|
||||
action = config.get(CONF_ACTION)
|
||||
|
||||
if action is not None:
|
||||
slots = self.parse_slots(response)
|
||||
yield from action.async_run(slots)
|
||||
|
||||
def parse_slots(self, response):
|
||||
"""Parse the intent slots."""
|
||||
parameters = {}
|
||||
|
||||
for slot in response.get('slots', []):
|
||||
key = slot['slot_name']
|
||||
value = self.get_value(slot['value'])
|
||||
if value is not None:
|
||||
parameters[key] = value
|
||||
|
||||
return parameters
|
||||
|
||||
@staticmethod
|
||||
def get_value(value):
|
||||
"""Return the value of a given slot."""
|
||||
kind = value['kind']
|
||||
|
||||
if kind == "Custom":
|
||||
return value["value"]
|
||||
elif kind == "Builtin":
|
||||
try:
|
||||
return value["value"]["value"]
|
||||
except KeyError:
|
||||
return None
|
||||
else:
|
||||
LOGGER.warning('Received unknown slot type: %s', kind)
|
||||
|
||||
return None
|
53
tests/components/test_snips.py
Normal file
53
tests/components/test_snips.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
"""Test the Snips component."""
|
||||
import asyncio
|
||||
|
||||
from homeassistant.bootstrap import async_setup_component
|
||||
from tests.common import async_fire_mqtt_message, async_mock_service
|
||||
|
||||
EXAMPLE_MSG = """
|
||||
{
|
||||
"text": "turn the lights green",
|
||||
"intent": {
|
||||
"intent_name": "Lights",
|
||||
"probability": 1
|
||||
},
|
||||
"slots": [
|
||||
{
|
||||
"slot_name": "light_color",
|
||||
"value": {
|
||||
"kind": "Custom",
|
||||
"value": "blue"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_snips_call_action(hass, mqtt_mock):
|
||||
"""Test calling action via Snips."""
|
||||
calls = async_mock_service(hass, 'test', 'service')
|
||||
|
||||
result = yield from async_setup_component(hass, "snips", {
|
||||
"snips": {
|
||||
"intents": {
|
||||
"Lights": {
|
||||
"action": {
|
||||
"service": "test.service",
|
||||
"data_template": {
|
||||
"color": "{{ light_color }}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
assert result
|
||||
|
||||
async_fire_mqtt_message(hass, 'hermes/nlu/intentParsed',
|
||||
EXAMPLE_MSG)
|
||||
yield from hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
call = calls[0]
|
||||
assert call.data.get('color') == 'blue'
|
Loading…
Add table
Reference in a new issue