diff --git a/homeassistant/components/zwave_js/__init__.py b/homeassistant/components/zwave_js/__init__.py index 4c05b7a4fab..b768e7a4465 100644 --- a/homeassistant/components/zwave_js/__init__.py +++ b/homeassistant/components/zwave_js/__init__.py @@ -58,25 +58,30 @@ def register_node_in_dev_reg( async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up Z-Wave JS from a config entry.""" client = ZwaveClient(entry.data[CONF_URL], async_get_clientsession(hass)) + connected = asyncio.Event() initialized = asyncio.Event() dev_reg = await device_registry.async_get_registry(hass) async def async_on_connect() -> None: """Handle websocket is (re)connected.""" LOGGER.info("Connected to Zwave JS Server") - if initialized.is_set(): - # update entity availability - async_dispatcher_send(hass, f"{DOMAIN}_{entry.entry_id}_connection_state") + connected.set() async def async_on_disconnect() -> None: """Handle websocket is disconnected.""" LOGGER.info("Disconnected from Zwave JS Server") - async_dispatcher_send(hass, f"{DOMAIN}_{entry.entry_id}_connection_state") + connected.clear() + if initialized.is_set(): + initialized.clear() + # update entity availability + async_dispatcher_send(hass, f"{DOMAIN}_{entry.entry_id}_connection_state") async def async_on_initialized() -> None: """Handle initial full state received.""" LOGGER.info("Connection to Zwave JS Server initialized.") initialized.set() + # update entity availability + async_dispatcher_send(hass, f"{DOMAIN}_{entry.entry_id}_connection_state") @callback def async_on_node_ready(node: ZwaveNode) -> None: @@ -127,7 +132,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: asyncio.create_task(client.connect()) try: async with timeout(CONNECT_TIMEOUT): - await initialized.wait() + await connected.wait() except asyncio.TimeoutError as err: for unsub in unsubs: unsub() @@ -152,6 +157,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ] ) + # Wait till we're initialized + LOGGER.info("Waiting for Z-Wave to be fully initialized") + await initialized.wait() + # run discovery on all ready nodes for node in client.driver.controller.nodes.values(): async_on_node_added(node) diff --git a/tests/components/zwave_js/conftest.py b/tests/components/zwave_js/conftest.py index 74bf103a49e..6bf4afefd62 100644 --- a/tests/components/zwave_js/conftest.py +++ b/tests/components/zwave_js/conftest.py @@ -1,6 +1,6 @@ """Provide common Z-Wave JS fixtures.""" import json -from unittest.mock import DEFAULT, patch +from unittest.mock import DEFAULT, Mock, patch import pytest from zwave_js_server.event import Event @@ -97,16 +97,37 @@ def in_wall_smart_fan_control_state_fixture(): @pytest.fixture(name="client") def mock_client_fixture(controller_state, version_state): """Mock a client.""" + + def mock_callback(): + callbacks = [] + + def add_callback(cb): + callbacks.append(cb) + return DEFAULT + + return callbacks, Mock(side_effect=add_callback) + with patch( "homeassistant.components.zwave_js.ZwaveClient", autospec=True ) as client_class: - driver = Driver(client_class.return_value, controller_state) - version = VersionInfo.from_message(version_state) - client_class.return_value.driver = driver - client_class.return_value.version = version - client_class.return_value.ws_server_url = "ws://test:3000/zjs" - client_class.return_value.state = "connected" - yield client_class.return_value + client = client_class.return_value + + connect_callback, client.register_on_connect = mock_callback() + initialized_callback, client.register_on_initialized = mock_callback() + + async def connect(): + for cb in connect_callback: + await cb() + + for cb in initialized_callback: + await cb() + + client.connect = Mock(side_effect=connect) + client.driver = Driver(client, controller_state) + client.version = VersionInfo.from_message(version_state) + client.ws_server_url = "ws://test:3000/zjs" + client.state = "connected" + yield client @pytest.fixture(name="multisensor_6") @@ -190,14 +211,6 @@ async def integration_fixture(hass, client): """Set up the zwave_js integration.""" entry = MockConfigEntry(domain="zwave_js", data={"url": "ws://test.org"}) entry.add_to_hass(hass) - - def initialize_client(async_on_initialized): - """Init the client.""" - hass.async_create_task(async_on_initialized()) - return DEFAULT - - client.register_on_initialized.side_effect = initialize_client - await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() diff --git a/tests/components/zwave_js/test_init.py b/tests/components/zwave_js/test_init.py index a117afa53bf..46b75331379 100644 --- a/tests/components/zwave_js/test_init.py +++ b/tests/components/zwave_js/test_init.py @@ -1,6 +1,6 @@ """Test the Z-Wave JS init module.""" from copy import deepcopy -from unittest.mock import DEFAULT, patch +from unittest.mock import patch import pytest from zwave_js_server.model.node import Node @@ -51,9 +51,11 @@ async def test_home_assistant_stop(hass, client, integration): assert client.disconnect.call_count == 1 -async def test_on_connect_disconnect(hass, client, multisensor_6, integration): +async def test_availability_reflect_connection_status( + hass, client, multisensor_6, integration +): """Test we handle disconnect and reconnect.""" - on_connect = client.register_on_connect.call_args[0][0] + on_initialized = client.register_on_initialized.call_args[0][0] on_disconnect = client.register_on_disconnect.call_args[0][0] state = hass.states.get(AIR_TEMPERATURE_SENSOR) @@ -72,7 +74,7 @@ async def test_on_connect_disconnect(hass, client, multisensor_6, integration): client.connected = True - await on_connect() + await on_initialized() await hass.async_block_till_done() state = hass.states.get(AIR_TEMPERATURE_SENSOR) @@ -182,13 +184,6 @@ async def test_existing_node_not_ready(hass, client, multisensor_6, device_regis entry = MockConfigEntry(domain="zwave_js", data={"url": "ws://test.org"}) entry.add_to_hass(hass) - def initialize_client(async_on_initialized): - """Init the client.""" - hass.async_create_task(async_on_initialized()) - return DEFAULT - - client.register_on_initialized.side_effect = initialize_client - await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done()