Add firmware updates support for zwave_js (#50390)

* Add WS API support for zwave_js firmware updates

* move file to fixture

* review comments

* fix logic and test based on upstream changes

* handle failure scenario

* handle failure scenario

* fix tests and adjust message

* Update homeassistant/components/zwave_js/api.py

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>

* remove return from firmware upload view because client will raise an exception if not successful

* raise if user is not an admin

* raise bad request exception if firmware command fails

* incorporate #50923

* Add test for failed command

* add event name to messages

* change error to not found

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
Raman Gupta 2021-05-26 02:57:00 -04:00 committed by GitHub
parent c1d5dd7141
commit 5f7964b54b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 412 additions and 5 deletions

View file

@ -6,12 +6,22 @@ from functools import wraps
import json
from typing import Callable
from aiohttp import hdrs, web, web_exceptions
from aiohttp import hdrs, web, web_exceptions, web_request
import voluptuous as vol
from zwave_js_server import dump
from zwave_js_server.client import Client
from zwave_js_server.const import CommandClass, LogLevel
from zwave_js_server.exceptions import InvalidNewValue, NotFoundError, SetValueFailed
from zwave_js_server.exceptions import (
BaseZwaveJSServerError,
InvalidNewValue,
NotFoundError,
SetValueFailed,
)
from zwave_js_server.firmware import begin_firmware_update
from zwave_js_server.model.firmware import (
FirmwareUpdateFinished,
FirmwareUpdateProgress,
)
from zwave_js_server.model.log_config import LogConfig
from zwave_js_server.model.log_message import LogMessage
from zwave_js_server.model.node import Node
@ -28,6 +38,7 @@ from homeassistant.components.websocket_api.const import (
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import CONF_URL
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import Unauthorized
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.device_registry import DeviceEntry
@ -147,7 +158,12 @@ def async_register_api(hass: HomeAssistant) -> None:
hass, websocket_update_data_collection_preference
)
websocket_api.async_register_command(hass, websocket_data_collection_status)
websocket_api.async_register_command(hass, websocket_abort_firmware_update)
websocket_api.async_register_command(
hass, websocket_subscribe_firmware_update_status
)
hass.http.register_view(DumpView())
hass.http.register_view(FirmwareUploadView())
@websocket_api.require_admin
@ -1024,3 +1040,131 @@ class DumpView(HomeAssistantView):
hdrs.CONTENT_DISPOSITION: 'attachment; filename="zwave_js_dump.json"',
},
)
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required(TYPE): "zwave_js/abort_firmware_update",
vol.Required(ENTRY_ID): str,
vol.Required(NODE_ID): int,
}
)
@websocket_api.async_response
@async_get_node
async def websocket_abort_firmware_update(
hass: HomeAssistant,
connection: ActiveConnection,
msg: dict,
node: Node,
) -> None:
"""Abort a firmware update."""
await node.async_abort_firmware_update()
connection.send_result(msg[ID])
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required(TYPE): "zwave_js/subscribe_firmware_update_status",
vol.Required(ENTRY_ID): str,
vol.Required(NODE_ID): int,
}
)
@websocket_api.async_response
@async_get_node
async def websocket_subscribe_firmware_update_status(
hass: HomeAssistant,
connection: ActiveConnection,
msg: dict,
node: Node,
) -> None:
"""Subsribe to the status of a firmware update."""
@callback
def async_cleanup() -> None:
"""Remove signal listeners."""
for unsub in unsubs:
unsub()
@callback
def forward_progress(event: dict) -> None:
progress: FirmwareUpdateProgress = event["firmware_update_progress"]
connection.send_message(
websocket_api.event_message(
msg[ID],
{
"event": event["event"],
"sent_fragments": progress.sent_fragments,
"total_fragments": progress.total_fragments,
},
)
)
@callback
def forward_finished(event: dict) -> None:
finished: FirmwareUpdateFinished = event["firmware_update_finished"]
connection.send_message(
websocket_api.event_message(
msg[ID],
{
"event": event["event"],
"status": finished.status,
"wait_time": finished.wait_time,
},
)
)
unsubs = [
node.on("firmware update progress", forward_progress),
node.on("firmware update finished", forward_finished),
]
connection.subscriptions[msg["id"]] = async_cleanup
connection.send_result(msg[ID])
class FirmwareUploadView(HomeAssistantView):
"""View to upload firmware."""
url = r"/api/zwave_js/firmware/upload/{config_entry_id}/{node_id:\d+}"
name = "api:zwave_js:firmware:upload"
async def post(
self, request: web.Request, config_entry_id: str, node_id: str
) -> web.Response:
"""Handle upload."""
if not request["hass_user"].is_admin:
raise Unauthorized()
hass = request.app["hass"]
if config_entry_id not in hass.data[DOMAIN]:
raise web_exceptions.HTTPBadRequest
entry = hass.config_entries.async_get_entry(config_entry_id)
client = hass.data[DOMAIN][config_entry_id][DATA_CLIENT]
node = client.driver.controller.nodes.get(int(node_id))
if not node:
raise web_exceptions.HTTPNotFound
# Increase max payload
request._client_max_size = 1024 * 1024 * 10 # pylint: disable=protected-access
data = await request.post()
if "file" not in data or not isinstance(data["file"], web_request.FileField):
raise web_exceptions.HTTPBadRequest
uploaded_file: web_request.FileField = data["file"]
try:
await begin_firmware_update(
entry.data[CONF_URL],
node,
uploaded_file.filename,
await hass.async_add_executor_job(uploaded_file.file.read),
async_get_clientsession(hass),
)
except BaseZwaveJSServerError as err:
raise web_exceptions.HTTPBadRequest from err
return self.json(None)

View file

@ -1,6 +1,7 @@
"""Provide common Z-Wave JS fixtures."""
import asyncio
import copy
import io
import json
from unittest.mock import AsyncMock, patch
@ -717,3 +718,9 @@ def wallmote_central_scene_fixture(client, wallmote_central_scene_state):
node = Node(client, copy.deepcopy(wallmote_central_scene_state))
client.driver.controller.nodes[node.node_id] = node
return node
@pytest.fixture(name="firmware_file")
def firmware_file_fixture():
"""Return mock firmware file stream."""
return io.BytesIO(bytes(10))

View file

@ -2,9 +2,15 @@
import json
from unittest.mock import patch
import pytest
from zwave_js_server.const import LogLevel
from zwave_js_server.event import Event
from zwave_js_server.exceptions import InvalidNewValue, NotFoundError, SetValueFailed
from zwave_js_server.exceptions import (
FailedCommand,
InvalidNewValue,
NotFoundError,
SetValueFailed,
)
from homeassistant.components.websocket_api.const import ERR_NOT_FOUND
from homeassistant.components.zwave_js.api import (
@ -1123,13 +1129,74 @@ async def test_dump_view(integration, hass_client):
assert json.loads(await resp.text()) == [{"hello": "world"}, {"second": "msg"}]
async def test_dump_view_invalid_entry_id(integration, hass_client):
async def test_firmware_upload_view(
hass, multisensor_6, integration, hass_client, firmware_file
):
"""Test the HTTP firmware upload view."""
client = await hass_client()
with patch(
"homeassistant.components.zwave_js.api.begin_firmware_update",
) as mock_cmd:
resp = await client.post(
f"/api/zwave_js/firmware/upload/{integration.entry_id}/{multisensor_6.node_id}",
data={"file": firmware_file},
)
assert mock_cmd.call_args[0][1:4] == (multisensor_6, "file", bytes(10))
assert json.loads(await resp.text()) is None
async def test_firmware_upload_view_failed_command(
hass, multisensor_6, integration, hass_client, firmware_file
):
"""Test failed command for the HTTP firmware upload view."""
client = await hass_client()
with patch(
"homeassistant.components.zwave_js.api.begin_firmware_update",
side_effect=FailedCommand("test", "test"),
):
resp = await client.post(
f"/api/zwave_js/firmware/upload/{integration.entry_id}/{multisensor_6.node_id}",
data={"file": firmware_file},
)
assert resp.status == 400
async def test_firmware_upload_view_invalid_payload(
hass, multisensor_6, integration, hass_client
):
"""Test an invalid payload for the HTTP firmware upload view."""
client = await hass_client()
resp = await client.post(
f"/api/zwave_js/firmware/upload/{integration.entry_id}/{multisensor_6.node_id}",
data={"wrong_key": bytes(10)},
)
assert resp.status == 400
@pytest.mark.parametrize(
"method, url",
[
("get", "/api/zwave_js/dump/INVALID"),
("post", "/api/zwave_js/firmware/upload/INVALID/1"),
],
)
async def test_view_invalid_entry_id(integration, hass_client, method, url):
"""Test an invalid config entry id parameter."""
client = await hass_client()
resp = await client.get("/api/zwave_js/dump/INVALID")
resp = await client.request(method, url)
assert resp.status == 400
@pytest.mark.parametrize(
"method, url", [("post", "/api/zwave_js/firmware/upload/{}/111")]
)
async def test_view_invalid_node_id(integration, hass_client, method, url):
"""Test an invalid config entry id parameter."""
client = await hass_client()
resp = await client.request(method, url.format(integration.entry_id))
assert resp.status == 404
async def test_subscribe_logs(hass, integration, client, hass_ws_client):
"""Test the subscribe_logs websocket command."""
entry = integration
@ -1468,3 +1535,192 @@ async def test_data_collection(hass, client, integration, hass_ws_client):
assert not msg["success"]
assert msg["error"]["code"] == ERR_NOT_LOADED
async def test_abort_firmware_update(
hass, client, multisensor_6, integration, hass_ws_client
):
"""Test that the abort_firmware_update WS API call works."""
entry = integration
ws_client = await hass_ws_client(hass)
client.async_send_command_no_wait.return_value = {}
await ws_client.send_json(
{
ID: 1,
TYPE: "zwave_js/abort_firmware_update",
ENTRY_ID: entry.entry_id,
NODE_ID: multisensor_6.node_id,
}
)
msg = await ws_client.receive_json()
assert msg["success"]
assert len(client.async_send_command_no_wait.call_args_list) == 1
args = client.async_send_command_no_wait.call_args[0][0]
assert args["command"] == "node.abort_firmware_update"
assert args["nodeId"] == multisensor_6.node_id
async def test_abort_firmware_update_failures(
hass, integration, multisensor_6, client, hass_ws_client
):
"""Test failures for the abort_firmware_update websocket command."""
entry = integration
ws_client = await hass_ws_client(hass)
# Test sending command with improper entry ID fails
await ws_client.send_json(
{
ID: 1,
TYPE: "zwave_js/abort_firmware_update",
ENTRY_ID: "fake_entry_id",
NODE_ID: multisensor_6.node_id,
}
)
msg = await ws_client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == ERR_NOT_FOUND
# Test sending command with improper node ID fails
await ws_client.send_json(
{
ID: 2,
TYPE: "zwave_js/abort_firmware_update",
ENTRY_ID: entry.entry_id,
NODE_ID: multisensor_6.node_id + 100,
}
)
msg = await ws_client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == ERR_NOT_FOUND
# Test sending command with not loaded entry fails
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
await ws_client.send_json(
{
ID: 3,
TYPE: "zwave_js/abort_firmware_update",
ENTRY_ID: entry.entry_id,
NODE_ID: multisensor_6.node_id,
}
)
msg = await ws_client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == ERR_NOT_LOADED
async def test_subscribe_firmware_update_status(
hass, integration, multisensor_6, client, hass_ws_client
):
"""Test the subscribe_firmware_update_status websocket command."""
entry = integration
ws_client = await hass_ws_client(hass)
client.async_send_command_no_wait.return_value = {}
await ws_client.send_json(
{
ID: 1,
TYPE: "zwave_js/subscribe_firmware_update_status",
ENTRY_ID: entry.entry_id,
NODE_ID: multisensor_6.node_id,
}
)
msg = await ws_client.receive_json()
assert msg["success"]
event = Event(
type="firmware update progress",
data={
"source": "node",
"event": "firmware update progress",
"nodeId": multisensor_6.node_id,
"sentFragments": 1,
"totalFragments": 10,
},
)
multisensor_6.receive_event(event)
msg = await ws_client.receive_json()
assert msg["event"] == {
"event": "firmware update progress",
"sent_fragments": 1,
"total_fragments": 10,
}
event = Event(
type="firmware update finished",
data={
"source": "node",
"event": "firmware update finished",
"nodeId": multisensor_6.node_id,
"status": 255,
"waitTime": 10,
},
)
multisensor_6.receive_event(event)
msg = await ws_client.receive_json()
assert msg["event"] == {
"event": "firmware update finished",
"status": 255,
"wait_time": 10,
}
async def test_subscribe_firmware_update_status_failures(
hass, integration, multisensor_6, client, hass_ws_client
):
"""Test failures for the subscribe_firmware_update_status websocket command."""
entry = integration
ws_client = await hass_ws_client(hass)
# Test sending command with improper entry ID fails
await ws_client.send_json(
{
ID: 1,
TYPE: "zwave_js/subscribe_firmware_update_status",
ENTRY_ID: "fake_entry_id",
NODE_ID: multisensor_6.node_id,
}
)
msg = await ws_client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == ERR_NOT_FOUND
# Test sending command with improper node ID fails
await ws_client.send_json(
{
ID: 2,
TYPE: "zwave_js/subscribe_firmware_update_status",
ENTRY_ID: entry.entry_id,
NODE_ID: multisensor_6.node_id + 100,
}
)
msg = await ws_client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == ERR_NOT_FOUND
# Test sending command with not loaded entry fails
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
await ws_client.send_json(
{
ID: 3,
TYPE: "zwave_js/subscribe_firmware_update_status",
ENTRY_ID: entry.entry_id,
NODE_ID: multisensor_6.node_id,
}
)
msg = await ws_client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == ERR_NOT_LOADED