Add support for EventBridge to aws integration (#77573)
* Added EventBridge support to aws integration * Added type hints for all aws notification services + Added unit tests for EventBridge AWS integration * Increase line coverage for unit tests for aws integration.
This commit is contained in:
parent
aea0067e49
commit
dbfca8def8
3 changed files with 169 additions and 5 deletions
|
@ -51,7 +51,7 @@ DEFAULT_CREDENTIAL = [
|
||||||
{CONF_NAME: "default", CONF_PROFILE_NAME: "default", CONF_VALIDATE: False}
|
{CONF_NAME: "default", CONF_PROFILE_NAME: "default", CONF_VALIDATE: False}
|
||||||
]
|
]
|
||||||
|
|
||||||
SUPPORTED_SERVICES = ["lambda", "sns", "sqs"]
|
SUPPORTED_SERVICES = ["lambda", "sns", "sqs", "events"]
|
||||||
|
|
||||||
NOTIFY_PLATFORM_SCHEMA = vol.Schema(
|
NOTIFY_PLATFORM_SCHEMA = vol.Schema(
|
||||||
{
|
{
|
||||||
|
|
|
@ -3,6 +3,7 @@ import asyncio
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from aiobotocore.session import AioSession
|
from aiobotocore.session import AioSession
|
||||||
|
|
||||||
|
@ -105,6 +106,9 @@ async def async_get_service(hass, config, discovery_info=None):
|
||||||
if service == "sqs":
|
if service == "sqs":
|
||||||
return AWSSQS(session, aws_config)
|
return AWSSQS(session, aws_config)
|
||||||
|
|
||||||
|
if service == "events":
|
||||||
|
return AWSEventBridge(session, aws_config)
|
||||||
|
|
||||||
# should not reach here since service was checked in schema
|
# should not reach here since service was checked in schema
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -128,7 +132,7 @@ class AWSLambda(AWSNotify):
|
||||||
super().__init__(session, aws_config)
|
super().__init__(session, aws_config)
|
||||||
self.context = context
|
self.context = context
|
||||||
|
|
||||||
async def async_send_message(self, message="", **kwargs):
|
async def async_send_message(self, message: str = "", **kwargs: Any) -> None:
|
||||||
"""Send notification to specified LAMBDA ARN."""
|
"""Send notification to specified LAMBDA ARN."""
|
||||||
if not kwargs.get(ATTR_TARGET):
|
if not kwargs.get(ATTR_TARGET):
|
||||||
_LOGGER.error("At least one target is required")
|
_LOGGER.error("At least one target is required")
|
||||||
|
@ -161,7 +165,7 @@ class AWSSNS(AWSNotify):
|
||||||
|
|
||||||
service = "sns"
|
service = "sns"
|
||||||
|
|
||||||
async def async_send_message(self, message="", **kwargs):
|
async def async_send_message(self, message: str = "", **kwargs: Any) -> None:
|
||||||
"""Send notification to specified SNS ARN."""
|
"""Send notification to specified SNS ARN."""
|
||||||
if not kwargs.get(ATTR_TARGET):
|
if not kwargs.get(ATTR_TARGET):
|
||||||
_LOGGER.error("At least one target is required")
|
_LOGGER.error("At least one target is required")
|
||||||
|
@ -199,7 +203,7 @@ class AWSSQS(AWSNotify):
|
||||||
|
|
||||||
service = "sqs"
|
service = "sqs"
|
||||||
|
|
||||||
async def async_send_message(self, message="", **kwargs):
|
async def async_send_message(self, message: str = "", **kwargs: Any) -> None:
|
||||||
"""Send notification to specified SQS ARN."""
|
"""Send notification to specified SQS ARN."""
|
||||||
if not kwargs.get(ATTR_TARGET):
|
if not kwargs.get(ATTR_TARGET):
|
||||||
_LOGGER.error("At least one target is required")
|
_LOGGER.error("At least one target is required")
|
||||||
|
@ -231,3 +235,52 @@ class AWSSQS(AWSNotify):
|
||||||
|
|
||||||
if tasks:
|
if tasks:
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
class AWSEventBridge(AWSNotify):
|
||||||
|
"""Implement the notification service for the AWS EventBridge service."""
|
||||||
|
|
||||||
|
service = "events"
|
||||||
|
|
||||||
|
async def async_send_message(self, message: str = "", **kwargs: Any) -> None:
|
||||||
|
"""Send notification to specified EventBus."""
|
||||||
|
|
||||||
|
cleaned_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
data = cleaned_kwargs.get(ATTR_DATA, {})
|
||||||
|
detail = (
|
||||||
|
json.dumps(data["detail"])
|
||||||
|
if "detail" in data
|
||||||
|
else json.dumps({"message": message})
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self.session.create_client(
|
||||||
|
self.service, **self.aws_config
|
||||||
|
) as client:
|
||||||
|
tasks = []
|
||||||
|
entries = []
|
||||||
|
for target in kwargs.get(ATTR_TARGET, [None]):
|
||||||
|
entry = {
|
||||||
|
"Source": data.get("source", "homeassistant"),
|
||||||
|
"Resources": data.get("resources", []),
|
||||||
|
"Detail": detail,
|
||||||
|
"DetailType": data.get("detail_type", ""),
|
||||||
|
}
|
||||||
|
if target:
|
||||||
|
entry["EventBusName"] = target
|
||||||
|
|
||||||
|
entries.append(entry)
|
||||||
|
for i in range(0, len(entries), 10):
|
||||||
|
tasks.append(
|
||||||
|
client.put_events(Entries=entries[i : min(i + 10, len(entries))])
|
||||||
|
)
|
||||||
|
|
||||||
|
if tasks:
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
for result in results:
|
||||||
|
for entry in result["Entries"]:
|
||||||
|
if len(entry.get("EventId", "")) == 0:
|
||||||
|
_LOGGER.error(
|
||||||
|
"Failed to send event: ErrorCode=%s ErrorMessage=%s",
|
||||||
|
entry["ErrorCode"],
|
||||||
|
entry["ErrorMessage"],
|
||||||
|
)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""Tests for the aws component config and setup."""
|
"""Tests for the aws component config and setup."""
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch as async_patch
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, call, patch as async_patch
|
||||||
|
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
@ -13,6 +14,7 @@ class MockAioSession:
|
||||||
self.invoke = AsyncMock()
|
self.invoke = AsyncMock()
|
||||||
self.publish = AsyncMock()
|
self.publish = AsyncMock()
|
||||||
self.send_message = AsyncMock()
|
self.send_message = AsyncMock()
|
||||||
|
self.put_events = AsyncMock()
|
||||||
|
|
||||||
def create_client(self, *args, **kwargs):
|
def create_client(self, *args, **kwargs):
|
||||||
"""Create a mocked client."""
|
"""Create a mocked client."""
|
||||||
|
@ -23,6 +25,7 @@ class MockAioSession:
|
||||||
invoke=self.invoke, # lambda
|
invoke=self.invoke, # lambda
|
||||||
publish=self.publish, # sns
|
publish=self.publish, # sns
|
||||||
send_message=self.send_message, # sqs
|
send_message=self.send_message, # sqs
|
||||||
|
put_events=self.put_events, # events
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
__aexit__=AsyncMock(),
|
__aexit__=AsyncMock(),
|
||||||
|
@ -289,3 +292,111 @@ async def test_service_call_extra_data(hass):
|
||||||
"AWS.SNS.SMS.SenderID": {"StringValue": "HA-notify", "DataType": "String"}
|
"AWS.SNS.SMS.SenderID": {"StringValue": "HA-notify", "DataType": "String"}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_events_service_call(hass):
|
||||||
|
"""Test events service (EventBridge) call works as expected."""
|
||||||
|
mock_session = MockAioSession()
|
||||||
|
with async_patch(
|
||||||
|
"homeassistant.components.aws.AioSession", return_value=mock_session
|
||||||
|
):
|
||||||
|
await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"aws",
|
||||||
|
{
|
||||||
|
"aws": {
|
||||||
|
"notify": [
|
||||||
|
{
|
||||||
|
"service": "events",
|
||||||
|
"name": "Events Test",
|
||||||
|
"region_name": "us-east-1",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert hass.services.has_service("notify", "events_test") is True
|
||||||
|
|
||||||
|
mock_session.put_events.return_value = {
|
||||||
|
"Entries": [{"EventId": "", "ErrorCode": 0, "ErrorMessage": "test-error"}]
|
||||||
|
}
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
"notify",
|
||||||
|
"events_test",
|
||||||
|
{
|
||||||
|
"message": "test",
|
||||||
|
"target": "ARN",
|
||||||
|
"data": {},
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session.put_events.assert_called_once_with(
|
||||||
|
Entries=[
|
||||||
|
{
|
||||||
|
"EventBusName": "ARN",
|
||||||
|
"Detail": json.dumps({"message": "test"}),
|
||||||
|
"DetailType": "",
|
||||||
|
"Source": "homeassistant",
|
||||||
|
"Resources": [],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_events_service_call_10_targets(hass):
|
||||||
|
"""Test events service (EventBridge) call works with more than 10 targets."""
|
||||||
|
mock_session = MockAioSession()
|
||||||
|
with async_patch(
|
||||||
|
"homeassistant.components.aws.AioSession", return_value=mock_session
|
||||||
|
):
|
||||||
|
await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"aws",
|
||||||
|
{
|
||||||
|
"aws": {
|
||||||
|
"notify": [
|
||||||
|
{
|
||||||
|
"service": "events",
|
||||||
|
"name": "Events Test",
|
||||||
|
"region_name": "us-east-1",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert hass.services.has_service("notify", "events_test") is True
|
||||||
|
await hass.services.async_call(
|
||||||
|
"notify",
|
||||||
|
"events_test",
|
||||||
|
{
|
||||||
|
"message": "",
|
||||||
|
"target": [f"eventbus{i}" for i in range(11)],
|
||||||
|
"data": {
|
||||||
|
"detail_type": "test_event",
|
||||||
|
"detail": {"eventkey": "eventvalue"},
|
||||||
|
"source": "HomeAssistant-test",
|
||||||
|
"resources": ["resource1", "resource2"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
entry = {
|
||||||
|
"Detail": json.dumps({"eventkey": "eventvalue"}),
|
||||||
|
"DetailType": "test_event",
|
||||||
|
"Source": "HomeAssistant-test",
|
||||||
|
"Resources": ["resource1", "resource2"],
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_session.put_events.assert_has_calls(
|
||||||
|
[
|
||||||
|
call(Entries=[entry | {"EventBusName": f"eventbus{i}"} for i in range(10)]),
|
||||||
|
call(Entries=[entry | {"EventBusName": "eventbus10"}]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue