diff --git a/homeassistant/components/zwave_js/__init__.py b/homeassistant/components/zwave_js/__init__.py index 0c716a39c75..9f928eb7d3d 100644 --- a/homeassistant/components/zwave_js/__init__.py +++ b/homeassistant/components/zwave_js/__init__.py @@ -566,7 +566,7 @@ async def setup_driver( # noqa: C901 # If opt in preference hasn't been specified yet, we do nothing, otherwise # we apply the preference if opted_in := entry.data.get(CONF_DATA_COLLECTION_OPTED_IN): - await async_enable_statistics(client) + await async_enable_statistics(driver) elif opted_in is False: await driver.async_disable_statistics() diff --git a/homeassistant/components/zwave_js/api.py b/homeassistant/components/zwave_js/api.py index 66c497a791f..15d3a68d4a4 100644 --- a/homeassistant/components/zwave_js/api.py +++ b/homeassistant/components/zwave_js/api.py @@ -32,6 +32,7 @@ from zwave_js_server.model.controller import ( ProvisioningEntry, QRProvisioningInformation, ) +from zwave_js_server.model.driver import Driver from zwave_js_server.model.firmware import ( FirmwareUpdateFinished, FirmwareUpdateProgress, @@ -243,8 +244,17 @@ def async_get_entry(orig_func: Callable) -> Callable: ) return - client = hass.data[DOMAIN][entry_id][DATA_CLIENT] - await orig_func(hass, connection, msg, entry, client) + client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT] + + if client.driver is None: + connection.send_error( + msg[ID], + ERR_NOT_LOADED, + f"Config entry {entry_id} not loaded, driver not ready", + ) + return + + await orig_func(hass, connection, msg, entry, client, client.driver) return async_get_entry_func @@ -373,16 +383,20 @@ async def websocket_network_status( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Get the status of the Z-Wave JS network.""" - controller = client.driver.controller + controller = driver.controller + client_version_info = client.version + assert client_version_info # When client is connected version info is set. + await controller.async_get_state() data = { "client": { "ws_server_url": client.ws_server_url, "state": "connected" if client.connected else "disconnected", - "driver_version": client.version.driver_version, - "server_version": client.version.server_version, + "driver_version": client_version_info.driver_version, + "server_version": client_version_info.server_version, }, "controller": { "home_id": controller.home_id, @@ -404,9 +418,7 @@ async def websocket_network_status( "supports_timers": controller.supports_timers, "is_heal_network_active": controller.is_heal_network_active, "inclusion_state": controller.inclusion_state, - "nodes": [ - node_status(node) for node in client.driver.controller.nodes.values() - ], + "nodes": [node_status(node) for node in driver.controller.nodes.values()], }, } connection.send_result( @@ -533,9 +545,10 @@ async def websocket_add_node( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Add a node to the Z-Wave network.""" - controller = client.driver.controller + controller = driver.controller inclusion_strategy = InclusionStrategy(msg[INCLUSION_STRATEGY]) force_security = msg.get(FORCE_SECURITY) provisioning = ( @@ -672,13 +685,14 @@ async def websocket_grant_security_classes( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Choose SecurityClass grants as part of S2 inclusion process.""" inclusion_grant = InclusionGrant( [SecurityClass(sec_cls) for sec_cls in msg[SECURITY_CLASSES]], msg[CLIENT_SIDE_AUTH], ) - await client.driver.controller.async_grant_security_classes(inclusion_grant) + await driver.controller.async_grant_security_classes(inclusion_grant) connection.send_result(msg[ID]) @@ -699,9 +713,10 @@ async def websocket_validate_dsk_and_enter_pin( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Validate DSK and enter PIN as part of S2 inclusion process.""" - await client.driver.controller.async_validate_dsk_and_enter_pin(msg[PIN]) + await driver.controller.async_validate_dsk_and_enter_pin(msg[PIN]) connection.send_result(msg[ID]) @@ -728,6 +743,7 @@ async def websocket_provision_smart_start_node( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Pre-provision a smart start node.""" try: @@ -758,7 +774,7 @@ async def websocket_provision_smart_start_node( "QR code version S2 is not supported for this command", ) return - await client.driver.controller.async_provision_smart_start_node(provisioning_info) + await driver.controller.async_provision_smart_start_node(provisioning_info) connection.send_result(msg[ID]) @@ -780,6 +796,7 @@ async def websocket_unprovision_smart_start_node( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Unprovision a smart start node.""" try: @@ -792,7 +809,7 @@ async def websocket_unprovision_smart_start_node( ) return dsk_or_node_id = msg.get(DSK) or msg[NODE_ID] - await client.driver.controller.async_unprovision_smart_start_node(dsk_or_node_id) + await driver.controller.async_unprovision_smart_start_node(dsk_or_node_id) connection.send_result(msg[ID]) @@ -812,11 +829,10 @@ async def websocket_get_provisioning_entries( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Get provisioning entries (entries that have been pre-provisioned).""" - provisioning_entries = ( - await client.driver.controller.async_get_provisioning_entries() - ) + provisioning_entries = await driver.controller.async_get_provisioning_entries() connection.send_result( msg[ID], [dataclasses.asdict(entry) for entry in provisioning_entries] ) @@ -839,6 +855,7 @@ async def websocket_parse_qr_code_string( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Parse a QR Code String and return QRProvisioningInformation dict.""" qr_provisioning_information = await async_parse_qr_code_string( @@ -864,9 +881,10 @@ async def websocket_supports_feature( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Check if controller supports a particular feature.""" - supported = await client.driver.controller.async_supports_feature(msg[FEATURE]) + supported = await driver.controller.async_supports_feature(msg[FEATURE]) connection.send_result( msg[ID], {"supported": supported}, @@ -889,9 +907,10 @@ async def websocket_stop_inclusion( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Cancel adding a node to the Z-Wave network.""" - controller = client.driver.controller + controller = driver.controller result = await controller.async_stop_inclusion() connection.send_result( msg[ID], @@ -915,9 +934,10 @@ async def websocket_stop_exclusion( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Cancel removing a node from the Z-Wave network.""" - controller = client.driver.controller + controller = driver.controller result = await controller.async_stop_exclusion() connection.send_result( msg[ID], @@ -942,9 +962,10 @@ async def websocket_remove_node( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Remove a node from the Z-Wave network.""" - controller = client.driver.controller + controller = driver.controller @callback def async_cleanup() -> None: @@ -1021,9 +1042,10 @@ async def websocket_replace_failed_node( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Replace a failed node with a new node.""" - controller = client.driver.controller + controller = driver.controller node_id = msg[NODE_ID] inclusion_strategy = InclusionStrategy(msg[INCLUSION_STRATEGY]) force_security = msg.get(FORCE_SECURITY) @@ -1173,7 +1195,9 @@ async def websocket_remove_failed_node( node: Node, ) -> None: """Remove a failed node from the Z-Wave network.""" - controller = node.client.driver.controller + driver = node.client.driver + assert driver is not None # The node comes from the driver instance. + controller = driver.controller @callback def async_cleanup() -> None: @@ -1217,9 +1241,10 @@ async def websocket_begin_healing_network( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Begin healing the Z-Wave network.""" - controller = client.driver.controller + controller = driver.controller result = await controller.async_begin_healing_network() connection.send_result( @@ -1243,9 +1268,10 @@ async def websocket_subscribe_heal_network_progress( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Subscribe to heal Z-Wave network status updates.""" - controller = client.driver.controller + controller = driver.controller @callback def async_cleanup() -> None: @@ -1286,9 +1312,10 @@ async def websocket_stop_healing_network( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Stop healing the Z-Wave network.""" - controller = client.driver.controller + controller = driver.controller result = await controller.async_stop_healing_network() connection.send_result( msg[ID], @@ -1313,7 +1340,10 @@ async def websocket_heal_node( node: Node, ) -> None: """Heal a node on the Z-Wave network.""" - controller = node.client.driver.controller + driver = node.client.driver + assert driver is not None # The node comes from the driver instance. + controller = driver.controller + result = await controller.async_heal_node(node.node_id) connection.send_result( msg[ID], @@ -1540,9 +1570,9 @@ async def websocket_subscribe_log_updates( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Subscribe to log message events from the server.""" - driver = client.driver @callback def async_cleanup() -> None: @@ -1627,9 +1657,10 @@ async def websocket_update_log_config( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Update the driver log config.""" - await client.driver.async_update_log_config(LogConfig(**msg[CONFIG])) + await driver.async_update_log_config(LogConfig(**msg[CONFIG])) connection.send_result( msg[ID], ) @@ -1650,11 +1681,12 @@ async def websocket_get_log_config( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Get log configuration for the Z-Wave JS driver.""" connection.send_result( msg[ID], - dataclasses.asdict(client.driver.log_config), + dataclasses.asdict(driver.log_config), ) @@ -1675,15 +1707,16 @@ async def websocket_update_data_collection_preference( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Update preference for data collection and enable/disable collection.""" opted_in = msg[OPTED_IN] update_data_collection_preference(hass, entry, opted_in) if opted_in: - await async_enable_statistics(client) + await async_enable_statistics(driver) else: - await client.driver.async_disable_statistics() + await driver.async_disable_statistics() connection.send_result( msg[ID], @@ -1706,11 +1739,12 @@ async def websocket_data_collection_status( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Return data collection preference and status.""" result = { OPTED_IN: entry.data.get(CONF_DATA_COLLECTION_OPTED_IN), - ENABLED: await client.driver.async_is_statistics_enabled(), + ENABLED: await driver.async_is_statistics_enabled(), } connection.send_result(msg[ID], result) @@ -1890,9 +1924,10 @@ async def websocket_check_for_config_updates( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Check for config updates.""" - config_update = await client.driver.async_check_for_config_updates() + config_update = await driver.async_check_for_config_updates() connection.send_result( msg[ID], { @@ -1918,9 +1953,10 @@ async def websocket_install_config_update( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Check for config updates.""" - success = await client.driver.async_install_config_update() + success = await driver.async_install_config_update() connection.send_result(msg[ID], success) @@ -1956,6 +1992,7 @@ async def websocket_subscribe_controller_statistics( msg: dict, entry: ConfigEntry, client: Client, + driver: Driver, ) -> None: """Subsribe to the statistics updates for a controller.""" @@ -1979,7 +2016,7 @@ async def websocket_subscribe_controller_statistics( ) ) - controller = client.driver.controller + controller = driver.controller msg[DATA_UNSUBSCRIBE] = unsubs = [ controller.on("statistics updated", forward_stats) diff --git a/homeassistant/components/zwave_js/helpers.py b/homeassistant/components/zwave_js/helpers.py index 68ff0c89b15..0657d8531f3 100644 --- a/homeassistant/components/zwave_js/helpers.py +++ b/homeassistant/components/zwave_js/helpers.py @@ -9,6 +9,7 @@ from typing import Any, cast import voluptuous as vol from zwave_js_server.client import Client as ZwaveClient from zwave_js_server.const import ConfigurationValueType +from zwave_js_server.model.driver import Driver from zwave_js_server.model.node import Node as ZwaveNode from zwave_js_server.model.value import ( ConfigurationValue, @@ -92,10 +93,10 @@ def get_value_of_zwave_value(value: ZwaveValue | None) -> Any | None: return value.value if value else None -async def async_enable_statistics(client: ZwaveClient) -> None: +async def async_enable_statistics(driver: Driver) -> None: """Enable statistics on the driver.""" - await client.driver.async_enable_statistics("Home Assistant", HA_VERSION) - await client.driver.async_enable_error_reporting() + await driver.async_enable_statistics("Home Assistant", HA_VERSION) + await driver.async_enable_error_reporting() @callback @@ -194,7 +195,11 @@ def async_get_node_from_device_id( f"Device {device_id} is not from an existing zwave_js config entry" ) - client = hass.data[DOMAIN][entry.entry_id][DATA_CLIENT] + client: ZwaveClient = hass.data[DOMAIN][entry.entry_id][DATA_CLIENT] + driver = client.driver + + if driver is None: + raise ValueError("Driver is not ready.") # Get node ID from device identifier, perform some validation, and then get the # node @@ -202,10 +207,10 @@ def async_get_node_from_device_id( node_id = identifiers[1] if identifiers else None - if node_id is None or node_id not in client.driver.controller.nodes: + if node_id is None or node_id not in driver.controller.nodes: raise ValueError(f"Node for device {device_id} can't be found") - return client.driver.controller.nodes[node_id] + return driver.controller.nodes[node_id] @callback