Asyncio notify component migration (#5377)
* Async migrate notify/platform * convert group to async * fix unittest
This commit is contained in:
parent
f7ac644c11
commit
2a362fd1ff
6 changed files with 136 additions and 88 deletions
|
@ -4,13 +4,15 @@ Provides functionality to notify people.
|
|||
For more details about this component, please refer to the documentation at
|
||||
https://home-assistant.io/components/notify/
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
import homeassistant.bootstrap as bootstrap
|
||||
from homeassistant.bootstrap import async_prepare_setup_platform
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.config import load_yaml_config_file
|
||||
from homeassistant.const import CONF_NAME, CONF_PLATFORM
|
||||
|
@ -64,91 +66,110 @@ def send_message(hass, message, title=None, data=None):
|
|||
hass.services.call(DOMAIN, SERVICE_NOTIFY, info)
|
||||
|
||||
|
||||
def setup(hass, config):
|
||||
@asyncio.coroutine
|
||||
def async_setup(hass, config):
|
||||
"""Setup the notify services."""
|
||||
descriptions = load_yaml_config_file(
|
||||
descriptions = yield from hass.loop.run_in_executor(
|
||||
None, load_yaml_config_file,
|
||||
os.path.join(os.path.dirname(__file__), 'services.yaml'))
|
||||
|
||||
targets = {}
|
||||
|
||||
def setup_notify_platform(platform, p_config=None, discovery_info=None):
|
||||
@asyncio.coroutine
|
||||
def async_setup_platform(p_type, p_config=None, discovery_info=None):
|
||||
"""Set up a notify platform."""
|
||||
if p_config is None:
|
||||
p_config = {}
|
||||
if discovery_info is None:
|
||||
discovery_info = {}
|
||||
|
||||
notify_implementation = bootstrap.prepare_setup_platform(
|
||||
hass, config, DOMAIN, platform)
|
||||
platform = yield from async_prepare_setup_platform(
|
||||
hass, config, DOMAIN, p_type)
|
||||
|
||||
if notify_implementation is None:
|
||||
if platform is None:
|
||||
_LOGGER.error("Unknown notification service specified")
|
||||
return False
|
||||
return
|
||||
|
||||
notify_service = notify_implementation.get_service(
|
||||
hass, p_config, discovery_info)
|
||||
_LOGGER.info("Setting up %s.%s", DOMAIN, p_type)
|
||||
notify_service = None
|
||||
try:
|
||||
if hasattr(platform, 'async_get_service'):
|
||||
notify_service = yield from \
|
||||
platform.async_get_service(hass, p_config, discovery_info)
|
||||
elif hasattr(platform, 'get_service'):
|
||||
notify_service = yield from hass.loop.run_in_executor(
|
||||
None, platform.get_service, hass, p_config, discovery_info)
|
||||
else:
|
||||
raise HomeAssistantError("Invalid notify platform.")
|
||||
|
||||
if notify_service is None:
|
||||
_LOGGER.error("Failed to initialize notification service %s",
|
||||
platform)
|
||||
return False
|
||||
if notify_service is None:
|
||||
_LOGGER.error(
|
||||
"Failed to initialize notification service %s", p_type)
|
||||
return
|
||||
|
||||
def notify_message(notify_service, call):
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception('Error setting up platform %s', p_type)
|
||||
return
|
||||
|
||||
notify_service.hass = hass
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_notify_message(service):
|
||||
"""Handle sending notification message service calls."""
|
||||
kwargs = {}
|
||||
message = call.data[ATTR_MESSAGE]
|
||||
title = call.data.get(ATTR_TITLE)
|
||||
message = service.data[ATTR_MESSAGE]
|
||||
title = service.data.get(ATTR_TITLE)
|
||||
|
||||
if title:
|
||||
title.hass = hass
|
||||
kwargs[ATTR_TITLE] = title.render()
|
||||
kwargs[ATTR_TITLE] = title.async_render()
|
||||
|
||||
if targets.get(call.service) is not None:
|
||||
kwargs[ATTR_TARGET] = [targets[call.service]]
|
||||
elif call.data.get(ATTR_TARGET) is not None:
|
||||
kwargs[ATTR_TARGET] = call.data.get(ATTR_TARGET)
|
||||
if targets.get(service.service) is not None:
|
||||
kwargs[ATTR_TARGET] = [targets[service.service]]
|
||||
elif service.data.get(ATTR_TARGET) is not None:
|
||||
kwargs[ATTR_TARGET] = service.data.get(ATTR_TARGET)
|
||||
|
||||
message.hass = hass
|
||||
kwargs[ATTR_MESSAGE] = message.render()
|
||||
kwargs[ATTR_DATA] = call.data.get(ATTR_DATA)
|
||||
kwargs[ATTR_MESSAGE] = message.async_render()
|
||||
kwargs[ATTR_DATA] = service.data.get(ATTR_DATA)
|
||||
|
||||
notify_service.send_message(**kwargs)
|
||||
|
||||
service_call_handler = partial(notify_message, notify_service)
|
||||
yield from notify_service.async_send_message(**kwargs)
|
||||
|
||||
if hasattr(notify_service, 'targets'):
|
||||
platform_name = (
|
||||
p_config.get(CONF_NAME) or discovery_info.get(CONF_NAME) or
|
||||
platform)
|
||||
p_type)
|
||||
for name, target in notify_service.targets.items():
|
||||
target_name = slugify('{}_{}'.format(platform_name, name))
|
||||
targets[target_name] = target
|
||||
hass.services.register(DOMAIN, target_name,
|
||||
service_call_handler,
|
||||
descriptions.get(SERVICE_NOTIFY),
|
||||
schema=NOTIFY_SERVICE_SCHEMA)
|
||||
hass.services.async_register(
|
||||
DOMAIN, target_name, async_notify_message,
|
||||
descriptions.get(SERVICE_NOTIFY),
|
||||
schema=NOTIFY_SERVICE_SCHEMA)
|
||||
|
||||
platform_name = (
|
||||
p_config.get(CONF_NAME) or discovery_info.get(CONF_NAME) or
|
||||
SERVICE_NOTIFY)
|
||||
platform_name_slug = slugify(platform_name)
|
||||
|
||||
hass.services.register(
|
||||
DOMAIN, platform_name_slug, service_call_handler,
|
||||
hass.services.async_register(
|
||||
DOMAIN, platform_name_slug, async_notify_message,
|
||||
descriptions.get(SERVICE_NOTIFY), schema=NOTIFY_SERVICE_SCHEMA)
|
||||
|
||||
return True
|
||||
|
||||
for platform, p_config in config_per_platform(config, DOMAIN):
|
||||
if not setup_notify_platform(platform, p_config):
|
||||
_LOGGER.error("Failed to set up platform %s", platform)
|
||||
continue
|
||||
setup_tasks = [async_setup_platform(p_type, p_config) for p_type, p_config
|
||||
in config_per_platform(config, DOMAIN)]
|
||||
|
||||
def platform_discovered(platform, info):
|
||||
if setup_tasks:
|
||||
yield from asyncio.wait(setup_tasks, loop=hass.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_platform_discovered(platform, info):
|
||||
"""Callback to load a platform."""
|
||||
setup_notify_platform(platform, discovery_info=info)
|
||||
yield from async_setup_platform(platform, discovery_info=info)
|
||||
|
||||
discovery.listen_platform(hass, DOMAIN, platform_discovered)
|
||||
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -156,9 +177,20 @@ def setup(hass, config):
|
|||
class BaseNotificationService(object):
|
||||
"""An abstract class for notification services."""
|
||||
|
||||
hass = None
|
||||
|
||||
def send_message(self, message, **kwargs):
|
||||
"""Send a message.
|
||||
|
||||
kwargs can contain ATTR_TITLE to specify a title.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
def async_send_message(self, message, **kwargs):
|
||||
"""Send a message.
|
||||
|
||||
kwargs can contain ATTR_TITLE to specify a title.
|
||||
This method must be run in the event loop and returns a coroutine.
|
||||
"""
|
||||
return self.hass.loop.run_in_executor(
|
||||
None, partial(self.send_message, message, **kwargs))
|
||||
|
|
|
@ -4,7 +4,7 @@ import asyncio
|
|||
import voluptuous as vol
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.components.notify import (
|
||||
PLATFORM_SCHEMA, BaseNotificationService)
|
||||
PLATFORM_SCHEMA, BaseNotificationService, ATTR_TARGET)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -32,20 +32,16 @@ class DiscordNotificationService(BaseNotificationService):
|
|||
self.hass = hass
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_send_message(self, message, target):
|
||||
def async_send_message(self, message, **kwargs):
|
||||
"""Login to Discord, send message to channel(s) and log out."""
|
||||
import discord
|
||||
discord_bot = discord.Client(loop=self.hass.loop)
|
||||
|
||||
yield from discord_bot.login(self.token)
|
||||
|
||||
for channelid in target:
|
||||
for channelid in kwargs[ATTR_TARGET]:
|
||||
channel = discord.Object(id=channelid)
|
||||
yield from discord_bot.send_message(channel, message)
|
||||
|
||||
yield from discord_bot.logout()
|
||||
yield from discord_bot.close()
|
||||
|
||||
def send_message(self, message=None, target=None, **kwargs):
|
||||
"""Send a message using Discord."""
|
||||
self.hass.async_add_job(self.async_send_message(message, target))
|
||||
|
|
|
@ -4,15 +4,15 @@ Group platform for notify component.
|
|||
For more details about this platform, please refer to the documentation at
|
||||
https://home-assistant.io/components/notify.group/
|
||||
"""
|
||||
import asyncio
|
||||
import collections
|
||||
from copy import deepcopy
|
||||
import logging
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import ATTR_SERVICE
|
||||
from homeassistant.components.notify import (DOMAIN, ATTR_MESSAGE, ATTR_DATA,
|
||||
PLATFORM_SCHEMA,
|
||||
BaseNotificationService)
|
||||
from homeassistant.components.notify import (
|
||||
DOMAIN, ATTR_MESSAGE, ATTR_DATA, PLATFORM_SCHEMA, BaseNotificationService)
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -28,7 +28,10 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
|
|||
|
||||
|
||||
def update(input_dict, update_source):
|
||||
"""Deep update a dictionary."""
|
||||
"""Deep update a dictionary.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
for key, val in update_source.items():
|
||||
if isinstance(val, collections.Mapping):
|
||||
recurse = update(input_dict.get(key, {}), val)
|
||||
|
@ -38,7 +41,8 @@ def update(input_dict, update_source):
|
|||
return input_dict
|
||||
|
||||
|
||||
def get_service(hass, config, discovery_info=None):
|
||||
@asyncio.coroutine
|
||||
def async_get_service(hass, config, discovery_info=None):
|
||||
"""Get the Group notification service."""
|
||||
return GroupNotifyPlatform(hass, config.get(CONF_SERVICES))
|
||||
|
||||
|
@ -51,14 +55,19 @@ class GroupNotifyPlatform(BaseNotificationService):
|
|||
self.hass = hass
|
||||
self.entities = entities
|
||||
|
||||
def send_message(self, message="", **kwargs):
|
||||
@asyncio.coroutine
|
||||
def async_send_message(self, message="", **kwargs):
|
||||
"""Send message to all entities in the group."""
|
||||
payload = {ATTR_MESSAGE: message}
|
||||
payload.update({key: val for key, val in kwargs.items() if val})
|
||||
|
||||
tasks = []
|
||||
for entity in self.entities:
|
||||
sending_payload = deepcopy(payload.copy())
|
||||
if entity.get(ATTR_DATA) is not None:
|
||||
update(sending_payload, entity.get(ATTR_DATA))
|
||||
self.hass.services.call(DOMAIN, entity.get(ATTR_SERVICE),
|
||||
sending_payload)
|
||||
tasks.append(self.hass.services.async_call(
|
||||
DOMAIN, entity.get(ATTR_SERVICE), sending_payload))
|
||||
|
||||
if tasks:
|
||||
yield from asyncio.wait(tasks, loop=self.hass.loop)
|
||||
|
|
|
@ -41,7 +41,9 @@ class TestApns(unittest.TestCase):
|
|||
assert setup_component(self.hass, notify.DOMAIN, CONFIG)
|
||||
assert handle_config[notify.DOMAIN]
|
||||
|
||||
def test_apns_setup_full(self):
|
||||
@patch('os.path.isfile', return_value=True)
|
||||
@patch('os.access', return_value=True)
|
||||
def test_apns_setup_full(self, mock_access, mock_isfile):
|
||||
"""Test setup with all data."""
|
||||
config = {
|
||||
'notify': {
|
||||
|
@ -53,7 +55,9 @@ class TestApns(unittest.TestCase):
|
|||
}
|
||||
}
|
||||
|
||||
self.assertTrue(notify.setup(self.hass, config))
|
||||
with assert_setup_component(1) as handle_config:
|
||||
assert setup_component(self.hass, notify.DOMAIN, config)
|
||||
assert handle_config[notify.DOMAIN]
|
||||
|
||||
def test_apns_setup_missing_name(self):
|
||||
"""Test setup with missing name."""
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""The tests for the notify demo platform."""
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
|
@ -16,6 +17,12 @@ CONFIG = {
|
|||
}
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def mock_setup_platform():
|
||||
"""Mock prepare_setup_platform."""
|
||||
return None
|
||||
|
||||
|
||||
class TestNotifyDemo(unittest.TestCase):
|
||||
"""Test the demo notify."""
|
||||
|
||||
|
@ -45,23 +52,16 @@ class TestNotifyDemo(unittest.TestCase):
|
|||
"""Test setup."""
|
||||
self._setup_notify()
|
||||
|
||||
@patch('homeassistant.bootstrap.prepare_setup_platform')
|
||||
@patch('homeassistant.bootstrap.async_prepare_setup_platform',
|
||||
return_value=mock_setup_platform())
|
||||
def test_no_prepare_setup_platform(self, mock_prep_setup_platform):
|
||||
"""Test missing notify platform."""
|
||||
mock_prep_setup_platform.return_value = None
|
||||
with self.assertLogs('homeassistant.components.notify',
|
||||
level='ERROR') as log_handle:
|
||||
self._setup_notify()
|
||||
self.hass.block_till_done()
|
||||
assert mock_prep_setup_platform.called
|
||||
self.assertEqual(
|
||||
log_handle.output,
|
||||
['ERROR:homeassistant.components.notify:'
|
||||
'Unknown notification service specified',
|
||||
'ERROR:homeassistant.components.notify:'
|
||||
'Failed to set up platform demo'])
|
||||
with assert_setup_component(0):
|
||||
setup_component(self.hass, notify.DOMAIN, CONFIG)
|
||||
|
||||
@patch('homeassistant.components.notify.demo.get_service')
|
||||
assert mock_prep_setup_platform.called
|
||||
|
||||
@patch('homeassistant.components.notify.demo.get_service', autospec=True)
|
||||
def test_no_notify_service(self, mock_demo_get_service):
|
||||
"""Test missing platform notify service instance."""
|
||||
mock_demo_get_service.return_value = None
|
||||
|
@ -73,11 +73,9 @@ class TestNotifyDemo(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
log_handle.output,
|
||||
['ERROR:homeassistant.components.notify:'
|
||||
'Failed to initialize notification service demo',
|
||||
'ERROR:homeassistant.components.notify:'
|
||||
'Failed to set up platform demo'])
|
||||
'Failed to initialize notification service demo'])
|
||||
|
||||
@patch('homeassistant.components.notify.demo.get_service')
|
||||
@patch('homeassistant.components.notify.demo.get_service', autospec=True)
|
||||
def test_discover_notify(self, mock_demo_get_service):
|
||||
"""Test discovery of notify demo platform."""
|
||||
assert notify.DOMAIN not in self.hass.config.components
|
||||
|
|
|
@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch
|
|||
from homeassistant.bootstrap import setup_component
|
||||
import homeassistant.components.notify as notify
|
||||
from homeassistant.components.notify import group, demo
|
||||
from homeassistant.util.async import run_coroutine_threadsafe
|
||||
|
||||
from tests.common import assert_setup_component, get_test_home_assistant
|
||||
|
||||
|
@ -16,8 +17,11 @@ class TestNotifyGroup(unittest.TestCase):
|
|||
"""Setup things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
self.events = []
|
||||
self.service1 = MagicMock()
|
||||
self.service2 = MagicMock()
|
||||
self.service1 = demo.DemoNotificationService(self.hass)
|
||||
self.service2 = demo.DemoNotificationService(self.hass)
|
||||
|
||||
self.service1.send_message = MagicMock(autospec=True)
|
||||
self.service2.send_message = MagicMock(autospec=True)
|
||||
|
||||
def mock_get_service(hass, config, discovery_info=None):
|
||||
if config['name'] == 'demo1':
|
||||
|
@ -37,11 +41,14 @@ class TestNotifyGroup(unittest.TestCase):
|
|||
}]
|
||||
})
|
||||
|
||||
self.service = group.get_service(self.hass, {'services': [
|
||||
{'service': 'demo1'},
|
||||
{'service': 'demo2',
|
||||
'data': {'target': 'unnamed device',
|
||||
'data': {'test': 'message'}}}]})
|
||||
self.service = run_coroutine_threadsafe(
|
||||
group.async_get_service(self.hass, {'services': [
|
||||
{'service': 'demo1'},
|
||||
{'service': 'demo2',
|
||||
'data': {'target': 'unnamed device',
|
||||
'data': {'test': 'message'}}}]}),
|
||||
self.hass.loop
|
||||
).result()
|
||||
|
||||
assert self.service is not None
|
||||
|
||||
|
@ -51,17 +58,19 @@ class TestNotifyGroup(unittest.TestCase):
|
|||
|
||||
def test_send_message_with_data(self):
|
||||
"""Test sending a message with to a notify group."""
|
||||
self.service.send_message('Hello', title='Test notification',
|
||||
data={'hello': 'world'})
|
||||
run_coroutine_threadsafe(
|
||||
self.service.async_send_message(
|
||||
'Hello', title='Test notification', data={'hello': 'world'}),
|
||||
self.hass.loop).result()
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert self.service1.send_message.mock_calls[0][1][0] == 'Hello'
|
||||
assert self.service1.send_message.mock_calls[0][2] == {
|
||||
'message': 'Hello',
|
||||
'title': 'Test notification',
|
||||
'data': {'hello': 'world'}
|
||||
}
|
||||
assert self.service2.send_message.mock_calls[0][1][0] == 'Hello'
|
||||
assert self.service2.send_message.mock_calls[0][2] == {
|
||||
'message': 'Hello',
|
||||
'target': ['unnamed device'],
|
||||
'title': 'Test notification',
|
||||
'data': {'hello': 'world', 'test': 'message'}
|
||||
|
|
Loading…
Add table
Reference in a new issue