Support reloading the group notify platform (#39511)

This commit is contained in:
J. Nick Koston 2020-09-02 17:12:07 -05:00 committed by GitHub
parent 2d2efeb9bb
commit a778690b64
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 470 additions and 129 deletions

View file

@ -58,7 +58,7 @@ ATTR_ALL = "all"
SERVICE_SET = "set"
SERVICE_REMOVE = "remove"
PLATFORMS = ["light", "cover"]
PLATFORMS = ["light", "cover", "notify"]
_LOGGER = logging.getLogger(__name__)

View file

@ -1,6 +1,6 @@
# Describes the format for available group services
reload:
description: Reload group configuration.
description: Reload group configuration, entities, and notify services.
set:
description: Create/Update a user group.

View file

@ -2,11 +2,12 @@
import asyncio
from functools import partial
import logging
from typing import Optional
from typing import Any, Dict, Optional
import voluptuous as vol
from homeassistant.const import CONF_NAME, CONF_PLATFORM
from homeassistant.core import ServiceCall
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform, discovery
import homeassistant.helpers.config_validation as cv
@ -37,10 +38,6 @@ DOMAIN = "notify"
SERVICE_NOTIFY = "notify"
NOTIFY_SERVICES = "notify_services"
SERVICE = "service"
TARGETS = "targets"
FRIENDLY_NAME = "friendly_name"
TARGET_FRIENDLY_NAME = "target_friendly_name"
PLATFORM_SCHEMA = vol.Schema(
{vol.Required(CONF_PLATFORM): cv.string, vol.Optional(CONF_NAME): cv.string},
@ -58,88 +55,160 @@ NOTIFY_SERVICE_SCHEMA = vol.Schema(
@bind_hass
async def async_reload(hass, integration_name):
async def async_reload(hass: HomeAssistantType, integration_name: str) -> None:
"""Register notify services for an integration."""
if not _async_integration_has_notify_services(hass, integration_name):
return
tasks = [
notify_service.async_register_services()
for notify_service in hass.data[NOTIFY_SERVICES][integration_name]
]
await asyncio.gather(*tasks)
@bind_hass
async def async_reset_platform(hass: HomeAssistantType, integration_name: str) -> None:
"""Unregister notify services for an integration."""
if not _async_integration_has_notify_services(hass, integration_name):
return
tasks = [
notify_service.async_unregister_services()
for notify_service in hass.data[NOTIFY_SERVICES][integration_name]
]
await asyncio.gather(*tasks)
del hass.data[NOTIFY_SERVICES][integration_name]
def _async_integration_has_notify_services(
hass: HomeAssistantType, integration_name: str
) -> bool:
"""Determine if an integration has notify services registered."""
if (
NOTIFY_SERVICES not in hass.data
or integration_name not in hass.data[NOTIFY_SERVICES]
):
return
return False
tasks = [
_async_setup_notify_services(hass, data)
for data in hass.data[NOTIFY_SERVICES][integration_name]
]
await asyncio.gather(*tasks)
return True
async def _async_setup_notify_services(hass, data):
"""Create or remove the notify services."""
notify_service = data[SERVICE]
friendly_name = data[FRIENDLY_NAME]
targets = data[TARGETS]
class BaseNotificationService:
"""An abstract class for notification services."""
async def _async_notify_message(service):
hass: Optional[HomeAssistantType] = None
def send_message(self, message, **kwargs):
"""Send a message.
kwargs can contain ATTR_TITLE to specify a title.
"""
raise NotImplementedError()
async def async_send_message(self, message: Any, **kwargs: Any) -> None:
"""Send a message.
kwargs can contain ATTR_TITLE to specify a title.
"""
await self.hass.async_add_job(partial(self.send_message, message, **kwargs)) # type: ignore
async def _async_notify_message_service(self, service: ServiceCall) -> None:
"""Handle sending notification message service calls."""
await _async_notify_message_service(hass, service, notify_service, targets)
kwargs = {}
message = service.data[ATTR_MESSAGE]
title = service.data.get(ATTR_TITLE)
if hasattr(notify_service, "targets"):
target_friendly_name = data[TARGET_FRIENDLY_NAME]
stale_targets = set(targets)
if title:
title.hass = self.hass
kwargs[ATTR_TITLE] = title.async_render()
for name, target in notify_service.targets.items():
target_name = slugify(f"{target_friendly_name}_{name}")
if target_name in stale_targets:
stale_targets.remove(target_name)
if target_name in targets:
continue
targets[target_name] = target
hass.services.async_register(
DOMAIN,
target_name,
_async_notify_message,
schema=NOTIFY_SERVICE_SCHEMA,
)
if self._registered_targets.get(service.service) is not None:
kwargs[ATTR_TARGET] = [self._registered_targets[service.service]]
elif service.data.get(ATTR_TARGET) is not None:
kwargs[ATTR_TARGET] = service.data.get(ATTR_TARGET)
for stale_target_name in stale_targets:
del targets[stale_target_name]
hass.services.async_remove(
DOMAIN,
stale_target_name,
)
message.hass = self.hass
kwargs[ATTR_MESSAGE] = message.async_render()
kwargs[ATTR_DATA] = service.data.get(ATTR_DATA)
friendly_name_slug = slugify(friendly_name)
if hass.services.has_service(DOMAIN, friendly_name_slug):
return
await self.async_send_message(**kwargs)
hass.services.async_register(
DOMAIN,
friendly_name_slug,
_async_notify_message,
schema=NOTIFY_SERVICE_SCHEMA,
)
async def async_setup(
self,
hass: HomeAssistantType,
service_name: str,
target_service_name_prefix: str,
) -> None:
"""Store the data for the notify service."""
# pylint: disable=attribute-defined-outside-init
self.hass = hass
self._service_name = service_name
self._target_service_name_prefix = target_service_name_prefix
self._registered_targets: Dict = {}
async def async_register_services(self) -> None:
"""Create or update the notify services."""
assert self.hass
async def _async_notify_message_service(hass, service, notify_service, targets):
"""Handle sending notification message service calls."""
kwargs = {}
message = service.data[ATTR_MESSAGE]
title = service.data.get(ATTR_TITLE)
if hasattr(self, "targets"):
stale_targets = set(self._registered_targets)
if title:
title.hass = hass
kwargs[ATTR_TITLE] = title.async_render()
# pylint: disable=no-member
for name, target in self.targets.items(): # type: ignore
target_name = slugify(f"{self._target_service_name_prefix}_{name}")
if target_name in stale_targets:
stale_targets.remove(target_name)
if target_name in self._registered_targets:
continue
self._registered_targets[target_name] = target
self.hass.services.async_register(
DOMAIN,
target_name,
self._async_notify_message_service,
schema=NOTIFY_SERVICE_SCHEMA,
)
if targets.get(service.service) is not None:
kwargs[ATTR_TARGET] = [targets[service.service]]
elif service.data.get(ATTR_TARGET) is not None:
kwargs[ATTR_TARGET] = service.data.get(ATTR_TARGET)
for stale_target_name in stale_targets:
del self._registered_targets[stale_target_name]
self.hass.services.async_remove(
DOMAIN,
stale_target_name,
)
message.hass = hass
kwargs[ATTR_MESSAGE] = message.async_render()
kwargs[ATTR_DATA] = service.data.get(ATTR_DATA)
if self.hass.services.has_service(DOMAIN, self._service_name):
return
await notify_service.async_send_message(**kwargs)
self.hass.services.async_register(
DOMAIN,
self._service_name,
self._async_notify_message_service,
schema=NOTIFY_SERVICE_SCHEMA,
)
async def async_unregister_services(self) -> None:
"""Unregister the notify services."""
assert self.hass
if self._registered_targets:
remove_targets = set(self._registered_targets)
for remove_target_name in remove_targets:
del self._registered_targets[remove_target_name]
self.hass.services.async_remove(
DOMAIN,
remove_target_name,
)
if not self.hass.services.has_service(DOMAIN, self._service_name):
return
self.hass.services.async_remove(
DOMAIN,
self._service_name,
)
async def async_setup(hass, config):
@ -188,31 +257,19 @@ async def async_setup(hass, config):
_LOGGER.exception("Error setting up platform %s", integration_name)
return
notify_service.hass = hass
if discovery_info is None:
discovery_info = {}
target_friendly_name = (
p_config.get(CONF_NAME) or discovery_info.get(CONF_NAME) or integration_name
conf_name = p_config.get(CONF_NAME) or discovery_info.get(CONF_NAME)
target_service_name_prefix = conf_name or integration_name
service_name = slugify(conf_name or SERVICE_NOTIFY)
await notify_service.async_setup(hass, service_name, target_service_name_prefix)
await notify_service.async_register_services()
hass.data[NOTIFY_SERVICES].setdefault(integration_name, []).append(
notify_service
)
friendly_name = (
p_config.get(CONF_NAME) or discovery_info.get(CONF_NAME) or SERVICE_NOTIFY
)
data = {
FRIENDLY_NAME: friendly_name,
# The targets use a slightly different friendly name
# selection pattern than the base service
TARGET_FRIENDLY_NAME: target_friendly_name,
SERVICE: notify_service,
TARGETS: {},
}
hass.data[NOTIFY_SERVICES].setdefault(integration_name, [])
hass.data[NOTIFY_SERVICES][integration_name].append(data)
await _async_setup_notify_services(hass, data)
hass.config.components.add(f"{DOMAIN}.{integration_name}")
return True
@ -232,23 +289,3 @@ async def async_setup(hass, config):
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
return True
class BaseNotificationService:
"""An abstract class for notification services."""
hass: Optional[HomeAssistantType] = None
def send_message(self, message, **kwargs):
"""Send a message.
kwargs can contain ATTR_TITLE to specify a title.
"""
raise NotImplementedError()
async def async_send_message(self, message, **kwargs):
"""Send a message.
kwargs can contain ATTR_TITLE to specify a title.
"""
await self.hass.async_add_job(partial(self.send_message, message, **kwargs))

View file

@ -2,7 +2,7 @@
import asyncio
import logging
from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, List, Optional
from homeassistant import config as conf_util
from homeassistant.const import SERVICE_RELOAD
@ -12,6 +12,7 @@ from homeassistant.helpers import config_per_platform
from homeassistant.helpers.entity_platform import DATA_ENTITY_PLATFORM, EntityPlatform
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import async_get_integration
from homeassistant.setup import async_setup_component
_LOGGER = logging.getLogger(__name__)
@ -34,29 +35,94 @@ async def async_reload_integration_platforms(
_LOGGER.error(err)
return
for integration_platform in integration_platforms:
platform = async_get_platform(hass, integration_name, integration_platform)
if not platform:
continue
integration = await async_get_integration(hass, integration_platform)
conf = await conf_util.async_process_component_config(
hass, unprocessed_conf, integration
tasks = [
_resetup_platform(
hass, integration_name, integration_platform, unprocessed_conf
)
for integration_platform in integration_platforms
]
if not conf:
await asyncio.gather(*tasks)
async def _resetup_platform(
hass: HomeAssistantType,
integration_name: str,
integration_platform: str,
unprocessed_conf: Dict,
) -> None:
"""Resetup a platform."""
integration = await async_get_integration(hass, integration_platform)
conf = await conf_util.async_process_component_config(
hass, unprocessed_conf, integration
)
if not conf:
return
root_config: Dict = {integration_platform: []}
# Extract only the config for template, ignore the rest.
for p_type, p_config in config_per_platform(conf, integration_platform):
if p_type != integration_name:
continue
await platform.async_reset()
root_config[integration_platform].append(p_config)
# Extract only the config for template, ignore the rest.
for p_type, p_config in config_per_platform(conf, integration_platform):
if p_type != integration_name:
continue
component = integration.get_component()
await platform.async_setup(p_config) # type: ignore
if hasattr(component, "async_reset_platform"):
# If the integration has its own way to reset
# use this method.
await component.async_reset_platform(hass, integration_name) # type: ignore
await component.async_setup(hass, root_config) # type: ignore
return
# If its an entity platform, we use the entity_platform
# async_reset method
platform = async_get_platform(hass, integration_name, integration_platform)
if platform:
await _async_reconfig_platform(platform, root_config[integration_platform])
return
if not root_config[integration_platform]:
# No config for this platform
# and its not loaded. Nothing to do
return
await _async_setup_platform(
hass, integration_name, integration_platform, root_config[integration_platform]
)
async def _async_setup_platform(
hass: HomeAssistantType,
integration_name: str,
integration_platform: str,
platform_configs: List[Dict],
) -> None:
"""Platform for the first time when new configuration is added."""
if integration_platform not in hass.data:
await async_setup_component(
hass, integration_platform, {integration_platform: platform_configs}
)
return
entity_component = hass.data[integration_platform]
tasks = [
entity_component.async_setup_platform(integration_name, p_config)
for p_config in platform_configs
]
await asyncio.gather(*tasks)
async def _async_reconfig_platform(
platform: EntityPlatform, platform_configs: List[Dict]
) -> None:
"""Reconfigure an already loaded platform."""
await platform.async_reset()
tasks = [platform.async_setup(p_config) for p_config in platform_configs] # type: ignore
await asyncio.gather(*tasks)
async def async_integration_yaml_config(

View file

@ -683,5 +683,79 @@ async def test_reload(hass):
assert hass.states.get("light.outside_patio_lights_g") is not None
async def test_reload_with_platform_not_setup(hass):
"""Test the ability to reload lights."""
hass.states.async_set("light.bowl", STATE_ON)
await async_setup_component(
hass,
LIGHT_DOMAIN,
{
LIGHT_DOMAIN: [
{"platform": "demo"},
]
},
)
assert await async_setup_component(
hass,
"group",
{
"group": {
"group_zero": {"entities": "light.Bowl", "icon": "mdi:work"},
}
},
)
await hass.async_block_till_done()
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"group/configuration.yaml",
)
with patch.object(hass_config, "YAML_CONFIG_FILE", yaml_path):
await hass.services.async_call(
DOMAIN,
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert hass.states.get("light.light_group") is None
assert hass.states.get("light.master_hall_lights_g") is not None
assert hass.states.get("light.outside_patio_lights_g") is not None
async def test_reload_with_base_integration_platform_not_setup(hass):
"""Test the ability to reload lights."""
assert await async_setup_component(
hass,
"group",
{
"group": {
"group_zero": {"entities": "light.Bowl", "icon": "mdi:work"},
}
},
)
await hass.async_block_till_done()
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"group/configuration.yaml",
)
with patch.object(hass_config, "YAML_CONFIG_FILE", yaml_path):
await hass.services.async_call(
DOMAIN,
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert hass.states.get("light.light_group") is None
assert hass.states.get("light.master_hall_lights_g") is not None
assert hass.states.get("light.outside_patio_lights_g") is not None
def _get_fixtures_base_path():
return path.dirname(path.dirname(path.dirname(__file__)))

View file

@ -1,11 +1,14 @@
"""The tests for the notify.group platform."""
import asyncio
from os import path
import unittest
from homeassistant import config as hass_config
import homeassistant.components.demo.notify as demo
from homeassistant.components.group import SERVICE_RELOAD
import homeassistant.components.group.notify as group
import homeassistant.components.notify as notify
from homeassistant.setup import setup_component
from homeassistant.setup import async_setup_component, setup_component
from tests.async_mock import MagicMock, patch
from tests.common import assert_setup_component, get_test_home_assistant
@ -90,3 +93,58 @@ class TestNotifyGroup(unittest.TestCase):
"title": "Test notification",
"data": {"hello": "world", "test": "message"},
}
async def test_reload_notify(hass):
"""Verify we can reload the notify service."""
assert await async_setup_component(
hass,
"group",
{},
)
await hass.async_block_till_done()
assert await async_setup_component(
hass,
notify.DOMAIN,
{
notify.DOMAIN: [
{"name": "demo1", "platform": "demo"},
{"name": "demo2", "platform": "demo"},
{
"name": "group_notify",
"platform": "group",
"services": [{"service": "demo1"}],
},
]
},
)
await hass.async_block_till_done()
assert hass.services.has_service(notify.DOMAIN, "demo1")
assert hass.services.has_service(notify.DOMAIN, "demo2")
assert hass.services.has_service(notify.DOMAIN, "group_notify")
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"group/configuration.yaml",
)
with patch.object(hass_config, "YAML_CONFIG_FILE", yaml_path):
await hass.services.async_call(
"group",
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert hass.services.has_service(notify.DOMAIN, "demo1")
assert hass.services.has_service(notify.DOMAIN, "demo2")
assert not hass.services.has_service(notify.DOMAIN, "group_notify")
assert hass.services.has_service(notify.DOMAIN, "new_group_notify")
def _get_fixtures_base_path():
return path.dirname(path.dirname(path.dirname(__file__)))

View file

@ -9,3 +9,10 @@ light:
entities:
- light.outside_patio_lights
- light.outside_patio_lights_2
notify:
- platform: group
name: new_group_notify
services:
- service: demo1
- service: demo2

View file

@ -13,8 +13,9 @@ from homeassistant.helpers.reload import (
async_reload_integration_platforms,
async_setup_reload_service,
)
from homeassistant.loader import async_get_integration
from tests.async_mock import Mock, patch
from tests.async_mock import AsyncMock, Mock, patch
from tests.common import (
MockModule,
MockPlatform,
@ -109,6 +110,104 @@ async def test_setup_reload_service(hass):
assert len(setup_called) == 2
async def test_setup_reload_service_when_async_process_component_config_fails(hass):
"""Test setting up a reload service with the config processing failing."""
component_setup = Mock(return_value=True)
setup_called = []
async def setup_platform(*args):
setup_called.append(args)
mock_integration(hass, MockModule(DOMAIN, setup=component_setup))
mock_integration(hass, MockModule(PLATFORM, dependencies=[DOMAIN]))
mock_platform = MockPlatform(async_setup_platform=setup_platform)
mock_entity_platform(hass, f"{DOMAIN}.{PLATFORM}", mock_platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
await component.async_setup({DOMAIN: {"platform": PLATFORM, "sensors": None}})
await hass.async_block_till_done()
assert component_setup.called
assert f"{DOMAIN}.{PLATFORM}" in hass.config.components
assert len(setup_called) == 1
await async_setup_reload_service(hass, PLATFORM, [DOMAIN])
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"helpers/reload_configuration.yaml",
)
with patch.object(config, "YAML_CONFIG_FILE", yaml_path), patch.object(
config, "async_process_component_config", return_value=None
):
await hass.services.async_call(
PLATFORM,
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert len(setup_called) == 1
async def test_setup_reload_service_with_platform_that_provides_async_reset_platform(
hass,
):
"""Test setting up a reload service using a platform that has its own async_reset_platform."""
component_setup = AsyncMock(return_value=True)
setup_called = []
async_reset_platform_called = []
async def setup_platform(*args):
setup_called.append(args)
async def async_reset_platform(*args):
async_reset_platform_called.append(args)
mock_integration(hass, MockModule(DOMAIN, async_setup=component_setup))
integration = await async_get_integration(hass, DOMAIN)
integration.get_component().async_reset_platform = async_reset_platform
mock_integration(hass, MockModule(PLATFORM, dependencies=[DOMAIN]))
mock_platform = MockPlatform(async_setup_platform=setup_platform)
mock_entity_platform(hass, f"{DOMAIN}.{PLATFORM}", mock_platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
await component.async_setup({DOMAIN: {"platform": PLATFORM, "name": "xyz"}})
await hass.async_block_till_done()
assert component_setup.called
assert f"{DOMAIN}.{PLATFORM}" in hass.config.components
assert len(setup_called) == 1
await async_setup_reload_service(hass, PLATFORM, [DOMAIN])
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"helpers/reload_configuration.yaml",
)
with patch.object(config, "YAML_CONFIG_FILE", yaml_path):
await hass.services.async_call(
PLATFORM,
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert len(setup_called) == 1
assert len(async_reset_platform_called) == 1
async def test_async_integration_yaml_config(hass):
"""Test loading yaml config for an integration."""
mock_integration(hass, MockModule(DOMAIN))