Support reloading the group notify platform ()

This commit is contained in:
J. Nick Koston 2020-09-02 17:12:07 -05:00 committed by GitHub
parent 2d2efeb9bb
commit a778690b64
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 470 additions and 129 deletions
homeassistant
tests

View file

@ -58,7 +58,7 @@ ATTR_ALL = "all"
SERVICE_SET = "set" SERVICE_SET = "set"
SERVICE_REMOVE = "remove" SERVICE_REMOVE = "remove"
PLATFORMS = ["light", "cover"] PLATFORMS = ["light", "cover", "notify"]
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View file

@ -1,6 +1,6 @@
# Describes the format for available group services # Describes the format for available group services
reload: reload:
description: Reload group configuration. description: Reload group configuration, entities, and notify services.
set: set:
description: Create/Update a user group. description: Create/Update a user group.

View file

@ -2,11 +2,12 @@
import asyncio import asyncio
from functools import partial from functools import partial
import logging import logging
from typing import Optional from typing import Any, Dict, Optional
import voluptuous as vol import voluptuous as vol
from homeassistant.const import CONF_NAME, CONF_PLATFORM from homeassistant.const import CONF_NAME, CONF_PLATFORM
from homeassistant.core import ServiceCall
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform, discovery from homeassistant.helpers import config_per_platform, discovery
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -37,10 +38,6 @@ DOMAIN = "notify"
SERVICE_NOTIFY = "notify" SERVICE_NOTIFY = "notify"
NOTIFY_SERVICES = "notify_services" NOTIFY_SERVICES = "notify_services"
SERVICE = "service"
TARGETS = "targets"
FRIENDLY_NAME = "friendly_name"
TARGET_FRIENDLY_NAME = "target_friendly_name"
PLATFORM_SCHEMA = vol.Schema( PLATFORM_SCHEMA = vol.Schema(
{vol.Required(CONF_PLATFORM): cv.string, vol.Optional(CONF_NAME): cv.string}, {vol.Required(CONF_PLATFORM): cv.string, vol.Optional(CONF_NAME): cv.string},
@ -58,88 +55,160 @@ NOTIFY_SERVICE_SCHEMA = vol.Schema(
@bind_hass @bind_hass
async def async_reload(hass, integration_name): async def async_reload(hass: HomeAssistantType, integration_name: str) -> None:
"""Register notify services for an integration.""" """Register notify services for an integration."""
if not _async_integration_has_notify_services(hass, integration_name):
return
tasks = [
notify_service.async_register_services()
for notify_service in hass.data[NOTIFY_SERVICES][integration_name]
]
await asyncio.gather(*tasks)
@bind_hass
async def async_reset_platform(hass: HomeAssistantType, integration_name: str) -> None:
"""Unregister notify services for an integration."""
if not _async_integration_has_notify_services(hass, integration_name):
return
tasks = [
notify_service.async_unregister_services()
for notify_service in hass.data[NOTIFY_SERVICES][integration_name]
]
await asyncio.gather(*tasks)
del hass.data[NOTIFY_SERVICES][integration_name]
def _async_integration_has_notify_services(
hass: HomeAssistantType, integration_name: str
) -> bool:
"""Determine if an integration has notify services registered."""
if ( if (
NOTIFY_SERVICES not in hass.data NOTIFY_SERVICES not in hass.data
or integration_name not in hass.data[NOTIFY_SERVICES] or integration_name not in hass.data[NOTIFY_SERVICES]
): ):
return return False
tasks = [ return True
_async_setup_notify_services(hass, data)
for data in hass.data[NOTIFY_SERVICES][integration_name]
]
await asyncio.gather(*tasks)
async def _async_setup_notify_services(hass, data): class BaseNotificationService:
"""Create or remove the notify services.""" """An abstract class for notification services."""
notify_service = data[SERVICE]
friendly_name = data[FRIENDLY_NAME]
targets = data[TARGETS]
async def _async_notify_message(service): hass: Optional[HomeAssistantType] = None
def send_message(self, message, **kwargs):
"""Send a message.
kwargs can contain ATTR_TITLE to specify a title.
"""
raise NotImplementedError()
async def async_send_message(self, message: Any, **kwargs: Any) -> None:
"""Send a message.
kwargs can contain ATTR_TITLE to specify a title.
"""
await self.hass.async_add_job(partial(self.send_message, message, **kwargs)) # type: ignore
async def _async_notify_message_service(self, service: ServiceCall) -> None:
"""Handle sending notification message service calls.""" """Handle sending notification message service calls."""
await _async_notify_message_service(hass, service, notify_service, targets) kwargs = {}
message = service.data[ATTR_MESSAGE]
title = service.data.get(ATTR_TITLE)
if hasattr(notify_service, "targets"): if title:
target_friendly_name = data[TARGET_FRIENDLY_NAME] title.hass = self.hass
stale_targets = set(targets) kwargs[ATTR_TITLE] = title.async_render()
for name, target in notify_service.targets.items(): if self._registered_targets.get(service.service) is not None:
target_name = slugify(f"{target_friendly_name}_{name}") kwargs[ATTR_TARGET] = [self._registered_targets[service.service]]
if target_name in stale_targets: elif service.data.get(ATTR_TARGET) is not None:
stale_targets.remove(target_name) kwargs[ATTR_TARGET] = service.data.get(ATTR_TARGET)
if target_name in targets:
continue
targets[target_name] = target
hass.services.async_register(
DOMAIN,
target_name,
_async_notify_message,
schema=NOTIFY_SERVICE_SCHEMA,
)
for stale_target_name in stale_targets: message.hass = self.hass
del targets[stale_target_name] kwargs[ATTR_MESSAGE] = message.async_render()
hass.services.async_remove( kwargs[ATTR_DATA] = service.data.get(ATTR_DATA)
DOMAIN,
stale_target_name,
)
friendly_name_slug = slugify(friendly_name) await self.async_send_message(**kwargs)
if hass.services.has_service(DOMAIN, friendly_name_slug):
return
hass.services.async_register( async def async_setup(
DOMAIN, self,
friendly_name_slug, hass: HomeAssistantType,
_async_notify_message, service_name: str,
schema=NOTIFY_SERVICE_SCHEMA, target_service_name_prefix: str,
) ) -> None:
"""Store the data for the notify service."""
# pylint: disable=attribute-defined-outside-init
self.hass = hass
self._service_name = service_name
self._target_service_name_prefix = target_service_name_prefix
self._registered_targets: Dict = {}
async def async_register_services(self) -> None:
"""Create or update the notify services."""
assert self.hass
async def _async_notify_message_service(hass, service, notify_service, targets): if hasattr(self, "targets"):
"""Handle sending notification message service calls.""" stale_targets = set(self._registered_targets)
kwargs = {}
message = service.data[ATTR_MESSAGE]
title = service.data.get(ATTR_TITLE)
if title: # pylint: disable=no-member
title.hass = hass for name, target in self.targets.items(): # type: ignore
kwargs[ATTR_TITLE] = title.async_render() target_name = slugify(f"{self._target_service_name_prefix}_{name}")
if target_name in stale_targets:
stale_targets.remove(target_name)
if target_name in self._registered_targets:
continue
self._registered_targets[target_name] = target
self.hass.services.async_register(
DOMAIN,
target_name,
self._async_notify_message_service,
schema=NOTIFY_SERVICE_SCHEMA,
)
if targets.get(service.service) is not None: for stale_target_name in stale_targets:
kwargs[ATTR_TARGET] = [targets[service.service]] del self._registered_targets[stale_target_name]
elif service.data.get(ATTR_TARGET) is not None: self.hass.services.async_remove(
kwargs[ATTR_TARGET] = service.data.get(ATTR_TARGET) DOMAIN,
stale_target_name,
)
message.hass = hass if self.hass.services.has_service(DOMAIN, self._service_name):
kwargs[ATTR_MESSAGE] = message.async_render() return
kwargs[ATTR_DATA] = service.data.get(ATTR_DATA)
await notify_service.async_send_message(**kwargs) self.hass.services.async_register(
DOMAIN,
self._service_name,
self._async_notify_message_service,
schema=NOTIFY_SERVICE_SCHEMA,
)
async def async_unregister_services(self) -> None:
"""Unregister the notify services."""
assert self.hass
if self._registered_targets:
remove_targets = set(self._registered_targets)
for remove_target_name in remove_targets:
del self._registered_targets[remove_target_name]
self.hass.services.async_remove(
DOMAIN,
remove_target_name,
)
if not self.hass.services.has_service(DOMAIN, self._service_name):
return
self.hass.services.async_remove(
DOMAIN,
self._service_name,
)
async def async_setup(hass, config): async def async_setup(hass, config):
@ -188,31 +257,19 @@ async def async_setup(hass, config):
_LOGGER.exception("Error setting up platform %s", integration_name) _LOGGER.exception("Error setting up platform %s", integration_name)
return return
notify_service.hass = hass
if discovery_info is None: if discovery_info is None:
discovery_info = {} discovery_info = {}
target_friendly_name = ( conf_name = p_config.get(CONF_NAME) or discovery_info.get(CONF_NAME)
p_config.get(CONF_NAME) or discovery_info.get(CONF_NAME) or integration_name target_service_name_prefix = conf_name or integration_name
service_name = slugify(conf_name or SERVICE_NOTIFY)
await notify_service.async_setup(hass, service_name, target_service_name_prefix)
await notify_service.async_register_services()
hass.data[NOTIFY_SERVICES].setdefault(integration_name, []).append(
notify_service
) )
friendly_name = (
p_config.get(CONF_NAME) or discovery_info.get(CONF_NAME) or SERVICE_NOTIFY
)
data = {
FRIENDLY_NAME: friendly_name,
# The targets use a slightly different friendly name
# selection pattern than the base service
TARGET_FRIENDLY_NAME: target_friendly_name,
SERVICE: notify_service,
TARGETS: {},
}
hass.data[NOTIFY_SERVICES].setdefault(integration_name, [])
hass.data[NOTIFY_SERVICES][integration_name].append(data)
await _async_setup_notify_services(hass, data)
hass.config.components.add(f"{DOMAIN}.{integration_name}") hass.config.components.add(f"{DOMAIN}.{integration_name}")
return True return True
@ -232,23 +289,3 @@ async def async_setup(hass, config):
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered) discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
return True return True
class BaseNotificationService:
"""An abstract class for notification services."""
hass: Optional[HomeAssistantType] = None
def send_message(self, message, **kwargs):
"""Send a message.
kwargs can contain ATTR_TITLE to specify a title.
"""
raise NotImplementedError()
async def async_send_message(self, message, **kwargs):
"""Send a message.
kwargs can contain ATTR_TITLE to specify a title.
"""
await self.hass.async_add_job(partial(self.send_message, message, **kwargs))

View file

@ -2,7 +2,7 @@
import asyncio import asyncio
import logging import logging
from typing import Any, Dict, Iterable, Optional from typing import Any, Dict, Iterable, List, Optional
from homeassistant import config as conf_util from homeassistant import config as conf_util
from homeassistant.const import SERVICE_RELOAD from homeassistant.const import SERVICE_RELOAD
@ -12,6 +12,7 @@ from homeassistant.helpers import config_per_platform
from homeassistant.helpers.entity_platform import DATA_ENTITY_PLATFORM, EntityPlatform from homeassistant.helpers.entity_platform import DATA_ENTITY_PLATFORM, EntityPlatform
from homeassistant.helpers.typing import HomeAssistantType from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import async_get_integration from homeassistant.loader import async_get_integration
from homeassistant.setup import async_setup_component
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -34,29 +35,94 @@ async def async_reload_integration_platforms(
_LOGGER.error(err) _LOGGER.error(err)
return return
for integration_platform in integration_platforms: tasks = [
platform = async_get_platform(hass, integration_name, integration_platform) _resetup_platform(
hass, integration_name, integration_platform, unprocessed_conf
if not platform:
continue
integration = await async_get_integration(hass, integration_platform)
conf = await conf_util.async_process_component_config(
hass, unprocessed_conf, integration
) )
for integration_platform in integration_platforms
]
if not conf: await asyncio.gather(*tasks)
async def _resetup_platform(
hass: HomeAssistantType,
integration_name: str,
integration_platform: str,
unprocessed_conf: Dict,
) -> None:
"""Resetup a platform."""
integration = await async_get_integration(hass, integration_platform)
conf = await conf_util.async_process_component_config(
hass, unprocessed_conf, integration
)
if not conf:
return
root_config: Dict = {integration_platform: []}
# Extract only the config for template, ignore the rest.
for p_type, p_config in config_per_platform(conf, integration_platform):
if p_type != integration_name:
continue continue
await platform.async_reset() root_config[integration_platform].append(p_config)
# Extract only the config for template, ignore the rest. component = integration.get_component()
for p_type, p_config in config_per_platform(conf, integration_platform):
if p_type != integration_name:
continue
await platform.async_setup(p_config) # type: ignore if hasattr(component, "async_reset_platform"):
# If the integration has its own way to reset
# use this method.
await component.async_reset_platform(hass, integration_name) # type: ignore
await component.async_setup(hass, root_config) # type: ignore
return
# If its an entity platform, we use the entity_platform
# async_reset method
platform = async_get_platform(hass, integration_name, integration_platform)
if platform:
await _async_reconfig_platform(platform, root_config[integration_platform])
return
if not root_config[integration_platform]:
# No config for this platform
# and its not loaded. Nothing to do
return
await _async_setup_platform(
hass, integration_name, integration_platform, root_config[integration_platform]
)
async def _async_setup_platform(
hass: HomeAssistantType,
integration_name: str,
integration_platform: str,
platform_configs: List[Dict],
) -> None:
"""Platform for the first time when new configuration is added."""
if integration_platform not in hass.data:
await async_setup_component(
hass, integration_platform, {integration_platform: platform_configs}
)
return
entity_component = hass.data[integration_platform]
tasks = [
entity_component.async_setup_platform(integration_name, p_config)
for p_config in platform_configs
]
await asyncio.gather(*tasks)
async def _async_reconfig_platform(
platform: EntityPlatform, platform_configs: List[Dict]
) -> None:
"""Reconfigure an already loaded platform."""
await platform.async_reset()
tasks = [platform.async_setup(p_config) for p_config in platform_configs] # type: ignore
await asyncio.gather(*tasks)
async def async_integration_yaml_config( async def async_integration_yaml_config(

View file

@ -683,5 +683,79 @@ async def test_reload(hass):
assert hass.states.get("light.outside_patio_lights_g") is not None assert hass.states.get("light.outside_patio_lights_g") is not None
async def test_reload_with_platform_not_setup(hass):
"""Test the ability to reload lights."""
hass.states.async_set("light.bowl", STATE_ON)
await async_setup_component(
hass,
LIGHT_DOMAIN,
{
LIGHT_DOMAIN: [
{"platform": "demo"},
]
},
)
assert await async_setup_component(
hass,
"group",
{
"group": {
"group_zero": {"entities": "light.Bowl", "icon": "mdi:work"},
}
},
)
await hass.async_block_till_done()
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"group/configuration.yaml",
)
with patch.object(hass_config, "YAML_CONFIG_FILE", yaml_path):
await hass.services.async_call(
DOMAIN,
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert hass.states.get("light.light_group") is None
assert hass.states.get("light.master_hall_lights_g") is not None
assert hass.states.get("light.outside_patio_lights_g") is not None
async def test_reload_with_base_integration_platform_not_setup(hass):
"""Test the ability to reload lights."""
assert await async_setup_component(
hass,
"group",
{
"group": {
"group_zero": {"entities": "light.Bowl", "icon": "mdi:work"},
}
},
)
await hass.async_block_till_done()
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"group/configuration.yaml",
)
with patch.object(hass_config, "YAML_CONFIG_FILE", yaml_path):
await hass.services.async_call(
DOMAIN,
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert hass.states.get("light.light_group") is None
assert hass.states.get("light.master_hall_lights_g") is not None
assert hass.states.get("light.outside_patio_lights_g") is not None
def _get_fixtures_base_path(): def _get_fixtures_base_path():
return path.dirname(path.dirname(path.dirname(__file__))) return path.dirname(path.dirname(path.dirname(__file__)))

View file

@ -1,11 +1,14 @@
"""The tests for the notify.group platform.""" """The tests for the notify.group platform."""
import asyncio import asyncio
from os import path
import unittest import unittest
from homeassistant import config as hass_config
import homeassistant.components.demo.notify as demo import homeassistant.components.demo.notify as demo
from homeassistant.components.group import SERVICE_RELOAD
import homeassistant.components.group.notify as group import homeassistant.components.group.notify as group
import homeassistant.components.notify as notify import homeassistant.components.notify as notify
from homeassistant.setup import setup_component from homeassistant.setup import async_setup_component, setup_component
from tests.async_mock import MagicMock, patch from tests.async_mock import MagicMock, patch
from tests.common import assert_setup_component, get_test_home_assistant from tests.common import assert_setup_component, get_test_home_assistant
@ -90,3 +93,58 @@ class TestNotifyGroup(unittest.TestCase):
"title": "Test notification", "title": "Test notification",
"data": {"hello": "world", "test": "message"}, "data": {"hello": "world", "test": "message"},
} }
async def test_reload_notify(hass):
"""Verify we can reload the notify service."""
assert await async_setup_component(
hass,
"group",
{},
)
await hass.async_block_till_done()
assert await async_setup_component(
hass,
notify.DOMAIN,
{
notify.DOMAIN: [
{"name": "demo1", "platform": "demo"},
{"name": "demo2", "platform": "demo"},
{
"name": "group_notify",
"platform": "group",
"services": [{"service": "demo1"}],
},
]
},
)
await hass.async_block_till_done()
assert hass.services.has_service(notify.DOMAIN, "demo1")
assert hass.services.has_service(notify.DOMAIN, "demo2")
assert hass.services.has_service(notify.DOMAIN, "group_notify")
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"group/configuration.yaml",
)
with patch.object(hass_config, "YAML_CONFIG_FILE", yaml_path):
await hass.services.async_call(
"group",
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert hass.services.has_service(notify.DOMAIN, "demo1")
assert hass.services.has_service(notify.DOMAIN, "demo2")
assert not hass.services.has_service(notify.DOMAIN, "group_notify")
assert hass.services.has_service(notify.DOMAIN, "new_group_notify")
def _get_fixtures_base_path():
return path.dirname(path.dirname(path.dirname(__file__)))

View file

@ -9,3 +9,10 @@ light:
entities: entities:
- light.outside_patio_lights - light.outside_patio_lights
- light.outside_patio_lights_2 - light.outside_patio_lights_2
notify:
- platform: group
name: new_group_notify
services:
- service: demo1
- service: demo2

View file

@ -13,8 +13,9 @@ from homeassistant.helpers.reload import (
async_reload_integration_platforms, async_reload_integration_platforms,
async_setup_reload_service, async_setup_reload_service,
) )
from homeassistant.loader import async_get_integration
from tests.async_mock import Mock, patch from tests.async_mock import AsyncMock, Mock, patch
from tests.common import ( from tests.common import (
MockModule, MockModule,
MockPlatform, MockPlatform,
@ -109,6 +110,104 @@ async def test_setup_reload_service(hass):
assert len(setup_called) == 2 assert len(setup_called) == 2
async def test_setup_reload_service_when_async_process_component_config_fails(hass):
"""Test setting up a reload service with the config processing failing."""
component_setup = Mock(return_value=True)
setup_called = []
async def setup_platform(*args):
setup_called.append(args)
mock_integration(hass, MockModule(DOMAIN, setup=component_setup))
mock_integration(hass, MockModule(PLATFORM, dependencies=[DOMAIN]))
mock_platform = MockPlatform(async_setup_platform=setup_platform)
mock_entity_platform(hass, f"{DOMAIN}.{PLATFORM}", mock_platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
await component.async_setup({DOMAIN: {"platform": PLATFORM, "sensors": None}})
await hass.async_block_till_done()
assert component_setup.called
assert f"{DOMAIN}.{PLATFORM}" in hass.config.components
assert len(setup_called) == 1
await async_setup_reload_service(hass, PLATFORM, [DOMAIN])
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"helpers/reload_configuration.yaml",
)
with patch.object(config, "YAML_CONFIG_FILE", yaml_path), patch.object(
config, "async_process_component_config", return_value=None
):
await hass.services.async_call(
PLATFORM,
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert len(setup_called) == 1
async def test_setup_reload_service_with_platform_that_provides_async_reset_platform(
hass,
):
"""Test setting up a reload service using a platform that has its own async_reset_platform."""
component_setup = AsyncMock(return_value=True)
setup_called = []
async_reset_platform_called = []
async def setup_platform(*args):
setup_called.append(args)
async def async_reset_platform(*args):
async_reset_platform_called.append(args)
mock_integration(hass, MockModule(DOMAIN, async_setup=component_setup))
integration = await async_get_integration(hass, DOMAIN)
integration.get_component().async_reset_platform = async_reset_platform
mock_integration(hass, MockModule(PLATFORM, dependencies=[DOMAIN]))
mock_platform = MockPlatform(async_setup_platform=setup_platform)
mock_entity_platform(hass, f"{DOMAIN}.{PLATFORM}", mock_platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
await component.async_setup({DOMAIN: {"platform": PLATFORM, "name": "xyz"}})
await hass.async_block_till_done()
assert component_setup.called
assert f"{DOMAIN}.{PLATFORM}" in hass.config.components
assert len(setup_called) == 1
await async_setup_reload_service(hass, PLATFORM, [DOMAIN])
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"helpers/reload_configuration.yaml",
)
with patch.object(config, "YAML_CONFIG_FILE", yaml_path):
await hass.services.async_call(
PLATFORM,
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert len(setup_called) == 1
assert len(async_reset_platform_called) == 1
async def test_async_integration_yaml_config(hass): async def test_async_integration_yaml_config(hass):
"""Test loading yaml config for an integration.""" """Test loading yaml config for an integration."""
mock_integration(hass, MockModule(DOMAIN)) mock_integration(hass, MockModule(DOMAIN))