Re-add event listeners after Z-Wave server disconnection (#94383)

* Re-add event listeners after Z-Wave server disconnection

* switch order

* Add tests
This commit is contained in:
Raman Gupta 2023-06-11 02:35:52 -04:00 committed by GitHub
parent eab024992e
commit 41d8ba3397
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 187 additions and 42 deletions

View file

@ -1,18 +1,20 @@
"""Offer Z-Wave JS event listening automation trigger."""
from __future__ import annotations
from collections.abc import Callable
import functools
from pydantic import ValidationError
import voluptuous as vol
from zwave_js_server.client import Client
from zwave_js_server.model.controller import CONTROLLER_EVENT_MODEL_MAP
from zwave_js_server.model.driver import DRIVER_EVENT_MODEL_MAP
from zwave_js_server.model.driver import DRIVER_EVENT_MODEL_MAP, Driver
from zwave_js_server.model.node import NODE_EVENT_MODEL_MAP
from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType
@ -150,7 +152,7 @@ async def async_attach_trigger(
event_name = config[ATTR_EVENT]
event_data_filter = config.get(ATTR_EVENT_DATA, {})
unsubs = []
unsubs: list[Callable] = []
job = HassJob(action)
trigger_data = trigger_info["trigger_data"]
@ -199,26 +201,6 @@ async def async_attach_trigger(
hass.async_run_hass_job(job, {"trigger": payload})
if not nodes:
entry_id = config[ATTR_CONFIG_ENTRY_ID]
client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT]
assert client.driver
if event_source == "controller":
unsubs.append(client.driver.controller.on(event_name, async_on_event))
else:
unsubs.append(client.driver.on(event_name, async_on_event))
for node in nodes:
driver = node.client.driver
assert driver is not None # The node comes from the driver.
device_identifier = get_device_id(driver, node)
device = dev_reg.async_get_device({device_identifier})
assert device
# We need to store the device for the callback
unsubs.append(
node.on(event_name, functools.partial(async_on_event, device=device))
)
@callback
def async_remove() -> None:
"""Remove state listeners async."""
@ -226,4 +208,45 @@ async def async_attach_trigger(
unsub()
unsubs.clear()
@callback
def _create_zwave_listeners() -> None:
"""Create Z-Wave JS listeners."""
async_remove()
# Nodes list can come from different drivers and we will need to listen to
# server connections for all of them.
drivers: set[Driver] = set()
if not nodes:
entry_id = config[ATTR_CONFIG_ENTRY_ID]
client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT]
driver = client.driver
assert driver
drivers.add(driver)
if event_source == "controller":
unsubs.append(driver.controller.on(event_name, async_on_event))
else:
unsubs.append(driver.on(event_name, async_on_event))
for node in nodes:
driver = node.client.driver
assert driver is not None # The node comes from the driver.
drivers.add(driver)
device_identifier = get_device_id(driver, node)
device = dev_reg.async_get_device({device_identifier})
assert device
# We need to store the device for the callback
unsubs.append(
node.on(event_name, functools.partial(async_on_event, device=device))
)
for driver in drivers:
unsubs.append(
async_dispatcher_connect(
hass,
f"{DOMAIN}_{driver.controller.home_id}_connected_to_server",
_create_zwave_listeners,
)
)
_create_zwave_listeners()
return async_remove