From 160c7fc68509d87e66637f00279f88caf3f65b33 Mon Sep 17 00:00:00 2001 From: Anders Melchiorsen Date: Sat, 9 Sep 2017 19:20:48 +0200 Subject: [PATCH] Add HTTP Basic auth to RESTful Switch (#9162) * Add HTTP Basic auth to RESTful Switch * Remove redundant hass passing * Initialize to current state The state used to be None until the first periodic poll. This commit refactors async_update so it can be used during setup as well, allowing the state to start out with the correct value. * Refactor turn_on/turn_off device communication * Remove lint * Fix Travis errors --- homeassistant/components/switch/rest.py | 96 ++++++++++++++----------- tests/components/switch/test_rest.py | 4 +- 2 files changed, 57 insertions(+), 43 deletions(-) diff --git a/homeassistant/components/switch/rest.py b/homeassistant/components/switch/rest.py index 31d4f0f3e06..c0f75509425 100644 --- a/homeassistant/components/switch/rest.py +++ b/homeassistant/components/switch/rest.py @@ -13,7 +13,8 @@ import voluptuous as vol from homeassistant.components.switch import (SwitchDevice, PLATFORM_SCHEMA) from homeassistant.const import ( - CONF_NAME, CONF_RESOURCE, CONF_TIMEOUT, CONF_METHOD) + CONF_NAME, CONF_RESOURCE, CONF_TIMEOUT, CONF_METHOD, CONF_USERNAME, + CONF_PASSWORD) from homeassistant.helpers.aiohttp_client import async_get_clientsession import homeassistant.helpers.config_validation as cv from homeassistant.helpers.template import Template @@ -41,6 +42,8 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ vol.All(vol.Lower, vol.In(SUPPORT_REST_METHODS)), vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_TIMEOUT, default=DEFAULT_TIMEOUT): cv.positive_int, + vol.Inclusive(CONF_USERNAME, 'authentication'): cv.string, + vol.Inclusive(CONF_PASSWORD, 'authentication'): cv.string, }) @@ -53,8 +56,13 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None): is_on_template = config.get(CONF_IS_ON_TEMPLATE) method = config.get(CONF_METHOD) name = config.get(CONF_NAME) + username = config.get(CONF_USERNAME) + password = config.get(CONF_PASSWORD) resource = config.get(CONF_RESOURCE) - websession = async_get_clientsession(hass) + + auth = None + if username: + auth = aiohttp.BasicAuth(username, password=password) if is_on_template is not None: is_on_template.hass = hass @@ -65,37 +73,32 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None): timeout = config.get(CONF_TIMEOUT) try: - with async_timeout.timeout(timeout, loop=hass.loop): - req = yield from websession.get(resource) + switch = RestSwitch(name, resource, method, auth, body_on, body_off, + is_on_template, timeout) + req = yield from switch.get_device_state(hass) if req.status >= 400: _LOGGER.error("Got non-ok response from resource: %s", req.status) - return False - + else: + async_add_devices([switch]) except (TypeError, ValueError): _LOGGER.error("Missing resource or schema in configuration. " "Add http:// or https:// to your URL") - return False except (asyncio.TimeoutError, aiohttp.ClientError): _LOGGER.error("No route to resource/endpoint: %s", resource) - return False - - async_add_devices( - [RestSwitch(hass, name, resource, method, body_on, body_off, - is_on_template, timeout)]) class RestSwitch(SwitchDevice): """Representation of a switch that can be toggled using REST.""" - def __init__(self, hass, name, resource, method, body_on, body_off, + def __init__(self, name, resource, method, auth, body_on, body_off, is_on_template, timeout): """Initialize the REST switch.""" self._state = None - self.hass = hass self._name = name self._resource = resource self._method = method + self._auth = auth self._body_on = body_on self._body_off = body_off self._is_on_template = is_on_template @@ -115,54 +118,61 @@ class RestSwitch(SwitchDevice): def async_turn_on(self, **kwargs): """Turn the device on.""" body_on_t = self._body_on.async_render() - websession = async_get_clientsession(self.hass) try: - with async_timeout.timeout(self._timeout, loop=self.hass.loop): - request = yield from getattr(websession, self._method)( - self._resource, data=bytes(body_on_t, 'utf-8')) + req = yield from self.set_device_state(body_on_t) + + if req.status == 200: + self._state = True + else: + _LOGGER.error( + "Can't turn on %s. Is resource/endpoint offline?", + self._resource) except (asyncio.TimeoutError, aiohttp.ClientError): _LOGGER.error("Error while turn on %s", self._resource) - return - - if request.status == 200: - self._state = True - else: - _LOGGER.error("Can't turn on %s. Is resource/endpoint offline?", - self._resource) @asyncio.coroutine def async_turn_off(self, **kwargs): """Turn the device off.""" body_off_t = self._body_off.async_render() - websession = async_get_clientsession(self.hass) try: - with async_timeout.timeout(self._timeout, loop=self.hass.loop): - request = yield from getattr(websession, self._method)( - self._resource, data=bytes(body_off_t, 'utf-8')) + req = yield from self.set_device_state(body_off_t) + if req.status == 200: + self._state = False + else: + _LOGGER.error( + "Can't turn off %s. Is resource/endpoint offline?", + self._resource) except (asyncio.TimeoutError, aiohttp.ClientError): _LOGGER.error("Error while turn off %s", self._resource) - return - if request.status == 200: - self._state = False - else: - _LOGGER.error("Can't turn off %s. Is resource/endpoint offline?", - self._resource) + @asyncio.coroutine + def set_device_state(self, body): + """Send a state update to the device.""" + websession = async_get_clientsession(self.hass) + + with async_timeout.timeout(self._timeout, loop=self.hass.loop): + req = yield from getattr(websession, self._method)( + self._resource, auth=self._auth, data=bytes(body, 'utf-8')) + return req @asyncio.coroutine def async_update(self): - """Get the latest data from REST API and update the state.""" - websession = async_get_clientsession(self.hass) - + """Get the current state, catching errors.""" try: - with async_timeout.timeout(self._timeout, loop=self.hass.loop): - request = yield from websession.get(self._resource) - text = yield from request.text() + yield from self.get_device_state(self.hass) except (asyncio.TimeoutError, aiohttp.ClientError): _LOGGER.exception("Error while fetch data.") - return + + @asyncio.coroutine + def get_device_state(self, hass): + """Get the latest data from REST API and update the state.""" + websession = async_get_clientsession(hass) + + with async_timeout.timeout(self._timeout, loop=hass.loop): + req = yield from websession.get(self._resource, auth=self._auth) + text = yield from req.text() if self._is_on_template is not None: text = self._is_on_template.async_render_with_possible_json_value( @@ -181,3 +191,5 @@ class RestSwitch(SwitchDevice): self._state = False else: self._state = None + + return req diff --git a/tests/components/switch/test_rest.py b/tests/components/switch/test_rest.py index 97911fccbfd..1b8215660bd 100644 --- a/tests/components/switch/test_rest.py +++ b/tests/components/switch/test_rest.py @@ -99,11 +99,13 @@ class TestRestSwitch: self.name = 'foo' self.method = 'post' self.resource = 'http://localhost/' + self.auth = None self.body_on = Template('on', self.hass) self.body_off = Template('off', self.hass) self.switch = rest.RestSwitch( - self.hass, self.name, self.resource, self.method, self.body_on, + self.name, self.resource, self.method, self.auth, self.body_on, self.body_off, None, 10) + self.switch.hass = self.hass def teardown_method(self): """Stop everything that was started."""