diff --git a/homeassistant/components/zwave_js/api.py b/homeassistant/components/zwave_js/api.py index d27541fc61c..66c497a791f 100644 --- a/homeassistant/components/zwave_js/api.py +++ b/homeassistant/components/zwave_js/api.py @@ -57,7 +57,7 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import Unauthorized from homeassistant.helpers import config_validation as cv from homeassistant.helpers.aiohttp_client import async_get_clientsession -from homeassistant.helpers.device_registry import DeviceEntry +import homeassistant.helpers.device_registry as dr from homeassistant.helpers.dispatcher import async_dispatcher_connect from .config_validation import BITMASK_SCHEMA @@ -607,7 +607,7 @@ async def websocket_add_node( ) @callback - def device_registered(device: DeviceEntry) -> None: + def device_registered(device: dr.DeviceEntry) -> None: device_details = { "name": device.name, "id": device.id, @@ -1108,7 +1108,7 @@ async def websocket_replace_failed_node( ) @callback - def device_registered(device: DeviceEntry) -> None: + def device_registered(device: dr.DeviceEntry) -> None: device_details = { "name": device.name, "id": device.id, @@ -1819,25 +1819,37 @@ async def websocket_subscribe_firmware_update_status( class FirmwareUploadView(HomeAssistantView): """View to upload firmware.""" - url = r"/api/zwave_js/firmware/upload/{config_entry_id}/{node_id:\d+}" + url = r"/api/zwave_js/firmware/upload/{device_id}" name = "api:zwave_js:firmware:upload" - async def post( - self, request: web.Request, config_entry_id: str, node_id: str - ) -> web.Response: + def __init__(self) -> None: + """Initialize view.""" + super().__init__() + self._dev_reg: dr.DeviceRegistry | None = None + + async def post(self, request: web.Request, device_id: str) -> web.Response: """Handle upload.""" if not request["hass_user"].is_admin: raise Unauthorized() hass = request.app["hass"] - if config_entry_id not in hass.data[DOMAIN]: - raise web_exceptions.HTTPBadRequest - entry = hass.config_entries.async_get_entry(config_entry_id) - client: Client = hass.data[DOMAIN][config_entry_id][DATA_CLIENT] - node = client.driver.controller.nodes.get(int(node_id)) - if not node: + try: + node = async_get_node_from_device_id(hass, device_id) + except ValueError as err: + if "not loaded" in err.args[0]: + raise web_exceptions.HTTPBadRequest raise web_exceptions.HTTPNotFound + if not self._dev_reg: + self._dev_reg = dr.async_get(hass) + device = self._dev_reg.async_get(device_id) + assert device + entry = next( + entry + for entry in hass.config_entries.async_entries(DOMAIN) + if entry.entry_id in device.config_entries + ) + # Increase max payload request._client_max_size = 1024 * 1024 * 10 # pylint: disable=protected-access diff --git a/tests/components/zwave_js/test_api.py b/tests/components/zwave_js/test_api.py index 60b83630add..3d491b98f93 100644 --- a/tests/components/zwave_js/test_api.py +++ b/tests/components/zwave_js/test_api.py @@ -2661,11 +2661,12 @@ async def test_firmware_upload_view( ): """Test the HTTP firmware upload view.""" client = await hass_client() + device = get_device(hass, multisensor_6) with patch( "homeassistant.components.zwave_js.api.begin_firmware_update", ) as mock_cmd: resp = await client.post( - f"/api/zwave_js/firmware/upload/{integration.entry_id}/{multisensor_6.node_id}", + f"/api/zwave_js/firmware/upload/{device.id}", data={"file": firmware_file}, ) assert mock_cmd.call_args[0][1:4] == (multisensor_6, "file", bytes(10)) @@ -2677,12 +2678,13 @@ async def test_firmware_upload_view_failed_command( ): """Test failed command for the HTTP firmware upload view.""" client = await hass_client() + device = get_device(hass, multisensor_6) with patch( "homeassistant.components.zwave_js.api.begin_firmware_update", side_effect=FailedCommand("test", "test"), ): resp = await client.post( - f"/api/zwave_js/firmware/upload/{integration.entry_id}/{multisensor_6.node_id}", + f"/api/zwave_js/firmware/upload/{device.id}", data={"file": firmware_file}, ) assert resp.status == HTTPStatus.BAD_REQUEST @@ -2692,9 +2694,10 @@ async def test_firmware_upload_view_invalid_payload( hass, multisensor_6, integration, hass_client ): """Test an invalid payload for the HTTP firmware upload view.""" + device = get_device(hass, multisensor_6) client = await hass_client() resp = await client.post( - f"/api/zwave_js/firmware/upload/{integration.entry_id}/{multisensor_6.node_id}", + f"/api/zwave_js/firmware/upload/{device.id}", data={"wrong_key": bytes(10)}, ) assert resp.status == HTTPStatus.BAD_REQUEST @@ -2702,40 +2705,43 @@ async def test_firmware_upload_view_invalid_payload( @pytest.mark.parametrize( "method, url", - [("post", "/api/zwave_js/firmware/upload/{}/{}")], + [("post", "/api/zwave_js/firmware/upload/{}")], ) async def test_node_view_non_admin_user( - multisensor_6, integration, hass_client, hass_admin_user, method, url + hass, multisensor_6, integration, hass_client, hass_admin_user, method, url ): """Test node level views for non-admin users.""" client = await hass_client() + device = get_device(hass, multisensor_6) # Verify we require admin user hass_admin_user.groups = [] - resp = await client.request( - method, url.format(integration.entry_id, multisensor_6.node_id) - ) + resp = await client.request(method, url.format(device.id)) assert resp.status == HTTPStatus.UNAUTHORIZED @pytest.mark.parametrize( "method, url", [ - ("post", "/api/zwave_js/firmware/upload/INVALID/1"), + ("post", "/api/zwave_js/firmware/upload/{}"), ], ) -async def test_view_invalid_entry_id(integration, hass_client, method, url): - """Test an invalid config entry id parameter.""" +async def test_view_unloaded_config_entry( + hass, multisensor_6, integration, hass_client, method, url +): + """Test an unloaded config entry raises Bad Request.""" client = await hass_client() - resp = await client.request(method, url) + device = get_device(hass, multisensor_6) + await hass.config_entries.async_unload(integration.entry_id) + resp = await client.request(method, url.format(device.id)) assert resp.status == HTTPStatus.BAD_REQUEST @pytest.mark.parametrize( "method, url", - [("post", "/api/zwave_js/firmware/upload/{}/111")], + [("post", "/api/zwave_js/firmware/upload/INVALID")], ) -async def test_view_invalid_node_id(integration, hass_client, method, url): - """Test an invalid config entry id parameter.""" +async def test_view_invalid_device_id(integration, hass_client, method, url): + """Test an invalid device id parameter.""" client = await hass_client() resp = await client.request(method, url.format(integration.entry_id)) assert resp.status == HTTPStatus.NOT_FOUND