Update zwave_js FirmwareUploadView to support controller updates (#87239)

* Update zwave_js FirmwareUploadView to support controller updates

* Add coverage

* Change None check to assertion
This commit is contained in:
Raman Gupta 2023-02-22 11:52:00 -05:00 committed by GitHub
parent 5683d21931
commit 1f9f6ab1f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 85 additions and 19 deletions

View file

@ -27,13 +27,14 @@ from zwave_js_server.exceptions import (
NotFoundError,
SetValueFailed,
)
from zwave_js_server.firmware import update_firmware
from zwave_js_server.firmware import controller_firmware_update_otw, update_firmware
from zwave_js_server.model.controller import (
ControllerStatistics,
InclusionGrant,
ProvisioningEntry,
QRProvisioningInformation,
)
from zwave_js_server.model.controller.firmware import ControllerFirmwareUpdateData
from zwave_js_server.model.driver import Driver
from zwave_js_server.model.log_config import LogConfig
from zwave_js_server.model.log_message import LogMessage
@ -445,7 +446,7 @@ def async_register_api(hass: HomeAssistant) -> None:
hass, websocket_subscribe_controller_statistics
)
websocket_api.async_register_command(hass, websocket_subscribe_node_statistics)
hass.http.register_view(FirmwareUploadView())
hass.http.register_view(FirmwareUploadView(dr.async_get(hass)))
@websocket_api.require_admin
@ -2071,10 +2072,10 @@ class FirmwareUploadView(HomeAssistantView):
url = r"/api/zwave_js/firmware/upload/{device_id}"
name = "api:zwave_js:firmware:upload"
def __init__(self) -> None:
def __init__(self, dev_reg: dr.DeviceRegistry) -> None:
"""Initialize view."""
super().__init__()
self._dev_reg: dr.DeviceRegistry | None = None
self._dev_reg = dev_reg
async def post(self, request: web.Request, device_id: str) -> web.Response:
"""Handle upload."""
@ -2083,12 +2084,16 @@ class FirmwareUploadView(HomeAssistantView):
hass = request.app["hass"]
try:
node = async_get_node_from_device_id(hass, device_id)
node = async_get_node_from_device_id(hass, device_id, self._dev_reg)
except ValueError as err:
if "not loaded" in err.args[0]:
raise web_exceptions.HTTPBadRequest
raise web_exceptions.HTTPNotFound
# If this was not true, we wouldn't have been able to get the node from the
# device ID above
assert node.client.driver
# Increase max payload
request._client_max_size = 1024 * 1024 * 10 # pylint: disable=protected-access
@ -2100,18 +2105,29 @@ class FirmwareUploadView(HomeAssistantView):
uploaded_file: web_request.FileField = data["file"]
try:
await update_firmware(
node.client.ws_server_url,
node,
[
NodeFirmwareUpdateData(
if node.client.driver.controller.own_node == node:
await controller_firmware_update_otw(
node.client.ws_server_url,
ControllerFirmwareUpdateData(
uploaded_file.filename,
await hass.async_add_executor_job(uploaded_file.file.read),
)
],
async_get_clientsession(hass),
additional_user_agent_components=USER_AGENT,
)
),
async_get_clientsession(hass),
additional_user_agent_components=USER_AGENT,
)
else:
await update_firmware(
node.client.ws_server_url,
node,
[
NodeFirmwareUpdateData(
uploaded_file.filename,
await hass.async_add_executor_job(uploaded_file.file.read),
)
],
async_get_clientsession(hass),
additional_user_agent_components=USER_AGENT,
)
except BaseZwaveJSServerError as err:
raise web_exceptions.HTTPBadRequest(reason=str(err)) from err

View file

@ -28,6 +28,7 @@ from zwave_js_server.model.controller import (
ProvisioningEntry,
QRProvisioningInformation,
)
from zwave_js_server.model.controller.firmware import ControllerFirmwareUpdateData
from zwave_js_server.model.node import Node
from zwave_js_server.model.node.firmware import NodeFirmwareUpdateData
@ -84,7 +85,7 @@ from tests.common import MockUser
from tests.typing import ClientSessionGenerator, WebSocketGenerator
def get_device(hass, node):
def get_device(hass: HomeAssistant, node):
"""Get device ID for a node."""
dev_reg = dr.async_get(hass)
device_id = get_device_id(node.client.driver, node)
@ -2968,7 +2969,9 @@ async def test_firmware_upload_view(
device = get_device(hass, multisensor_6)
with patch(
"homeassistant.components.zwave_js.api.update_firmware",
) as mock_cmd, patch.dict(
) as mock_node_cmd, patch(
"homeassistant.components.zwave_js.api.controller_firmware_update_otw",
) as mock_controller_cmd, patch.dict(
"homeassistant.components.zwave_js.api.USER_AGENT",
{"HomeAssistant": "0.0.0"},
):
@ -2976,11 +2979,40 @@ async def test_firmware_upload_view(
f"/api/zwave_js/firmware/upload/{device.id}",
data={"file": firmware_file},
)
assert mock_cmd.call_args[0][1:3] == (
mock_controller_cmd.assert_not_called()
assert mock_node_cmd.call_args[0][1:3] == (
multisensor_6,
[NodeFirmwareUpdateData("file", bytes(10))],
)
assert mock_cmd.call_args[1] == {
assert mock_node_cmd.call_args[1] == {
"additional_user_agent_components": {"HomeAssistant": "0.0.0"},
}
assert json.loads(await resp.text()) is None
async def test_firmware_upload_view_controller(
hass, client, integration, hass_client: ClientSessionGenerator, firmware_file
) -> None:
"""Test the HTTP firmware upload view for a controller."""
hass_client = await hass_client()
device = get_device(hass, client.driver.controller.nodes[1])
with patch(
"homeassistant.components.zwave_js.api.update_firmware",
) as mock_node_cmd, patch(
"homeassistant.components.zwave_js.api.controller_firmware_update_otw",
) as mock_controller_cmd, patch.dict(
"homeassistant.components.zwave_js.api.USER_AGENT",
{"HomeAssistant": "0.0.0"},
):
resp = await hass_client.post(
f"/api/zwave_js/firmware/upload/{device.id}",
data={"file": firmware_file},
)
mock_node_cmd.assert_not_called()
assert mock_controller_cmd.call_args[0][1:2] == (
ControllerFirmwareUpdateData("file", bytes(10)),
)
assert mock_controller_cmd.call_args[1] == {
"additional_user_agent_components": {"HomeAssistant": "0.0.0"},
}
assert json.loads(await resp.text()) is None
@ -3020,6 +3052,24 @@ async def test_firmware_upload_view_invalid_payload(
assert resp.status == HTTPStatus.BAD_REQUEST
async def test_firmware_upload_view_no_driver(
hass: HomeAssistant,
client,
multisensor_6,
integration,
hass_client: ClientSessionGenerator,
) -> None:
"""Test the HTTP firmware upload view when the driver doesn't exist."""
device = get_device(hass, multisensor_6)
client.driver = None
aiohttp_client = await hass_client()
resp = await aiohttp_client.post(
f"/api/zwave_js/firmware/upload/{device.id}",
data={"wrong_key": bytes(10)},
)
assert resp.status == HTTPStatus.NOT_FOUND
@pytest.mark.parametrize(
("method", "url"),
[("post", "/api/zwave_js/firmware/upload/{}")],