Support reloading the group notify platform (#39511)

This commit is contained in:
J. Nick Koston 2020-09-02 17:12:07 -05:00 committed by GitHub
parent 2d2efeb9bb
commit a778690b64
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 470 additions and 129 deletions

View file

@ -58,7 +58,7 @@ ATTR_ALL = "all"
SERVICE_SET = "set" SERVICE_SET = "set"
SERVICE_REMOVE = "remove" SERVICE_REMOVE = "remove"
PLATFORMS = ["light", "cover"] PLATFORMS = ["light", "cover", "notify"]
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View file

@ -1,6 +1,6 @@
# Describes the format for available group services # Describes the format for available group services
reload: reload:
description: Reload group configuration. description: Reload group configuration, entities, and notify services.
set: set:
description: Create/Update a user group. description: Create/Update a user group.

View file

@ -2,11 +2,12 @@
import asyncio import asyncio
from functools import partial from functools import partial
import logging import logging
from typing import Optional from typing import Any, Dict, Optional
import voluptuous as vol import voluptuous as vol
from homeassistant.const import CONF_NAME, CONF_PLATFORM from homeassistant.const import CONF_NAME, CONF_PLATFORM
from homeassistant.core import ServiceCall
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform, discovery from homeassistant.helpers import config_per_platform, discovery
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -37,10 +38,6 @@ DOMAIN = "notify"
SERVICE_NOTIFY = "notify" SERVICE_NOTIFY = "notify"
NOTIFY_SERVICES = "notify_services" NOTIFY_SERVICES = "notify_services"
SERVICE = "service"
TARGETS = "targets"
FRIENDLY_NAME = "friendly_name"
TARGET_FRIENDLY_NAME = "target_friendly_name"
PLATFORM_SCHEMA = vol.Schema( PLATFORM_SCHEMA = vol.Schema(
{vol.Required(CONF_PLATFORM): cv.string, vol.Optional(CONF_NAME): cv.string}, {vol.Required(CONF_PLATFORM): cv.string, vol.Optional(CONF_NAME): cv.string},
@ -58,88 +55,160 @@ NOTIFY_SERVICE_SCHEMA = vol.Schema(
@bind_hass @bind_hass
async def async_reload(hass, integration_name): async def async_reload(hass: HomeAssistantType, integration_name: str) -> None:
"""Register notify services for an integration.""" """Register notify services for an integration."""
if not _async_integration_has_notify_services(hass, integration_name):
return
tasks = [
notify_service.async_register_services()
for notify_service in hass.data[NOTIFY_SERVICES][integration_name]
]
await asyncio.gather(*tasks)
@bind_hass
async def async_reset_platform(hass: HomeAssistantType, integration_name: str) -> None:
"""Unregister notify services for an integration."""
if not _async_integration_has_notify_services(hass, integration_name):
return
tasks = [
notify_service.async_unregister_services()
for notify_service in hass.data[NOTIFY_SERVICES][integration_name]
]
await asyncio.gather(*tasks)
del hass.data[NOTIFY_SERVICES][integration_name]
def _async_integration_has_notify_services(
hass: HomeAssistantType, integration_name: str
) -> bool:
"""Determine if an integration has notify services registered."""
if ( if (
NOTIFY_SERVICES not in hass.data NOTIFY_SERVICES not in hass.data
or integration_name not in hass.data[NOTIFY_SERVICES] or integration_name not in hass.data[NOTIFY_SERVICES]
): ):
return return False
tasks = [ return True
_async_setup_notify_services(hass, data)
for data in hass.data[NOTIFY_SERVICES][integration_name]
]
await asyncio.gather(*tasks)
async def _async_setup_notify_services(hass, data): class BaseNotificationService:
"""Create or remove the notify services.""" """An abstract class for notification services."""
notify_service = data[SERVICE]
friendly_name = data[FRIENDLY_NAME]
targets = data[TARGETS]
async def _async_notify_message(service): hass: Optional[HomeAssistantType] = None
"""Handle sending notification message service calls."""
await _async_notify_message_service(hass, service, notify_service, targets)
if hasattr(notify_service, "targets"): def send_message(self, message, **kwargs):
target_friendly_name = data[TARGET_FRIENDLY_NAME] """Send a message.
stale_targets = set(targets)
for name, target in notify_service.targets.items(): kwargs can contain ATTR_TITLE to specify a title.
target_name = slugify(f"{target_friendly_name}_{name}") """
if target_name in stale_targets: raise NotImplementedError()
stale_targets.remove(target_name)
if target_name in targets:
continue
targets[target_name] = target
hass.services.async_register(
DOMAIN,
target_name,
_async_notify_message,
schema=NOTIFY_SERVICE_SCHEMA,
)
for stale_target_name in stale_targets: async def async_send_message(self, message: Any, **kwargs: Any) -> None:
del targets[stale_target_name] """Send a message.
hass.services.async_remove(
DOMAIN,
stale_target_name,
)
friendly_name_slug = slugify(friendly_name) kwargs can contain ATTR_TITLE to specify a title.
if hass.services.has_service(DOMAIN, friendly_name_slug): """
return await self.hass.async_add_job(partial(self.send_message, message, **kwargs)) # type: ignore
hass.services.async_register( async def _async_notify_message_service(self, service: ServiceCall) -> None:
DOMAIN,
friendly_name_slug,
_async_notify_message,
schema=NOTIFY_SERVICE_SCHEMA,
)
async def _async_notify_message_service(hass, service, notify_service, targets):
"""Handle sending notification message service calls.""" """Handle sending notification message service calls."""
kwargs = {} kwargs = {}
message = service.data[ATTR_MESSAGE] message = service.data[ATTR_MESSAGE]
title = service.data.get(ATTR_TITLE) title = service.data.get(ATTR_TITLE)
if title: if title:
title.hass = hass title.hass = self.hass
kwargs[ATTR_TITLE] = title.async_render() kwargs[ATTR_TITLE] = title.async_render()
if targets.get(service.service) is not None: if self._registered_targets.get(service.service) is not None:
kwargs[ATTR_TARGET] = [targets[service.service]] kwargs[ATTR_TARGET] = [self._registered_targets[service.service]]
elif service.data.get(ATTR_TARGET) is not None: elif service.data.get(ATTR_TARGET) is not None:
kwargs[ATTR_TARGET] = service.data.get(ATTR_TARGET) kwargs[ATTR_TARGET] = service.data.get(ATTR_TARGET)
message.hass = hass message.hass = self.hass
kwargs[ATTR_MESSAGE] = message.async_render() kwargs[ATTR_MESSAGE] = message.async_render()
kwargs[ATTR_DATA] = service.data.get(ATTR_DATA) kwargs[ATTR_DATA] = service.data.get(ATTR_DATA)
await notify_service.async_send_message(**kwargs) await self.async_send_message(**kwargs)
async def async_setup(
self,
hass: HomeAssistantType,
service_name: str,
target_service_name_prefix: str,
) -> None:
"""Store the data for the notify service."""
# pylint: disable=attribute-defined-outside-init
self.hass = hass
self._service_name = service_name
self._target_service_name_prefix = target_service_name_prefix
self._registered_targets: Dict = {}
async def async_register_services(self) -> None:
"""Create or update the notify services."""
assert self.hass
if hasattr(self, "targets"):
stale_targets = set(self._registered_targets)
# pylint: disable=no-member
for name, target in self.targets.items(): # type: ignore
target_name = slugify(f"{self._target_service_name_prefix}_{name}")
if target_name in stale_targets:
stale_targets.remove(target_name)
if target_name in self._registered_targets:
continue
self._registered_targets[target_name] = target
self.hass.services.async_register(
DOMAIN,
target_name,
self._async_notify_message_service,
schema=NOTIFY_SERVICE_SCHEMA,
)
for stale_target_name in stale_targets:
del self._registered_targets[stale_target_name]
self.hass.services.async_remove(
DOMAIN,
stale_target_name,
)
if self.hass.services.has_service(DOMAIN, self._service_name):
return
self.hass.services.async_register(
DOMAIN,
self._service_name,
self._async_notify_message_service,
schema=NOTIFY_SERVICE_SCHEMA,
)
async def async_unregister_services(self) -> None:
"""Unregister the notify services."""
assert self.hass
if self._registered_targets:
remove_targets = set(self._registered_targets)
for remove_target_name in remove_targets:
del self._registered_targets[remove_target_name]
self.hass.services.async_remove(
DOMAIN,
remove_target_name,
)
if not self.hass.services.has_service(DOMAIN, self._service_name):
return
self.hass.services.async_remove(
DOMAIN,
self._service_name,
)
async def async_setup(hass, config): async def async_setup(hass, config):
@ -188,31 +257,19 @@ async def async_setup(hass, config):
_LOGGER.exception("Error setting up platform %s", integration_name) _LOGGER.exception("Error setting up platform %s", integration_name)
return return
notify_service.hass = hass
if discovery_info is None: if discovery_info is None:
discovery_info = {} discovery_info = {}
target_friendly_name = ( conf_name = p_config.get(CONF_NAME) or discovery_info.get(CONF_NAME)
p_config.get(CONF_NAME) or discovery_info.get(CONF_NAME) or integration_name target_service_name_prefix = conf_name or integration_name
service_name = slugify(conf_name or SERVICE_NOTIFY)
await notify_service.async_setup(hass, service_name, target_service_name_prefix)
await notify_service.async_register_services()
hass.data[NOTIFY_SERVICES].setdefault(integration_name, []).append(
notify_service
) )
friendly_name = (
p_config.get(CONF_NAME) or discovery_info.get(CONF_NAME) or SERVICE_NOTIFY
)
data = {
FRIENDLY_NAME: friendly_name,
# The targets use a slightly different friendly name
# selection pattern than the base service
TARGET_FRIENDLY_NAME: target_friendly_name,
SERVICE: notify_service,
TARGETS: {},
}
hass.data[NOTIFY_SERVICES].setdefault(integration_name, [])
hass.data[NOTIFY_SERVICES][integration_name].append(data)
await _async_setup_notify_services(hass, data)
hass.config.components.add(f"{DOMAIN}.{integration_name}") hass.config.components.add(f"{DOMAIN}.{integration_name}")
return True return True
@ -232,23 +289,3 @@ async def async_setup(hass, config):
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered) discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
return True return True
class BaseNotificationService:
"""An abstract class for notification services."""
hass: Optional[HomeAssistantType] = None
def send_message(self, message, **kwargs):
"""Send a message.
kwargs can contain ATTR_TITLE to specify a title.
"""
raise NotImplementedError()
async def async_send_message(self, message, **kwargs):
"""Send a message.
kwargs can contain ATTR_TITLE to specify a title.
"""
await self.hass.async_add_job(partial(self.send_message, message, **kwargs))

View file

@ -2,7 +2,7 @@
import asyncio import asyncio
import logging import logging
from typing import Any, Dict, Iterable, Optional from typing import Any, Dict, Iterable, List, Optional
from homeassistant import config as conf_util from homeassistant import config as conf_util
from homeassistant.const import SERVICE_RELOAD from homeassistant.const import SERVICE_RELOAD
@ -12,6 +12,7 @@ from homeassistant.helpers import config_per_platform
from homeassistant.helpers.entity_platform import DATA_ENTITY_PLATFORM, EntityPlatform from homeassistant.helpers.entity_platform import DATA_ENTITY_PLATFORM, EntityPlatform
from homeassistant.helpers.typing import HomeAssistantType from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import async_get_integration from homeassistant.loader import async_get_integration
from homeassistant.setup import async_setup_component
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -34,12 +35,23 @@ async def async_reload_integration_platforms(
_LOGGER.error(err) _LOGGER.error(err)
return return
for integration_platform in integration_platforms: tasks = [
platform = async_get_platform(hass, integration_name, integration_platform) _resetup_platform(
hass, integration_name, integration_platform, unprocessed_conf
)
for integration_platform in integration_platforms
]
if not platform: await asyncio.gather(*tasks)
continue
async def _resetup_platform(
hass: HomeAssistantType,
integration_name: str,
integration_platform: str,
unprocessed_conf: Dict,
) -> None:
"""Resetup a platform."""
integration = await async_get_integration(hass, integration_platform) integration = await async_get_integration(hass, integration_platform)
conf = await conf_util.async_process_component_config( conf = await conf_util.async_process_component_config(
@ -47,16 +59,70 @@ async def async_reload_integration_platforms(
) )
if not conf: if not conf:
continue return
await platform.async_reset()
root_config: Dict = {integration_platform: []}
# Extract only the config for template, ignore the rest. # Extract only the config for template, ignore the rest.
for p_type, p_config in config_per_platform(conf, integration_platform): for p_type, p_config in config_per_platform(conf, integration_platform):
if p_type != integration_name: if p_type != integration_name:
continue continue
await platform.async_setup(p_config) # type: ignore root_config[integration_platform].append(p_config)
component = integration.get_component()
if hasattr(component, "async_reset_platform"):
# If the integration has its own way to reset
# use this method.
await component.async_reset_platform(hass, integration_name) # type: ignore
await component.async_setup(hass, root_config) # type: ignore
return
# If its an entity platform, we use the entity_platform
# async_reset method
platform = async_get_platform(hass, integration_name, integration_platform)
if platform:
await _async_reconfig_platform(platform, root_config[integration_platform])
return
if not root_config[integration_platform]:
# No config for this platform
# and its not loaded. Nothing to do
return
await _async_setup_platform(
hass, integration_name, integration_platform, root_config[integration_platform]
)
async def _async_setup_platform(
hass: HomeAssistantType,
integration_name: str,
integration_platform: str,
platform_configs: List[Dict],
) -> None:
"""Platform for the first time when new configuration is added."""
if integration_platform not in hass.data:
await async_setup_component(
hass, integration_platform, {integration_platform: platform_configs}
)
return
entity_component = hass.data[integration_platform]
tasks = [
entity_component.async_setup_platform(integration_name, p_config)
for p_config in platform_configs
]
await asyncio.gather(*tasks)
async def _async_reconfig_platform(
platform: EntityPlatform, platform_configs: List[Dict]
) -> None:
"""Reconfigure an already loaded platform."""
await platform.async_reset()
tasks = [platform.async_setup(p_config) for p_config in platform_configs] # type: ignore
await asyncio.gather(*tasks)
async def async_integration_yaml_config( async def async_integration_yaml_config(

View file

@ -683,5 +683,79 @@ async def test_reload(hass):
assert hass.states.get("light.outside_patio_lights_g") is not None assert hass.states.get("light.outside_patio_lights_g") is not None
async def test_reload_with_platform_not_setup(hass):
"""Test the ability to reload lights."""
hass.states.async_set("light.bowl", STATE_ON)
await async_setup_component(
hass,
LIGHT_DOMAIN,
{
LIGHT_DOMAIN: [
{"platform": "demo"},
]
},
)
assert await async_setup_component(
hass,
"group",
{
"group": {
"group_zero": {"entities": "light.Bowl", "icon": "mdi:work"},
}
},
)
await hass.async_block_till_done()
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"group/configuration.yaml",
)
with patch.object(hass_config, "YAML_CONFIG_FILE", yaml_path):
await hass.services.async_call(
DOMAIN,
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert hass.states.get("light.light_group") is None
assert hass.states.get("light.master_hall_lights_g") is not None
assert hass.states.get("light.outside_patio_lights_g") is not None
async def test_reload_with_base_integration_platform_not_setup(hass):
"""Test the ability to reload lights."""
assert await async_setup_component(
hass,
"group",
{
"group": {
"group_zero": {"entities": "light.Bowl", "icon": "mdi:work"},
}
},
)
await hass.async_block_till_done()
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"group/configuration.yaml",
)
with patch.object(hass_config, "YAML_CONFIG_FILE", yaml_path):
await hass.services.async_call(
DOMAIN,
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert hass.states.get("light.light_group") is None
assert hass.states.get("light.master_hall_lights_g") is not None
assert hass.states.get("light.outside_patio_lights_g") is not None
def _get_fixtures_base_path(): def _get_fixtures_base_path():
return path.dirname(path.dirname(path.dirname(__file__))) return path.dirname(path.dirname(path.dirname(__file__)))

View file

@ -1,11 +1,14 @@
"""The tests for the notify.group platform.""" """The tests for the notify.group platform."""
import asyncio import asyncio
from os import path
import unittest import unittest
from homeassistant import config as hass_config
import homeassistant.components.demo.notify as demo import homeassistant.components.demo.notify as demo
from homeassistant.components.group import SERVICE_RELOAD
import homeassistant.components.group.notify as group import homeassistant.components.group.notify as group
import homeassistant.components.notify as notify import homeassistant.components.notify as notify
from homeassistant.setup import setup_component from homeassistant.setup import async_setup_component, setup_component
from tests.async_mock import MagicMock, patch from tests.async_mock import MagicMock, patch
from tests.common import assert_setup_component, get_test_home_assistant from tests.common import assert_setup_component, get_test_home_assistant
@ -90,3 +93,58 @@ class TestNotifyGroup(unittest.TestCase):
"title": "Test notification", "title": "Test notification",
"data": {"hello": "world", "test": "message"}, "data": {"hello": "world", "test": "message"},
} }
async def test_reload_notify(hass):
"""Verify we can reload the notify service."""
assert await async_setup_component(
hass,
"group",
{},
)
await hass.async_block_till_done()
assert await async_setup_component(
hass,
notify.DOMAIN,
{
notify.DOMAIN: [
{"name": "demo1", "platform": "demo"},
{"name": "demo2", "platform": "demo"},
{
"name": "group_notify",
"platform": "group",
"services": [{"service": "demo1"}],
},
]
},
)
await hass.async_block_till_done()
assert hass.services.has_service(notify.DOMAIN, "demo1")
assert hass.services.has_service(notify.DOMAIN, "demo2")
assert hass.services.has_service(notify.DOMAIN, "group_notify")
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"group/configuration.yaml",
)
with patch.object(hass_config, "YAML_CONFIG_FILE", yaml_path):
await hass.services.async_call(
"group",
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert hass.services.has_service(notify.DOMAIN, "demo1")
assert hass.services.has_service(notify.DOMAIN, "demo2")
assert not hass.services.has_service(notify.DOMAIN, "group_notify")
assert hass.services.has_service(notify.DOMAIN, "new_group_notify")
def _get_fixtures_base_path():
return path.dirname(path.dirname(path.dirname(__file__)))

View file

@ -9,3 +9,10 @@ light:
entities: entities:
- light.outside_patio_lights - light.outside_patio_lights
- light.outside_patio_lights_2 - light.outside_patio_lights_2
notify:
- platform: group
name: new_group_notify
services:
- service: demo1
- service: demo2

View file

@ -13,8 +13,9 @@ from homeassistant.helpers.reload import (
async_reload_integration_platforms, async_reload_integration_platforms,
async_setup_reload_service, async_setup_reload_service,
) )
from homeassistant.loader import async_get_integration
from tests.async_mock import Mock, patch from tests.async_mock import AsyncMock, Mock, patch
from tests.common import ( from tests.common import (
MockModule, MockModule,
MockPlatform, MockPlatform,
@ -109,6 +110,104 @@ async def test_setup_reload_service(hass):
assert len(setup_called) == 2 assert len(setup_called) == 2
async def test_setup_reload_service_when_async_process_component_config_fails(hass):
"""Test setting up a reload service with the config processing failing."""
component_setup = Mock(return_value=True)
setup_called = []
async def setup_platform(*args):
setup_called.append(args)
mock_integration(hass, MockModule(DOMAIN, setup=component_setup))
mock_integration(hass, MockModule(PLATFORM, dependencies=[DOMAIN]))
mock_platform = MockPlatform(async_setup_platform=setup_platform)
mock_entity_platform(hass, f"{DOMAIN}.{PLATFORM}", mock_platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
await component.async_setup({DOMAIN: {"platform": PLATFORM, "sensors": None}})
await hass.async_block_till_done()
assert component_setup.called
assert f"{DOMAIN}.{PLATFORM}" in hass.config.components
assert len(setup_called) == 1
await async_setup_reload_service(hass, PLATFORM, [DOMAIN])
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"helpers/reload_configuration.yaml",
)
with patch.object(config, "YAML_CONFIG_FILE", yaml_path), patch.object(
config, "async_process_component_config", return_value=None
):
await hass.services.async_call(
PLATFORM,
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert len(setup_called) == 1
async def test_setup_reload_service_with_platform_that_provides_async_reset_platform(
hass,
):
"""Test setting up a reload service using a platform that has its own async_reset_platform."""
component_setup = AsyncMock(return_value=True)
setup_called = []
async_reset_platform_called = []
async def setup_platform(*args):
setup_called.append(args)
async def async_reset_platform(*args):
async_reset_platform_called.append(args)
mock_integration(hass, MockModule(DOMAIN, async_setup=component_setup))
integration = await async_get_integration(hass, DOMAIN)
integration.get_component().async_reset_platform = async_reset_platform
mock_integration(hass, MockModule(PLATFORM, dependencies=[DOMAIN]))
mock_platform = MockPlatform(async_setup_platform=setup_platform)
mock_entity_platform(hass, f"{DOMAIN}.{PLATFORM}", mock_platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
await component.async_setup({DOMAIN: {"platform": PLATFORM, "name": "xyz"}})
await hass.async_block_till_done()
assert component_setup.called
assert f"{DOMAIN}.{PLATFORM}" in hass.config.components
assert len(setup_called) == 1
await async_setup_reload_service(hass, PLATFORM, [DOMAIN])
yaml_path = path.join(
_get_fixtures_base_path(),
"fixtures",
"helpers/reload_configuration.yaml",
)
with patch.object(config, "YAML_CONFIG_FILE", yaml_path):
await hass.services.async_call(
PLATFORM,
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert len(setup_called) == 1
assert len(async_reset_platform_called) == 1
async def test_async_integration_yaml_config(hass): async def test_async_integration_yaml_config(hass):
"""Test loading yaml config for an integration.""" """Test loading yaml config for an integration."""
mock_integration(hass, MockModule(DOMAIN)) mock_integration(hass, MockModule(DOMAIN))