diff --git a/homeassistant/components/zwave_js/__init__.py b/homeassistant/components/zwave_js/__init__.py index 82e79b83659..1a2cdfa7017 100644 --- a/homeassistant/components/zwave_js/__init__.py +++ b/homeassistant/components/zwave_js/__init__.py @@ -1,9 +1,11 @@ """The Z-Wave JS integration.""" import asyncio import logging +from typing import Callable, List from async_timeout import timeout from zwave_js_server.client import Client as ZwaveClient +from zwave_js_server.exceptions import BaseZwaveJSServerError from zwave_js_server.model.node import Node as ZwaveNode from zwave_js_server.model.notification import Notification from zwave_js_server.model.value import ValueNotification @@ -45,6 +47,8 @@ from .entity import get_device_id LOGGER = logging.getLogger(__name__) CONNECT_TIMEOUT = 10 +DATA_CLIENT_LISTEN_TASK = "client_listen_task" +DATA_START_PLATFORM_TASK = "start_platform_task" async def async_setup(hass: HomeAssistant, config: dict) -> bool: @@ -77,45 +81,8 @@ def register_node_in_dev_reg( async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up Z-Wave JS from a config entry.""" client = ZwaveClient(entry.data[CONF_URL], async_get_clientsession(hass)) - connected = asyncio.Event() - initialized = asyncio.Event() dev_reg = await device_registry.async_get_registry(hass) - async def async_on_connect() -> None: - """Handle websocket is (re)connected.""" - LOGGER.info("Connected to Zwave JS Server") - connected.set() - - async def async_on_disconnect() -> None: - """Handle websocket is disconnected.""" - LOGGER.info("Disconnected from Zwave JS Server") - connected.clear() - if initialized.is_set(): - initialized.clear() - # update entity availability - async_dispatcher_send(hass, f"{DOMAIN}_{entry.entry_id}_connection_state") - - async def async_on_initialized() -> None: - """Handle initial full state received.""" - LOGGER.info("Connection to Zwave JS Server initialized.") - initialized.set() - # update entity availability - async_dispatcher_send(hass, f"{DOMAIN}_{entry.entry_id}_connection_state") - - # Check for nodes that no longer exist and remove them - stored_devices = device_registry.async_entries_for_config_entry( - dev_reg, entry.entry_id - ) - known_devices = [ - dev_reg.async_get_device({get_device_id(client, node)}) - for node in client.driver.controller.nodes.values() - ] - - # Devices that are in the device registry that are not known by the controller can be removed - for device in stored_devices: - if device not in known_devices: - dev_reg.async_remove_device(device.id) - @callback def async_on_node_ready(node: ZwaveNode) -> None: """Handle node ready event.""" @@ -209,32 +176,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: }, ) - async def handle_ha_shutdown(event: Event) -> None: - """Handle HA shutdown.""" - await client.disconnect() - - # register main event callbacks. - unsubs = [ - client.register_on_initialized(async_on_initialized), - client.register_on_disconnect(async_on_disconnect), - client.register_on_connect(async_on_connect), - hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_ha_shutdown), - ] - # connect and throw error if connection failed - asyncio.create_task(client.connect()) try: async with timeout(CONNECT_TIMEOUT): - await connected.wait() - except asyncio.TimeoutError as err: - for unsub in unsubs: - unsub() - await client.disconnect() + await client.connect() + except (asyncio.TimeoutError, BaseZwaveJSServerError) as err: raise ConfigEntryNotReady from err + else: + LOGGER.info("Connected to Zwave JS Server") + unsubscribe_callbacks: List[Callable] = [] hass.data[DOMAIN][entry.entry_id] = { DATA_CLIENT: client, - DATA_UNSUBSCRIBE: unsubs, + DATA_UNSUBSCRIBE: unsubscribe_callbacks, } # Set up websocket API @@ -250,9 +204,37 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ] ) - # Wait till we're initialized - LOGGER.info("Waiting for Z-Wave to be fully initialized") - await initialized.wait() + driver_ready = asyncio.Event() + + async def handle_ha_shutdown(event: Event) -> None: + """Handle HA shutdown.""" + await disconnect_client(hass, entry, client, listen_task, platform_task) + + listen_task = asyncio.create_task( + client_listen(hass, entry, client, driver_ready) + ) + hass.data[DOMAIN][entry.entry_id][DATA_CLIENT_LISTEN_TASK] = listen_task + unsubscribe_callbacks.append( + hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_ha_shutdown) + ) + + await driver_ready.wait() + + LOGGER.info("Connection to Zwave JS Server initialized") + + # Check for nodes that no longer exist and remove them + stored_devices = device_registry.async_entries_for_config_entry( + dev_reg, entry.entry_id + ) + known_devices = [ + dev_reg.async_get_device({get_device_id(client, node)}) + for node in client.driver.controller.nodes.values() + ] + + # Devices that are in the device registry that are not known by the controller can be removed + for device in stored_devices: + if device not in known_devices: + dev_reg.async_remove_device(device.id) # run discovery on all ready nodes for node in client.driver.controller.nodes.values(): @@ -268,11 +250,46 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: "node removed", lambda event: async_on_node_removed(event["node"]) ) - hass.async_create_task(start_platforms()) + platform_task = hass.async_create_task(start_platforms()) + hass.data[DOMAIN][entry.entry_id][DATA_START_PLATFORM_TASK] = platform_task return True +async def client_listen( + hass: HomeAssistant, + entry: ConfigEntry, + client: ZwaveClient, + driver_ready: asyncio.Event, +) -> None: + """Listen with the client.""" + try: + await client.listen(driver_ready) + except BaseZwaveJSServerError: + # The entry needs to be reloaded since a new driver state + # will be acquired on reconnect. + # All model instances will be replaced when the new state is acquired. + hass.async_create_task(hass.config_entries.async_reload(entry.entry_id)) + + +async def disconnect_client( + hass: HomeAssistant, + entry: ConfigEntry, + client: ZwaveClient, + listen_task: asyncio.Task, + platform_task: asyncio.Task, +) -> None: + """Disconnect client.""" + await client.disconnect() + + listen_task.cancel() + platform_task.cancel() + + await asyncio.gather(listen_task, platform_task) + + LOGGER.info("Disconnected from Zwave JS Server") + + async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" unload_ok = all( @@ -291,7 +308,14 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: for unsub in info[DATA_UNSUBSCRIBE]: unsub() - await info[DATA_CLIENT].disconnect() + if DATA_CLIENT_LISTEN_TASK in info: + await disconnect_client( + hass, + entry, + info[DATA_CLIENT], + info[DATA_CLIENT_LISTEN_TASK], + platform_task=info[DATA_START_PLATFORM_TASK], + ) return True diff --git a/homeassistant/components/zwave_js/entity.py b/homeassistant/components/zwave_js/entity.py index b039113270d..334a2cccd4f 100644 --- a/homeassistant/components/zwave_js/entity.py +++ b/homeassistant/components/zwave_js/entity.py @@ -9,7 +9,6 @@ from zwave_js_server.model.value import Value as ZwaveValue, get_value_id from homeassistant.config_entries import ConfigEntry from homeassistant.core import callback -from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity import Entity from .const import DOMAIN @@ -54,14 +53,6 @@ class ZWaveBaseEntity(Entity): self.info.node.on(EVENT_VALUE_UPDATED, self._value_changed) ) - self.async_on_remove( - async_dispatcher_connect( - self.hass, - f"{DOMAIN}_{self.config_entry.entry_id}_connection_state", - self.async_write_ha_state, - ) - ) - @property def device_info(self) -> dict: """Return device information for the device registry.""" diff --git a/homeassistant/components/zwave_js/manifest.json b/homeassistant/components/zwave_js/manifest.json index 586a6492a1a..de77ebbf5e0 100644 --- a/homeassistant/components/zwave_js/manifest.json +++ b/homeassistant/components/zwave_js/manifest.json @@ -3,7 +3,7 @@ "name": "Z-Wave JS", "config_flow": true, "documentation": "https://www.home-assistant.io/integrations/zwave_js", - "requirements": ["zwave-js-server-python==0.15.0"], + "requirements": ["zwave-js-server-python==0.16.0"], "codeowners": ["@home-assistant/z-wave"], "dependencies": ["http", "websocket_api"] } diff --git a/requirements_all.txt b/requirements_all.txt index 2b8efbff9e1..21211668779 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2384,4 +2384,4 @@ zigpy==0.32.0 zm-py==0.5.2 # homeassistant.components.zwave_js -zwave-js-server-python==0.15.0 +zwave-js-server-python==0.16.0 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 1998d1baf0c..c2d39fd320e 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -1203,4 +1203,4 @@ zigpy-znp==0.3.0 zigpy==0.32.0 # homeassistant.components.zwave_js -zwave-js-server-python==0.15.0 +zwave-js-server-python==0.16.0 diff --git a/tests/components/zwave_js/conftest.py b/tests/components/zwave_js/conftest.py index 0e0ebdee3c6..b5301f4cd2f 100644 --- a/tests/components/zwave_js/conftest.py +++ b/tests/components/zwave_js/conftest.py @@ -1,6 +1,7 @@ """Provide common Z-Wave JS fixtures.""" +import asyncio import json -from unittest.mock import DEFAULT, Mock, patch +from unittest.mock import DEFAULT, AsyncMock, patch import pytest from zwave_js_server.event import Event @@ -149,35 +150,31 @@ def in_wall_smart_fan_control_state_fixture(): def mock_client_fixture(controller_state, version_state): """Mock a client.""" - def mock_callback(): - callbacks = [] - - def add_callback(cb): - callbacks.append(cb) - return DEFAULT - - return callbacks, Mock(side_effect=add_callback) - with patch( "homeassistant.components.zwave_js.ZwaveClient", autospec=True ) as client_class: client = client_class.return_value - connect_callback, client.register_on_connect = mock_callback() - initialized_callback, client.register_on_initialized = mock_callback() - async def connect(): - for cb in connect_callback: - await cb() + await asyncio.sleep(0) + client.state = "connected" + client.connected = True - for cb in initialized_callback: - await cb() + async def listen(driver_ready: asyncio.Event) -> None: + driver_ready.set() - client.connect = Mock(side_effect=connect) + async def disconnect(): + client.state = "disconnected" + client.connected = False + + client.connect = AsyncMock(side_effect=connect) + client.listen = AsyncMock(side_effect=listen) + client.disconnect = AsyncMock(side_effect=disconnect) client.driver = Driver(client, controller_state) + client.version = VersionInfo.from_message(version_state) client.ws_server_url = "ws://test:3000/zjs" - client.state = "connected" + yield client diff --git a/tests/components/zwave_js/test_init.py b/tests/components/zwave_js/test_init.py index fa61b3deb27..1aad07400ad 100644 --- a/tests/components/zwave_js/test_init.py +++ b/tests/components/zwave_js/test_init.py @@ -50,17 +50,11 @@ async def test_entry_setup_unload(hass, client, integration): entry = integration assert client.connect.call_count == 1 - assert client.register_on_initialized.call_count == 1 - assert client.register_on_disconnect.call_count == 1 - assert client.register_on_connect.call_count == 1 assert entry.state == ENTRY_STATE_LOADED await hass.config_entries.async_unload(entry.entry_id) assert client.disconnect.call_count == 1 - assert client.register_on_initialized.return_value.call_count == 1 - assert client.register_on_disconnect.return_value.call_count == 1 - assert client.register_on_connect.return_value.call_count == 1 assert entry.state == ENTRY_STATE_NOT_LOADED @@ -71,38 +65,6 @@ async def test_home_assistant_stop(hass, client, integration): assert client.disconnect.call_count == 1 -async def test_availability_reflect_connection_status( - hass, client, multisensor_6, integration -): - """Test we handle disconnect and reconnect.""" - on_initialized = client.register_on_initialized.call_args[0][0] - on_disconnect = client.register_on_disconnect.call_args[0][0] - state = hass.states.get(AIR_TEMPERATURE_SENSOR) - - assert state - assert state.state != STATE_UNAVAILABLE - - client.connected = False - - await on_disconnect() - await hass.async_block_till_done() - - state = hass.states.get(AIR_TEMPERATURE_SENSOR) - - assert state - assert state.state == STATE_UNAVAILABLE - - client.connected = True - - await on_initialized() - await hass.async_block_till_done() - - state = hass.states.get(AIR_TEMPERATURE_SENSOR) - - assert state - assert state.state != STATE_UNAVAILABLE - - async def test_initialized_timeout(hass, client, connect_timeout): """Test we handle a timeout during client initialization.""" entry = MockConfigEntry(domain="zwave_js", data={"url": "ws://test.org"})