Add config flow to Azure Event Hub integration (#61155)
* config flow added, no tests yet * added tests * refinement of tests * small reverses of hub code * fix small bug * test fixes from review * test fixes from review * further refinement of tests and config flow * removed true return from hub and added failed reason for import * added deepcopy to default options * deleted max_delay from options, can still be in yaml for now * updated dropped message * mistaken period at eol
This commit is contained in:
parent
619529b40c
commit
80833aa7fb
11 changed files with 980 additions and 250 deletions
|
@ -9,16 +9,11 @@ import time
|
|||
from typing import Any
|
||||
|
||||
from azure.eventhub import EventData, EventDataBatch
|
||||
from azure.eventhub.aio import EventHubProducerClient, EventHubSharedKeyCredential
|
||||
from azure.eventhub.exceptions import EventHubError
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import (
|
||||
EVENT_HOMEASSISTANT_STOP,
|
||||
MATCH_ALL,
|
||||
STATE_UNAVAILABLE,
|
||||
STATE_UNKNOWN,
|
||||
)
|
||||
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry, ConfigEntryNotReady
|
||||
from homeassistant.const import MATCH_ALL, STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.core import Event, HomeAssistant
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entityfilter import FILTER_SCHEMA
|
||||
|
@ -26,8 +21,8 @@ from homeassistant.helpers.event import async_call_later
|
|||
from homeassistant.helpers.json import JSONEncoder
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .client import AzureEventHubClient
|
||||
from .const import (
|
||||
ADDITIONAL_ARGS,
|
||||
CONF_EVENT_HUB_CON_STRING,
|
||||
CONF_EVENT_HUB_INSTANCE_NAME,
|
||||
CONF_EVENT_HUB_NAMESPACE,
|
||||
|
@ -36,6 +31,9 @@ from .const import (
|
|||
CONF_FILTER,
|
||||
CONF_MAX_DELAY,
|
||||
CONF_SEND_INTERVAL,
|
||||
DATA_FILTER,
|
||||
DATA_HUB,
|
||||
DEFAULT_MAX_DELAY,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
|
@ -45,18 +43,15 @@ CONFIG_SCHEMA = vol.Schema(
|
|||
{
|
||||
DOMAIN: vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_EVENT_HUB_INSTANCE_NAME): cv.string,
|
||||
vol.Exclusive(CONF_EVENT_HUB_CON_STRING, "setup_methods"): cv.string,
|
||||
vol.Exclusive(CONF_EVENT_HUB_NAMESPACE, "setup_methods"): cv.string,
|
||||
vol.Optional(CONF_EVENT_HUB_INSTANCE_NAME): cv.string,
|
||||
vol.Optional(CONF_EVENT_HUB_CON_STRING): cv.string,
|
||||
vol.Optional(CONF_EVENT_HUB_NAMESPACE): cv.string,
|
||||
vol.Optional(CONF_EVENT_HUB_SAS_POLICY): cv.string,
|
||||
vol.Optional(CONF_EVENT_HUB_SAS_KEY): cv.string,
|
||||
vol.Optional(CONF_SEND_INTERVAL, default=5): cv.positive_int,
|
||||
vol.Optional(CONF_MAX_DELAY, default=30): cv.positive_int,
|
||||
vol.Optional(CONF_SEND_INTERVAL): cv.positive_int,
|
||||
vol.Optional(CONF_MAX_DELAY): cv.positive_int,
|
||||
vol.Optional(CONF_FILTER, default={}): FILTER_SCHEMA,
|
||||
},
|
||||
cv.has_at_least_one_key(
|
||||
CONF_EVENT_HUB_CON_STRING, CONF_EVENT_HUB_NAMESPACE
|
||||
),
|
||||
)
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
|
@ -64,35 +59,62 @@ CONFIG_SCHEMA = vol.Schema(
|
|||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, yaml_config: ConfigType) -> bool:
|
||||
"""Activate Azure EH component."""
|
||||
config = yaml_config[DOMAIN]
|
||||
if config.get(CONF_EVENT_HUB_CON_STRING):
|
||||
client_args = {
|
||||
"conn_str": config[CONF_EVENT_HUB_CON_STRING],
|
||||
"eventhub_name": config[CONF_EVENT_HUB_INSTANCE_NAME],
|
||||
}
|
||||
conn_str_client = True
|
||||
else:
|
||||
client_args = {
|
||||
"fully_qualified_namespace": f"{config[CONF_EVENT_HUB_NAMESPACE]}.servicebus.windows.net",
|
||||
"eventhub_name": config[CONF_EVENT_HUB_INSTANCE_NAME],
|
||||
"credential": EventHubSharedKeyCredential(
|
||||
policy=config[CONF_EVENT_HUB_SAS_POLICY],
|
||||
key=config[CONF_EVENT_HUB_SAS_KEY],
|
||||
),
|
||||
}
|
||||
conn_str_client = False
|
||||
"""Activate Azure EH component from yaml.
|
||||
|
||||
instance = hass.data[DOMAIN] = AzureEventHub(
|
||||
hass,
|
||||
client_args,
|
||||
conn_str_client,
|
||||
config[CONF_FILTER],
|
||||
config[CONF_SEND_INTERVAL],
|
||||
config[CONF_MAX_DELAY],
|
||||
Adds an empty filter to hass data.
|
||||
Tries to get a filter from yaml, if present set to hass data.
|
||||
If config is empty after getting the filter, return, otherwise emit
|
||||
deprecated warning and pass the rest to the config flow.
|
||||
"""
|
||||
hass.data.setdefault(DOMAIN, {DATA_FILTER: FILTER_SCHEMA({})})
|
||||
if DOMAIN not in yaml_config:
|
||||
return True
|
||||
hass.data[DOMAIN][DATA_FILTER] = yaml_config[DOMAIN].pop(CONF_FILTER)
|
||||
|
||||
if not yaml_config[DOMAIN]:
|
||||
return True
|
||||
_LOGGER.warning(
|
||||
"Loading Azure Event Hub completely via yaml config is deprecated; Only the \
|
||||
Filter can be set in yaml, the rest is done through a config flow and has \
|
||||
been imported, all other keys but filter can be deleted from configuration.yaml"
|
||||
)
|
||||
hass.async_create_task(
|
||||
hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_IMPORT}, data=yaml_config[DOMAIN]
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
hass.async_create_task(instance.async_start())
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Do the setup based on the config entry and the filter from yaml."""
|
||||
hass.data.setdefault(DOMAIN, {DATA_FILTER: FILTER_SCHEMA({})})
|
||||
hub = AzureEventHub(
|
||||
hass,
|
||||
AzureEventHubClient.from_input(**entry.data),
|
||||
hass.data[DOMAIN][DATA_FILTER],
|
||||
entry.options[CONF_SEND_INTERVAL],
|
||||
entry.options.get(CONF_MAX_DELAY),
|
||||
)
|
||||
try:
|
||||
await hub.async_test_connection()
|
||||
except EventHubError as err:
|
||||
raise ConfigEntryNotReady("Could not connect to Azure Event Hub") from err
|
||||
hass.data[DOMAIN][DATA_HUB] = hub
|
||||
entry.async_on_unload(entry.add_update_listener(async_update_listener))
|
||||
await hub.async_start()
|
||||
return True
|
||||
|
||||
|
||||
async def async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None:
|
||||
"""Update listener for options."""
|
||||
hass.data[DOMAIN][DATA_HUB].update_options(entry.options)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
hub = hass.data[DOMAIN].pop(DATA_HUB)
|
||||
await hub.async_stop()
|
||||
return True
|
||||
|
||||
|
||||
|
@ -102,40 +124,45 @@ class AzureEventHub:
|
|||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
client_args: dict[str, Any],
|
||||
conn_str_client: bool,
|
||||
client: AzureEventHubClient,
|
||||
entities_filter: vol.Schema,
|
||||
send_interval: int,
|
||||
max_delay: int,
|
||||
max_delay: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the listener."""
|
||||
self.hass = hass
|
||||
self.queue: asyncio.PriorityQueue[ # pylint: disable=unsubscriptable-object
|
||||
tuple[int, tuple[float, Event | None]]
|
||||
] = asyncio.PriorityQueue()
|
||||
self._client_args = client_args
|
||||
self._conn_str_client = conn_str_client
|
||||
self._client = client
|
||||
self._entities_filter = entities_filter
|
||||
self._send_interval = send_interval
|
||||
self._max_delay = max_delay + send_interval
|
||||
self._max_delay = max_delay if max_delay else DEFAULT_MAX_DELAY
|
||||
self._listener_remover: Callable[[], None] | None = None
|
||||
self._next_send_remover: Callable[[], None] | None = None
|
||||
self.shutdown = False
|
||||
|
||||
async def async_start(self) -> None:
|
||||
"""Start the recorder, suppress logging and register the callbacks and do the first send after five seconds, to capture the startup events."""
|
||||
# suppress the INFO and below logging on the underlying packages, they are very verbose, even at INFO
|
||||
"""Start the hub.
|
||||
|
||||
This suppresses logging and register the listener and
|
||||
schedules the first send.
|
||||
"""
|
||||
# suppress the INFO and below logging on the underlying packages,
|
||||
# they are very verbose, even at INFO
|
||||
logging.getLogger("uamqp").setLevel(logging.WARNING)
|
||||
logging.getLogger("azure.eventhub").setLevel(logging.WARNING)
|
||||
|
||||
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self.async_shutdown)
|
||||
self._listener_remover = self.hass.bus.async_listen(
|
||||
MATCH_ALL, self.async_listen
|
||||
)
|
||||
# schedule the first send after 10 seconds to capture startup events, after that each send will schedule the next after the interval.
|
||||
self._next_send_remover = async_call_later(self.hass, 10, self.async_send)
|
||||
# schedule the first send after 10 seconds to capture startup events,
|
||||
# after that each send will schedule the next after the interval.
|
||||
self._next_send_remover = async_call_later(
|
||||
self.hass, self._send_interval, self.async_send
|
||||
)
|
||||
|
||||
async def async_shutdown(self, _: Event) -> None:
|
||||
async def async_stop(self) -> None:
|
||||
"""Shut down the AEH by queueing None and calling send."""
|
||||
if self._next_send_remover:
|
||||
self._next_send_remover()
|
||||
|
@ -144,13 +171,17 @@ class AzureEventHub:
|
|||
await self.queue.put((3, (time.monotonic(), None)))
|
||||
await self.async_send(None)
|
||||
|
||||
async def async_test_connection(self) -> None:
|
||||
"""Test the connection to the event hub."""
|
||||
await self._client.test_connection()
|
||||
|
||||
async def async_listen(self, event: Event) -> None:
|
||||
"""Listen for new messages on the bus and queue them for AEH."""
|
||||
await self.queue.put((2, (time.monotonic(), event)))
|
||||
|
||||
async def async_send(self, _) -> None:
|
||||
"""Write preprocessed events to eventhub, with retry."""
|
||||
async with self._get_client() as client:
|
||||
async with self._client.client as client:
|
||||
while not self.queue.empty():
|
||||
data_batch, dequeue_count = await self.fill_batch(client)
|
||||
_LOGGER.debug(
|
||||
|
@ -175,9 +206,12 @@ class AzureEventHub:
|
|||
async def fill_batch(self, client) -> tuple[EventDataBatch, int]:
|
||||
"""Return a batch of events formatted for writing.
|
||||
|
||||
Uses get_nowait instead of await get, because the functions batches and doesn't wait for each single event, the send function is called.
|
||||
Uses get_nowait instead of await get, because the functions batches and
|
||||
doesn't wait for each single event, the send function is called.
|
||||
|
||||
Throws ValueError on add to batch when the EventDataBatch object reaches max_size. Put the item back in the queue and the next batch will include it.
|
||||
Throws ValueError on add to batch when the EventDataBatch object reaches
|
||||
max_size. Put the item back in the queue and the next batch will include
|
||||
it.
|
||||
"""
|
||||
event_batch = await client.create_batch()
|
||||
dequeue_count = 0
|
||||
|
@ -194,10 +228,12 @@ class AzureEventHub:
|
|||
event_data = self._event_to_filtered_event_data(event)
|
||||
if not event_data:
|
||||
continue
|
||||
if time.monotonic() - timestamp <= self._max_delay:
|
||||
if time.monotonic() - timestamp <= self._max_delay + self._send_interval:
|
||||
try:
|
||||
event_batch.add(event_data)
|
||||
except ValueError:
|
||||
dequeue_count -= 1
|
||||
self.queue.task_done()
|
||||
self.queue.put_nowait((1, (timestamp, event)))
|
||||
break
|
||||
else:
|
||||
|
@ -205,7 +241,7 @@ class AzureEventHub:
|
|||
|
||||
if dropped:
|
||||
_LOGGER.warning(
|
||||
"Dropped %d old events, consider increasing the max_delay", dropped
|
||||
"Dropped %d old events, consider filtering messages", dropped
|
||||
)
|
||||
|
||||
return event_batch, dequeue_count
|
||||
|
@ -221,10 +257,6 @@ class AzureEventHub:
|
|||
return None
|
||||
return EventData(json.dumps(obj=state, cls=JSONEncoder).encode("utf-8"))
|
||||
|
||||
def _get_client(self) -> EventHubProducerClient:
|
||||
"""Get a Event Producer Client."""
|
||||
if self._conn_str_client:
|
||||
return EventHubProducerClient.from_connection_string(
|
||||
**self._client_args, **ADDITIONAL_ARGS
|
||||
)
|
||||
return EventHubProducerClient(**self._client_args, **ADDITIONAL_ARGS)
|
||||
def update_options(self, new_options: dict[str, Any]) -> None:
|
||||
"""Update options."""
|
||||
self._send_interval = new_options[CONF_SEND_INTERVAL]
|
||||
|
|
71
homeassistant/components/azure_event_hub/client.py
Normal file
71
homeassistant/components/azure_event_hub/client.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
"""File for Azure Event Hub models."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
from azure.eventhub.aio import EventHubProducerClient, EventHubSharedKeyCredential
|
||||
|
||||
from .const import ADDITIONAL_ARGS, CONF_EVENT_HUB_CON_STRING
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AzureEventHubClient:
|
||||
"""Class for the Azure Event Hub client. Use from_input to initialize."""
|
||||
|
||||
event_hub_instance_name: str
|
||||
|
||||
@property
|
||||
def client(self) -> EventHubProducerClient:
|
||||
"""Return the client."""
|
||||
|
||||
async def test_connection(self) -> None:
|
||||
"""Test connection, will throw EventHubError when it cannot connect."""
|
||||
async with self.client as client:
|
||||
await client.get_eventhub_properties()
|
||||
|
||||
@classmethod
|
||||
def from_input(cls, **kwargs) -> AzureEventHubClient:
|
||||
"""Create the right class."""
|
||||
if CONF_EVENT_HUB_CON_STRING in kwargs:
|
||||
return AzureEventHubClientConnectionString(**kwargs)
|
||||
return AzureEventHubClientSAS(**kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AzureEventHubClientConnectionString(AzureEventHubClient):
|
||||
"""Class for Connection String based Azure Event Hub Client."""
|
||||
|
||||
event_hub_connection_string: str
|
||||
|
||||
@property
|
||||
def client(self) -> EventHubProducerClient:
|
||||
"""Return the client."""
|
||||
return EventHubProducerClient.from_connection_string(
|
||||
conn_str=self.event_hub_connection_string,
|
||||
eventhub_name=self.event_hub_instance_name,
|
||||
**ADDITIONAL_ARGS,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AzureEventHubClientSAS(AzureEventHubClient):
|
||||
"""Class for SAS based Azure Event Hub Client."""
|
||||
|
||||
event_hub_namespace: str
|
||||
event_hub_sas_policy: str
|
||||
event_hub_sas_key: str
|
||||
|
||||
@property
|
||||
def client(self) -> EventHubProducerClient:
|
||||
"""Get a Event Producer Client."""
|
||||
return EventHubProducerClient(
|
||||
fully_qualified_namespace=f"{self.event_hub_namespace}.servicebus.windows.net",
|
||||
eventhub_name=self.event_hub_instance_name,
|
||||
credential=EventHubSharedKeyCredential( # type: ignore
|
||||
policy=self.event_hub_sas_policy, key=self.event_hub_sas_key
|
||||
),
|
||||
**ADDITIONAL_ARGS,
|
||||
)
|
196
homeassistant/components/azure_event_hub/config_flow.py
Normal file
196
homeassistant/components/azure_event_hub/config_flow.py
Normal file
|
@ -0,0 +1,196 @@
|
|||
"""Config flow for azure_event_hub integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from azure.eventhub.exceptions import EventHubError
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
|
||||
from .client import AzureEventHubClient
|
||||
from .const import (
|
||||
CONF_EVENT_HUB_CON_STRING,
|
||||
CONF_EVENT_HUB_INSTANCE_NAME,
|
||||
CONF_EVENT_HUB_NAMESPACE,
|
||||
CONF_EVENT_HUB_SAS_KEY,
|
||||
CONF_EVENT_HUB_SAS_POLICY,
|
||||
CONF_MAX_DELAY,
|
||||
CONF_SEND_INTERVAL,
|
||||
CONF_USE_CONN_STRING,
|
||||
DEFAULT_OPTIONS,
|
||||
DOMAIN,
|
||||
STEP_CONN_STRING,
|
||||
STEP_SAS,
|
||||
STEP_USER,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
BASE_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_EVENT_HUB_INSTANCE_NAME): str,
|
||||
vol.Optional(CONF_USE_CONN_STRING, default=False): bool,
|
||||
}
|
||||
)
|
||||
|
||||
CONN_STRING_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_EVENT_HUB_CON_STRING): str,
|
||||
}
|
||||
)
|
||||
|
||||
SAS_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_EVENT_HUB_NAMESPACE): str,
|
||||
vol.Required(CONF_EVENT_HUB_SAS_POLICY): str,
|
||||
vol.Required(CONF_EVENT_HUB_SAS_KEY): str,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def validate_data(data: dict[str, Any]) -> dict[str, str] | None:
|
||||
"""Validate the input."""
|
||||
client = AzureEventHubClient.from_input(**data)
|
||||
try:
|
||||
await client.test_connection()
|
||||
except EventHubError:
|
||||
return {"base": "cannot_connect"}
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unknown error")
|
||||
return {"base": "unknown"}
|
||||
return None
|
||||
|
||||
|
||||
class AEHConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for azure event hub."""
|
||||
|
||||
VERSION: int = 1
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the config flow."""
|
||||
self._data: dict[str, Any] = {}
|
||||
self._options: dict[str, Any] = deepcopy(DEFAULT_OPTIONS)
|
||||
self._conn_string: bool | None = None
|
||||
|
||||
@staticmethod
|
||||
@callback
|
||||
def async_get_options_flow(config_entry):
|
||||
"""Get the options flow for this handler."""
|
||||
return AEHOptionsFlowHandler(config_entry)
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle the initial user step."""
|
||||
if self._async_current_entries():
|
||||
return self.async_abort(reason="single_instance_allowed")
|
||||
if user_input is None:
|
||||
return self.async_show_form(step_id=STEP_USER, data_schema=BASE_SCHEMA)
|
||||
|
||||
self._conn_string = user_input.pop(CONF_USE_CONN_STRING)
|
||||
self._data = user_input
|
||||
|
||||
if self._conn_string:
|
||||
return await self.async_step_conn_string()
|
||||
return await self.async_step_sas()
|
||||
|
||||
async def async_step_conn_string(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle the connection string steps."""
|
||||
errors = await self.async_update_and_validate_data(user_input)
|
||||
if user_input is None or errors is not None:
|
||||
return self.async_show_form(
|
||||
step_id=STEP_CONN_STRING,
|
||||
data_schema=CONN_STRING_SCHEMA,
|
||||
errors=errors,
|
||||
description_placeholders=self._data[CONF_EVENT_HUB_INSTANCE_NAME],
|
||||
last_step=True,
|
||||
)
|
||||
|
||||
return self.async_create_entry(
|
||||
title=self._data[CONF_EVENT_HUB_INSTANCE_NAME],
|
||||
data=self._data,
|
||||
options=self._options,
|
||||
)
|
||||
|
||||
async def async_step_sas(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle the sas steps."""
|
||||
errors = await self.async_update_and_validate_data(user_input)
|
||||
if user_input is None or errors is not None:
|
||||
return self.async_show_form(
|
||||
step_id=STEP_SAS,
|
||||
data_schema=SAS_SCHEMA,
|
||||
errors=errors,
|
||||
description_placeholders=self._data[CONF_EVENT_HUB_INSTANCE_NAME],
|
||||
last_step=True,
|
||||
)
|
||||
|
||||
return self.async_create_entry(
|
||||
title=self._data[CONF_EVENT_HUB_INSTANCE_NAME],
|
||||
data=self._data,
|
||||
options=self._options,
|
||||
)
|
||||
|
||||
async def async_step_import(self, import_config: dict[str, Any]) -> FlowResult:
|
||||
"""Import config from configuration.yaml."""
|
||||
if self._async_current_entries():
|
||||
return self.async_abort(reason="single_instance_allowed")
|
||||
if CONF_SEND_INTERVAL in import_config:
|
||||
self._options[CONF_SEND_INTERVAL] = import_config.pop(CONF_SEND_INTERVAL)
|
||||
if CONF_MAX_DELAY in import_config:
|
||||
self._options[CONF_MAX_DELAY] = import_config.pop(CONF_MAX_DELAY)
|
||||
self._data = import_config
|
||||
errors = await validate_data(self._data)
|
||||
if errors:
|
||||
return self.async_abort(reason=errors["base"])
|
||||
return self.async_create_entry(
|
||||
title=self._data[CONF_EVENT_HUB_INSTANCE_NAME],
|
||||
data=self._data,
|
||||
options=self._options,
|
||||
)
|
||||
|
||||
async def async_update_and_validate_data(
|
||||
self, user_input: dict[str, Any] | None
|
||||
) -> dict[str, str] | None:
|
||||
"""Validate the input."""
|
||||
if user_input is None:
|
||||
return None
|
||||
self._data.update(user_input)
|
||||
return await validate_data(self._data)
|
||||
|
||||
|
||||
class AEHOptionsFlowHandler(config_entries.OptionsFlow):
|
||||
"""Handle azure event hub options."""
|
||||
|
||||
def __init__(self, config_entry):
|
||||
"""Initialize AEH options flow."""
|
||||
self.config_entry = config_entry
|
||||
self.options = deepcopy(dict(config_entry.options))
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Manage the AEH options."""
|
||||
if user_input is not None:
|
||||
return self.async_create_entry(title="", data=user_input)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Required(
|
||||
CONF_SEND_INTERVAL,
|
||||
default=self.options.get(CONF_SEND_INTERVAL),
|
||||
): int
|
||||
}
|
||||
),
|
||||
last_step=True,
|
||||
)
|
|
@ -5,6 +5,7 @@ from typing import Any
|
|||
|
||||
DOMAIN = "azure_event_hub"
|
||||
|
||||
CONF_USE_CONN_STRING = "use_connection_string"
|
||||
CONF_EVENT_HUB_NAMESPACE = "event_hub_namespace"
|
||||
CONF_EVENT_HUB_INSTANCE_NAME = "event_hub_instance_name"
|
||||
CONF_EVENT_HUB_SAS_POLICY = "event_hub_sas_policy"
|
||||
|
@ -12,6 +13,17 @@ CONF_EVENT_HUB_SAS_KEY = "event_hub_sas_key"
|
|||
CONF_EVENT_HUB_CON_STRING = "event_hub_connection_string"
|
||||
CONF_SEND_INTERVAL = "send_interval"
|
||||
CONF_MAX_DELAY = "max_delay"
|
||||
CONF_FILTER = "filter"
|
||||
CONF_FILTER = DATA_FILTER = "filter"
|
||||
DATA_HUB = "hub"
|
||||
|
||||
STEP_USER = "user"
|
||||
STEP_SAS = "sas"
|
||||
STEP_CONN_STRING = "conn_string"
|
||||
|
||||
DEFAULT_SEND_INTERVAL: int = 5
|
||||
DEFAULT_MAX_DELAY: int = 30
|
||||
DEFAULT_OPTIONS: dict[str, Any] = {
|
||||
CONF_SEND_INTERVAL: DEFAULT_SEND_INTERVAL,
|
||||
}
|
||||
|
||||
ADDITIONAL_ARGS: dict[str, Any] = {"logging_enable": False}
|
||||
|
|
|
@ -4,5 +4,6 @@
|
|||
"documentation": "https://www.home-assistant.io/integrations/azure_event_hub",
|
||||
"requirements": ["azure-eventhub==5.5.0"],
|
||||
"codeowners": ["@eavanvalkenburg"],
|
||||
"iot_class": "cloud_push"
|
||||
"iot_class": "cloud_push",
|
||||
"config_flow": true
|
||||
}
|
||||
|
|
49
homeassistant/components/azure_event_hub/strings.json
Normal file
49
homeassistant/components/azure_event_hub/strings.json
Normal file
|
@ -0,0 +1,49 @@
|
|||
{
|
||||
"config": {
|
||||
"step": {
|
||||
"user": {
|
||||
"title": "Setup your Azure Event Hub integration",
|
||||
"data": {
|
||||
"event_hub_instance_name": "Event Hub Instance Name",
|
||||
"use_connection_string": "Use Connection String"
|
||||
}
|
||||
},
|
||||
"conn_string": {
|
||||
"title": "Connection String method",
|
||||
"description": "Please enter the connection string for: {event_hub_instance_name}",
|
||||
"data": {
|
||||
"event_hub_connection_string": "Event Hub Connection String"
|
||||
}
|
||||
},
|
||||
"sas": {
|
||||
"title": "SAS Credentials method",
|
||||
"description": "Please enter the SAS (shared access signature) credentials for: {event_hub_instance_name}",
|
||||
"data": {
|
||||
"event_hub_namespace": "Event Hub Namespace",
|
||||
"event_hub_sas_policy": "Event Hub SAS Policy",
|
||||
"event_hub_sas_key": "Event Hub SAS Key"
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||
},
|
||||
"abort": {
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]",
|
||||
"single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]",
|
||||
"cannot_connect": "Connecting with the credentails from the configuration.yaml failed, please remove from yaml and use the config flow.",
|
||||
"unknown": "Connecting with the credentails from the configuration.yaml failed with an unknown error, please remove from yaml and use the config flow."
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"options": {
|
||||
"title": "Options for the Azure Event Hub.",
|
||||
"data": {
|
||||
"send_interval": "Interval between sending batches to the hub."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -36,6 +36,7 @@ FLOWS = [
|
|||
"awair",
|
||||
"axis",
|
||||
"azure_devops",
|
||||
"azure_event_hub",
|
||||
"balboa",
|
||||
"blebox",
|
||||
"blink",
|
||||
|
|
126
tests/components/azure_event_hub/conftest.py
Normal file
126
tests/components/azure_event_hub/conftest.py
Normal file
|
@ -0,0 +1,126 @@
|
|||
"""Test fixtures for AEH."""
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from azure.eventhub.aio import EventHubProducerClient
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.azure_event_hub.const import (
|
||||
CONF_FILTER,
|
||||
CONF_SEND_INTERVAL,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.const import STATE_ON
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from .const import AZURE_EVENT_HUB_PATH, BASIC_OPTIONS, PRODUCER_PATH, SAS_CONFIG_FULL
|
||||
|
||||
from tests.common import MockConfigEntry, async_fire_time_changed
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# fixtures for both init and config flow tests
|
||||
@pytest.fixture(autouse=True, name="mock_get_eventhub_properties")
|
||||
def mock_get_eventhub_properties_fixture():
|
||||
"""Mock azure event hub properties, used to test the connection."""
|
||||
with patch(f"{PRODUCER_PATH}.get_eventhub_properties") as get_eventhub_properties:
|
||||
yield get_eventhub_properties
|
||||
|
||||
|
||||
@pytest.fixture(name="filter_schema")
|
||||
def mock_filter_schema():
|
||||
"""Return an empty filter."""
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture(name="entry")
|
||||
async def mock_entry_fixture(hass, filter_schema, mock_create_batch, mock_send_batch):
|
||||
"""Create the setup in HA."""
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data=SAS_CONFIG_FULL,
|
||||
title="test-instance",
|
||||
options=BASIC_OPTIONS,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
assert await async_setup_component(
|
||||
hass, DOMAIN, {DOMAIN: {CONF_FILTER: filter_schema}}
|
||||
)
|
||||
assert entry.state == ConfigEntryState.LOADED
|
||||
|
||||
# Clear the component_loaded event from the queue.
|
||||
async_fire_time_changed(
|
||||
hass,
|
||||
utcnow() + timedelta(seconds=entry.options[CONF_SEND_INTERVAL]),
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
return entry
|
||||
|
||||
|
||||
# fixtures for init tests
|
||||
@pytest.fixture(name="entry_with_one_event")
|
||||
async def mock_entry_with_one_event(hass, entry):
|
||||
"""Use the entry and add a single test event to the queue."""
|
||||
assert entry.state == ConfigEntryState.LOADED
|
||||
hass.states.async_set("sensor.test", STATE_ON)
|
||||
return entry
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterTest:
|
||||
"""Class for capturing a filter test."""
|
||||
|
||||
entity_id: str
|
||||
expected_count: int
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_send_batch")
|
||||
def mock_send_batch_fixture():
|
||||
"""Mock send_batch."""
|
||||
with patch(f"{PRODUCER_PATH}.send_batch") as mock_send_batch:
|
||||
yield mock_send_batch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, name="mock_client")
|
||||
def mock_client_fixture(mock_send_batch):
|
||||
"""Mock the azure event hub producer client."""
|
||||
with patch(f"{PRODUCER_PATH}.close") as mock_close:
|
||||
yield (
|
||||
mock_send_batch,
|
||||
mock_close,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_create_batch")
|
||||
def mock_create_batch_fixture():
|
||||
"""Mock batch creator and return mocked batch object."""
|
||||
mock_batch = MagicMock()
|
||||
with patch(f"{PRODUCER_PATH}.create_batch", return_value=mock_batch):
|
||||
yield mock_batch
|
||||
|
||||
|
||||
# fixtures for config flow tests
|
||||
@pytest.fixture(name="mock_from_connection_string")
|
||||
def mock_from_connection_string_fixture():
|
||||
"""Mock AEH from connection string creation."""
|
||||
mock_aeh = MagicMock(spec=EventHubProducerClient)
|
||||
mock_aeh.__aenter__.return_value = mock_aeh
|
||||
with patch(
|
||||
f"{PRODUCER_PATH}.from_connection_string",
|
||||
return_value=mock_aeh,
|
||||
) as from_conn_string:
|
||||
yield from_conn_string
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_setup_entry")
|
||||
def mock_setup_entry():
|
||||
"""Mock the setup entry call, used for config flow tests."""
|
||||
with patch(
|
||||
f"{AZURE_EVENT_HUB_PATH}.async_setup_entry", return_value=True
|
||||
) as setup_entry:
|
||||
yield setup_entry
|
56
tests/components/azure_event_hub/const.py
Normal file
56
tests/components/azure_event_hub/const.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
"""Constants for testing AEH."""
|
||||
from homeassistant.components.azure_event_hub.const import (
|
||||
CONF_EVENT_HUB_CON_STRING,
|
||||
CONF_EVENT_HUB_INSTANCE_NAME,
|
||||
CONF_EVENT_HUB_NAMESPACE,
|
||||
CONF_EVENT_HUB_SAS_KEY,
|
||||
CONF_EVENT_HUB_SAS_POLICY,
|
||||
CONF_MAX_DELAY,
|
||||
CONF_SEND_INTERVAL,
|
||||
CONF_USE_CONN_STRING,
|
||||
)
|
||||
|
||||
AZURE_EVENT_HUB_PATH = "homeassistant.components.azure_event_hub"
|
||||
PRODUCER_PATH = f"{AZURE_EVENT_HUB_PATH}.client.EventHubProducerClient"
|
||||
CLIENT_PATH = f"{AZURE_EVENT_HUB_PATH}.client.AzureEventHubClient"
|
||||
CONFIG_FLOW_PATH = f"{AZURE_EVENT_HUB_PATH}.config_flow"
|
||||
|
||||
BASE_CONFIG_CS = {
|
||||
CONF_EVENT_HUB_INSTANCE_NAME: "test-instance",
|
||||
CONF_USE_CONN_STRING: True,
|
||||
}
|
||||
BASE_CONFIG_SAS = {
|
||||
CONF_EVENT_HUB_INSTANCE_NAME: "test-instance",
|
||||
CONF_USE_CONN_STRING: False,
|
||||
}
|
||||
|
||||
CS_CONFIG = {CONF_EVENT_HUB_CON_STRING: "test-cs"}
|
||||
SAS_CONFIG = {
|
||||
CONF_EVENT_HUB_NAMESPACE: "test-ns",
|
||||
CONF_EVENT_HUB_SAS_POLICY: "test-policy",
|
||||
CONF_EVENT_HUB_SAS_KEY: "test-key",
|
||||
}
|
||||
CS_CONFIG_FULL = {
|
||||
CONF_EVENT_HUB_INSTANCE_NAME: "test-instance",
|
||||
CONF_EVENT_HUB_CON_STRING: "test-cs",
|
||||
}
|
||||
SAS_CONFIG_FULL = {
|
||||
CONF_EVENT_HUB_INSTANCE_NAME: "test-instance",
|
||||
CONF_EVENT_HUB_NAMESPACE: "test-ns",
|
||||
CONF_EVENT_HUB_SAS_POLICY: "test-policy",
|
||||
CONF_EVENT_HUB_SAS_KEY: "test-key",
|
||||
}
|
||||
|
||||
IMPORT_CONFIG = {
|
||||
CONF_EVENT_HUB_INSTANCE_NAME: "test-instance",
|
||||
CONF_EVENT_HUB_NAMESPACE: "test-ns",
|
||||
CONF_EVENT_HUB_SAS_POLICY: "test-policy",
|
||||
CONF_EVENT_HUB_SAS_KEY: "test-key",
|
||||
CONF_SEND_INTERVAL: 5,
|
||||
CONF_MAX_DELAY: 10,
|
||||
}
|
||||
|
||||
BASIC_OPTIONS = {
|
||||
CONF_SEND_INTERVAL: 5,
|
||||
}
|
||||
UPDATE_OPTIONS = {CONF_SEND_INTERVAL: 100}
|
188
tests/components/azure_event_hub/test_config_flow.py
Normal file
188
tests/components/azure_event_hub/test_config_flow.py
Normal file
|
@ -0,0 +1,188 @@
|
|||
"""Test the AEH config flow."""
|
||||
import logging
|
||||
|
||||
from azure.eventhub.exceptions import EventHubError
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant.components.azure_event_hub.const import (
|
||||
CONF_MAX_DELAY,
|
||||
CONF_SEND_INTERVAL,
|
||||
DOMAIN,
|
||||
STEP_CONN_STRING,
|
||||
STEP_SAS,
|
||||
)
|
||||
|
||||
from .const import (
|
||||
BASE_CONFIG_CS,
|
||||
BASE_CONFIG_SAS,
|
||||
CS_CONFIG,
|
||||
CS_CONFIG_FULL,
|
||||
IMPORT_CONFIG,
|
||||
SAS_CONFIG,
|
||||
SAS_CONFIG_FULL,
|
||||
UPDATE_OPTIONS,
|
||||
)
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"step1_config, step_id, step2_config, data_config",
|
||||
[
|
||||
(BASE_CONFIG_CS, STEP_CONN_STRING, CS_CONFIG, CS_CONFIG_FULL),
|
||||
(BASE_CONFIG_SAS, STEP_SAS, SAS_CONFIG, SAS_CONFIG_FULL),
|
||||
],
|
||||
ids=["connection_string", "sas"],
|
||||
)
|
||||
async def test_form(
|
||||
hass,
|
||||
mock_setup_entry,
|
||||
mock_from_connection_string,
|
||||
step1_config,
|
||||
step_id,
|
||||
step2_config,
|
||||
data_config,
|
||||
):
|
||||
"""Test we get the form."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}, data=None
|
||||
)
|
||||
assert result["type"] == "form"
|
||||
assert result["errors"] is None
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
step1_config.copy(),
|
||||
)
|
||||
|
||||
assert result2["type"] == "form"
|
||||
assert result2["step_id"] == step_id
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"],
|
||||
step2_config.copy(),
|
||||
)
|
||||
assert result3["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert result3["title"] == "test-instance"
|
||||
assert result3["data"] == data_config
|
||||
mock_setup_entry.assert_called_once()
|
||||
|
||||
|
||||
async def test_import(hass, mock_setup_entry):
|
||||
"""Test we get the form."""
|
||||
|
||||
import_config = IMPORT_CONFIG.copy()
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": config_entries.SOURCE_IMPORT},
|
||||
data=IMPORT_CONFIG.copy(),
|
||||
)
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert result["title"] == "test-instance"
|
||||
options = {
|
||||
CONF_SEND_INTERVAL: import_config.pop(CONF_SEND_INTERVAL),
|
||||
CONF_MAX_DELAY: import_config.pop(CONF_MAX_DELAY),
|
||||
}
|
||||
assert result["data"] == import_config
|
||||
assert result["options"] == options
|
||||
mock_setup_entry.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"source",
|
||||
[config_entries.SOURCE_USER, config_entries.SOURCE_IMPORT],
|
||||
ids=["user", "import"],
|
||||
)
|
||||
async def test_single_instance(hass, source):
|
||||
"""Test uniqueness of username."""
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data=CS_CONFIG_FULL,
|
||||
title="test-instance",
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": source},
|
||||
data=BASE_CONFIG_CS.copy(),
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "single_instance_allowed"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"side_effect, error_message",
|
||||
[(EventHubError("test"), "cannot_connect"), (Exception, "unknown")],
|
||||
ids=["cannot_connect", "unknown"],
|
||||
)
|
||||
async def test_connection_error_sas(
|
||||
hass,
|
||||
mock_get_eventhub_properties,
|
||||
side_effect,
|
||||
error_message,
|
||||
):
|
||||
"""Test we handle connection errors."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
data=BASE_CONFIG_SAS.copy(),
|
||||
)
|
||||
assert result["type"] == "form"
|
||||
assert result["errors"] is None
|
||||
|
||||
mock_get_eventhub_properties.side_effect = side_effect
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
SAS_CONFIG.copy(),
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["errors"] == {"base": error_message}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"side_effect, error_message",
|
||||
[(EventHubError("test"), "cannot_connect"), (Exception, "unknown")],
|
||||
ids=["cannot_connect", "unknown"],
|
||||
)
|
||||
async def test_connection_error_cs(
|
||||
hass,
|
||||
mock_from_connection_string,
|
||||
side_effect,
|
||||
error_message,
|
||||
):
|
||||
"""Test we handle connection errors."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
data=BASE_CONFIG_CS.copy(),
|
||||
)
|
||||
assert result["type"] == "form"
|
||||
assert result["errors"] is None
|
||||
mock_from_connection_string.return_value.get_eventhub_properties.side_effect = (
|
||||
side_effect
|
||||
)
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
CS_CONFIG.copy(),
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["errors"] == {"base": error_message}
|
||||
|
||||
|
||||
async def test_options_flow(hass, entry):
|
||||
"""Test options flow."""
|
||||
result = await hass.config_entries.options.async_init(entry.entry_id)
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "init"
|
||||
assert result["last_step"]
|
||||
|
||||
updated = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"], UPDATE_OPTIONS
|
||||
)
|
||||
assert updated["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert updated["data"] == UPDATE_OPTIONS
|
||||
await hass.async_block_till_done()
|
|
@ -1,83 +1,31 @@
|
|||
"""The tests for the Azure Event Hub component."""
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import MagicMock, patch
|
||||
"""Test the init functions for AEH."""
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from time import monotonic
|
||||
from unittest.mock import patch
|
||||
|
||||
from azure.eventhub.exceptions import EventHubError
|
||||
import pytest
|
||||
|
||||
import homeassistant.components.azure_event_hub as azure_event_hub
|
||||
from homeassistant.components import azure_event_hub
|
||||
from homeassistant.components.azure_event_hub.const import CONF_SEND_INTERVAL, DOMAIN
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.const import STATE_ON
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
AZURE_EVENT_HUB_PATH = "homeassistant.components.azure_event_hub"
|
||||
PRODUCER_PATH = f"{AZURE_EVENT_HUB_PATH}.EventHubProducerClient"
|
||||
MIN_CONFIG = {
|
||||
"event_hub_namespace": "namespace",
|
||||
"event_hub_instance_name": "name",
|
||||
"event_hub_sas_policy": "policy",
|
||||
"event_hub_sas_key": "key",
|
||||
}
|
||||
from .conftest import FilterTest
|
||||
from .const import AZURE_EVENT_HUB_PATH, BASIC_OPTIONS, CS_CONFIG_FULL, SAS_CONFIG_FULL
|
||||
|
||||
from tests.common import MockConfigEntry, async_fire_time_changed
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterTest:
|
||||
"""Class for capturing a filter test."""
|
||||
|
||||
id: str
|
||||
should_pass: bool
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, name="mock_client", scope="module")
|
||||
def mock_client_fixture():
|
||||
"""Mock the azure event hub producer client."""
|
||||
with patch(f"{PRODUCER_PATH}.send_batch") as mock_send_batch, patch(
|
||||
f"{PRODUCER_PATH}.close"
|
||||
) as mock_close, patch(f"{PRODUCER_PATH}.__init__", return_value=None) as mock_init:
|
||||
yield (
|
||||
mock_init,
|
||||
mock_send_batch,
|
||||
mock_close,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, name="mock_batch")
|
||||
def mock_batch_fixture():
|
||||
"""Mock batch creator and return mocked batch object."""
|
||||
mock_batch = MagicMock()
|
||||
with patch(f"{PRODUCER_PATH}.create_batch", return_value=mock_batch):
|
||||
yield mock_batch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, name="mock_policy")
|
||||
def mock_policy_fixture():
|
||||
"""Mock azure shared key credential."""
|
||||
with patch(f"{AZURE_EVENT_HUB_PATH}.EventHubSharedKeyCredential") as policy:
|
||||
yield policy
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, name="mock_event_data")
|
||||
def mock_event_data_fixture():
|
||||
"""Mock the azure event data component."""
|
||||
with patch(f"{AZURE_EVENT_HUB_PATH}.EventData") as event_data:
|
||||
yield event_data
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, name="mock_call_later")
|
||||
def mock_call_later_fixture():
|
||||
"""Mock async_call_later to allow queue processing on demand."""
|
||||
with patch(f"{AZURE_EVENT_HUB_PATH}.async_call_later") as mock_call_later:
|
||||
yield mock_call_later
|
||||
|
||||
|
||||
async def test_minimal_config(hass):
|
||||
"""Test the minimal config and defaults of component."""
|
||||
config = {azure_event_hub.DOMAIN: MIN_CONFIG}
|
||||
assert await async_setup_component(hass, azure_event_hub.DOMAIN, config)
|
||||
|
||||
|
||||
async def test_full_config(hass):
|
||||
"""Test the full config of component."""
|
||||
async def test_import(hass):
|
||||
"""Test the popping of the filter and further import of the config."""
|
||||
config = {
|
||||
azure_event_hub.DOMAIN: {
|
||||
DOMAIN: {
|
||||
"send_interval": 10,
|
||||
"max_delay": 10,
|
||||
"filter": {
|
||||
|
@ -90,128 +38,178 @@ async def test_full_config(hass):
|
|||
},
|
||||
}
|
||||
}
|
||||
config[azure_event_hub.DOMAIN].update(MIN_CONFIG)
|
||||
assert await async_setup_component(hass, azure_event_hub.DOMAIN, config)
|
||||
config[DOMAIN].update(CS_CONFIG_FULL)
|
||||
assert await async_setup_component(hass, DOMAIN, config)
|
||||
|
||||
|
||||
async def _setup(hass, mock_call_later, filter_config):
|
||||
"""Shared set up for filtering tests."""
|
||||
config = {azure_event_hub.DOMAIN: {"filter": filter_config}}
|
||||
config[azure_event_hub.DOMAIN].update(MIN_CONFIG)
|
||||
async def test_filter_only_config(hass):
|
||||
"""Test the popping of the filter and further import of the config."""
|
||||
config = {
|
||||
DOMAIN: {
|
||||
"filter": {
|
||||
"include_domains": ["light"],
|
||||
"include_entity_globs": ["sensor.included_*"],
|
||||
"include_entities": ["binary_sensor.included"],
|
||||
"exclude_domains": ["light"],
|
||||
"exclude_entity_globs": ["sensor.excluded_*"],
|
||||
"exclude_entities": ["binary_sensor.excluded"],
|
||||
},
|
||||
}
|
||||
}
|
||||
assert await async_setup_component(hass, DOMAIN, config)
|
||||
|
||||
assert await async_setup_component(hass, azure_event_hub.DOMAIN, config)
|
||||
|
||||
async def test_unload_entry(hass, entry, mock_create_batch):
|
||||
"""Test being able to unload an entry.
|
||||
|
||||
Queue should be empty, so adding events to the batch should not be called,
|
||||
this verifies that the unload, calls async_stop, which calls async_send and
|
||||
shuts down the hub.
|
||||
"""
|
||||
assert await hass.config_entries.async_unload(entry.entry_id)
|
||||
mock_create_batch.add.assert_not_called()
|
||||
assert entry.state == ConfigEntryState.NOT_LOADED
|
||||
|
||||
|
||||
async def test_failed_test_connection(hass, mock_get_eventhub_properties):
|
||||
"""Test being able to unload an entry."""
|
||||
entry = MockConfigEntry(
|
||||
domain=azure_event_hub.DOMAIN,
|
||||
data=SAS_CONFIG_FULL,
|
||||
title="test-instance",
|
||||
options=BASIC_OPTIONS,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
mock_get_eventhub_properties.side_effect = EventHubError("Test")
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
assert entry.state == ConfigEntryState.SETUP_RETRY
|
||||
|
||||
|
||||
async def test_send_batch_error(hass, entry_with_one_event, mock_send_batch):
|
||||
"""Test a error in send_batch, including recovering at the next interval."""
|
||||
mock_send_batch.reset_mock()
|
||||
mock_send_batch.side_effect = [EventHubError("Test"), None]
|
||||
async_fire_time_changed(
|
||||
hass,
|
||||
utcnow() + timedelta(seconds=entry_with_one_event.options[CONF_SEND_INTERVAL]),
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
mock_call_later.assert_called_once()
|
||||
return mock_call_later.call_args[0][2]
|
||||
mock_send_batch.assert_called_once()
|
||||
mock_send_batch.reset_mock()
|
||||
|
||||
async_fire_time_changed(
|
||||
hass,
|
||||
utcnow() + timedelta(seconds=entry_with_one_event.options[CONF_SEND_INTERVAL]),
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
mock_send_batch.assert_called_once()
|
||||
|
||||
|
||||
async def _run_filter_tests(hass, tests, process_queue, mock_batch):
|
||||
"""Run a series of filter tests on azure event hub."""
|
||||
for test in tests:
|
||||
hass.states.async_set(test.id, STATE_ON)
|
||||
async def test_late_event(hass, entry_with_one_event, mock_create_batch):
|
||||
"""Test the check on late events."""
|
||||
with patch(
|
||||
f"{AZURE_EVENT_HUB_PATH}.time.monotonic",
|
||||
return_value=monotonic() + timedelta(hours=1).seconds,
|
||||
):
|
||||
async_fire_time_changed(
|
||||
hass,
|
||||
utcnow()
|
||||
+ timedelta(seconds=entry_with_one_event.options[CONF_SEND_INTERVAL]),
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
await process_queue(None)
|
||||
|
||||
if test.should_pass:
|
||||
mock_batch.add.assert_called_once()
|
||||
mock_batch.add.reset_mock()
|
||||
else:
|
||||
mock_batch.add.assert_not_called()
|
||||
mock_create_batch.add.assert_not_called()
|
||||
|
||||
|
||||
async def test_allowlist(hass, mock_batch, mock_call_later):
|
||||
"""Test an allowlist only config."""
|
||||
process_queue = await _setup(
|
||||
async def test_full_batch(hass, entry_with_one_event, mock_create_batch):
|
||||
"""Test the full batch behaviour."""
|
||||
mock_create_batch.add.side_effect = [ValueError, None]
|
||||
async_fire_time_changed(
|
||||
hass,
|
||||
mock_call_later,
|
||||
{
|
||||
"include_domains": ["light"],
|
||||
"include_entity_globs": ["sensor.included_*"],
|
||||
"include_entities": ["binary_sensor.included"],
|
||||
},
|
||||
utcnow() + timedelta(seconds=entry_with_one_event.options[CONF_SEND_INTERVAL]),
|
||||
)
|
||||
|
||||
tests = [
|
||||
FilterTest("climate.excluded", False),
|
||||
FilterTest("light.included", True),
|
||||
FilterTest("sensor.excluded_test", False),
|
||||
FilterTest("sensor.included_test", True),
|
||||
FilterTest("binary_sensor.included", True),
|
||||
FilterTest("binary_sensor.excluded", False),
|
||||
]
|
||||
|
||||
await _run_filter_tests(hass, tests, process_queue, mock_batch)
|
||||
await hass.async_block_till_done()
|
||||
assert mock_create_batch.add.call_count == 2
|
||||
|
||||
|
||||
async def test_denylist(hass, mock_batch, mock_call_later):
|
||||
"""Test a denylist only config."""
|
||||
process_queue = await _setup(
|
||||
hass,
|
||||
mock_call_later,
|
||||
{
|
||||
"exclude_domains": ["climate"],
|
||||
"exclude_entity_globs": ["sensor.excluded_*"],
|
||||
"exclude_entities": ["binary_sensor.excluded"],
|
||||
},
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"filter_schema, tests",
|
||||
[
|
||||
(
|
||||
{
|
||||
"include_domains": ["light"],
|
||||
"include_entity_globs": ["sensor.included_*"],
|
||||
"include_entities": ["binary_sensor.included"],
|
||||
},
|
||||
[
|
||||
FilterTest("climate.excluded", 0),
|
||||
FilterTest("light.included", 1),
|
||||
FilterTest("sensor.excluded_test", 0),
|
||||
FilterTest("sensor.included_test", 1),
|
||||
FilterTest("binary_sensor.included", 1),
|
||||
FilterTest("binary_sensor.excluded", 0),
|
||||
],
|
||||
),
|
||||
(
|
||||
{
|
||||
"exclude_domains": ["climate"],
|
||||
"exclude_entity_globs": ["sensor.excluded_*"],
|
||||
"exclude_entities": ["binary_sensor.excluded"],
|
||||
},
|
||||
[
|
||||
FilterTest("climate.excluded", 0),
|
||||
FilterTest("light.included", 1),
|
||||
FilterTest("sensor.excluded_test", 0),
|
||||
FilterTest("sensor.included_test", 1),
|
||||
FilterTest("binary_sensor.included", 1),
|
||||
FilterTest("binary_sensor.excluded", 0),
|
||||
],
|
||||
),
|
||||
(
|
||||
{
|
||||
"include_domains": ["light"],
|
||||
"include_entity_globs": ["*.included_*"],
|
||||
"exclude_domains": ["climate"],
|
||||
"exclude_entity_globs": ["*.excluded_*"],
|
||||
"exclude_entities": ["light.excluded"],
|
||||
},
|
||||
[
|
||||
FilterTest("light.included", 1),
|
||||
FilterTest("light.excluded_test", 0),
|
||||
FilterTest("light.excluded", 0),
|
||||
FilterTest("sensor.included_test", 1),
|
||||
FilterTest("climate.included_test", 0),
|
||||
],
|
||||
),
|
||||
(
|
||||
{
|
||||
"include_entities": ["climate.included", "sensor.excluded_test"],
|
||||
"exclude_domains": ["climate"],
|
||||
"exclude_entity_globs": ["*.excluded_*"],
|
||||
"exclude_entities": ["light.excluded"],
|
||||
},
|
||||
[
|
||||
FilterTest("climate.excluded", 0),
|
||||
FilterTest("climate.included", 1),
|
||||
FilterTest("switch.excluded_test", 0),
|
||||
FilterTest("sensor.excluded_test", 1),
|
||||
FilterTest("light.excluded", 0),
|
||||
FilterTest("light.included", 1),
|
||||
],
|
||||
),
|
||||
],
|
||||
ids=["allowlist", "denylist", "filtered_allowlist", "filtered_denylist"],
|
||||
)
|
||||
async def test_filter(hass, entry, tests, mock_create_batch):
|
||||
"""Test different filters.
|
||||
|
||||
tests = [
|
||||
FilterTest("climate.excluded", False),
|
||||
FilterTest("light.included", True),
|
||||
FilterTest("sensor.excluded_test", False),
|
||||
FilterTest("sensor.included_test", True),
|
||||
FilterTest("binary_sensor.included", True),
|
||||
FilterTest("binary_sensor.excluded", False),
|
||||
]
|
||||
|
||||
await _run_filter_tests(hass, tests, process_queue, mock_batch)
|
||||
|
||||
|
||||
async def test_filtered_allowlist(hass, mock_batch, mock_call_later):
|
||||
"""Test an allowlist config with a filtering denylist."""
|
||||
process_queue = await _setup(
|
||||
hass,
|
||||
mock_call_later,
|
||||
{
|
||||
"include_domains": ["light"],
|
||||
"include_entity_globs": ["*.included_*"],
|
||||
"exclude_domains": ["climate"],
|
||||
"exclude_entity_globs": ["*.excluded_*"],
|
||||
"exclude_entities": ["light.excluded"],
|
||||
},
|
||||
)
|
||||
|
||||
tests = [
|
||||
FilterTest("light.included", True),
|
||||
FilterTest("light.excluded_test", False),
|
||||
FilterTest("light.excluded", False),
|
||||
FilterTest("sensor.included_test", True),
|
||||
FilterTest("climate.included_test", False),
|
||||
]
|
||||
|
||||
await _run_filter_tests(hass, tests, process_queue, mock_batch)
|
||||
|
||||
|
||||
async def test_filtered_denylist(hass, mock_batch, mock_call_later):
|
||||
"""Test a denylist config with a filtering allowlist."""
|
||||
process_queue = await _setup(
|
||||
hass,
|
||||
mock_call_later,
|
||||
{
|
||||
"include_entities": ["climate.included", "sensor.excluded_test"],
|
||||
"exclude_domains": ["climate"],
|
||||
"exclude_entity_globs": ["*.excluded_*"],
|
||||
"exclude_entities": ["light.excluded"],
|
||||
},
|
||||
)
|
||||
|
||||
tests = [
|
||||
FilterTest("climate.excluded", False),
|
||||
FilterTest("climate.included", True),
|
||||
FilterTest("switch.excluded_test", False),
|
||||
FilterTest("sensor.excluded_test", True),
|
||||
FilterTest("light.excluded", False),
|
||||
FilterTest("light.included", True),
|
||||
]
|
||||
|
||||
await _run_filter_tests(hass, tests, process_queue, mock_batch)
|
||||
Filter_schema is also a fixture which is replaced by the filter_schema
|
||||
in the parametrize and added to the entry fixture.
|
||||
"""
|
||||
for test in tests:
|
||||
hass.states.async_set(test.entity_id, STATE_ON)
|
||||
async_fire_time_changed(
|
||||
hass, utcnow() + timedelta(seconds=entry.options[CONF_SEND_INTERVAL])
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert mock_create_batch.add.call_count == test.expected_count
|
||||
mock_create_batch.add.reset_mock()
|
||||
|
|
Loading…
Add table
Reference in a new issue