Don't duplicate html5 registrations (#11451)

* Don't duplicate html5 registrations

If a registration is posted and another registration with the same
endpoint URL exists, update that one instead. That way, we preserve
the device name that has been configured. The previous behavior used to
append 'unnamed device' registrations over and over, leading to
multiple copies of the same registration. The endpoint URL is unique per
service worker so it is safe to update matching registrations.

* Refactor html5 registration view to not write json in the event loop
This commit is contained in:
Christopher Viel 2018-01-05 17:29:27 -05:00 committed by Paulus Schoutsen
parent 01896786fb
commit 455c629f47
2 changed files with 172 additions and 73 deletions

View file

@ -169,15 +169,35 @@ class HTML5PushRegistrationView(HomeAssistantView):
return self.json_message(
humanize_error(data, ex), HTTP_BAD_REQUEST)
name = ensure_unique_string('unnamed device', self.registrations)
name = self.find_registration_name(data)
previous_registration = self.registrations.get(name)
self.registrations[name] = data
if not save_json(self.json_path, self.registrations):
try:
hass = request.app['hass']
yield from hass.async_add_job(save_json, self.json_path,
self.registrations)
return self.json_message(
'Push notification subscriber registered.')
except HomeAssistantError:
if previous_registration is not None:
self.registrations[name] = previous_registration
else:
self.registrations.pop(name)
return self.json_message(
'Error saving registration.', HTTP_INTERNAL_SERVER_ERROR)
return self.json_message('Push notification subscriber registered.')
def find_registration_name(self, data):
"""Find a registration name matching data or generate a unique one."""
endpoint = data.get(ATTR_SUBSCRIPTION).get(ATTR_ENDPOINT)
for key, registration in self.registrations.items():
subscription = registration.get(ATTR_SUBSCRIPTION)
if subscription.get(ATTR_ENDPOINT) == endpoint:
return key
return ensure_unique_string('unnamed device', self.registrations)
@asyncio.coroutine
def delete(self, request):
@ -202,7 +222,12 @@ class HTML5PushRegistrationView(HomeAssistantView):
reg = self.registrations.pop(found)
if not save_json(self.json_path, self.registrations):
try:
hass = request.app['hass']
yield from hass.async_add_job(save_json, self.json_path,
self.registrations)
except HomeAssistantError:
self.registrations[found] = reg
return self.json_message(
'Error saving registration.', HTTP_INTERNAL_SERVER_ERROR)

View file

@ -4,10 +4,14 @@ import json
from unittest.mock import patch, MagicMock, mock_open
from aiohttp.hdrs import AUTHORIZATION
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.json import save_json
from homeassistant.components.notify import html5
from tests.common import mock_http_component_app
CONFIG_FILE = 'file.conf'
SUBSCRIPTION_1 = {
'browser': 'chrome',
'subscription': {
@ -108,36 +112,30 @@ class TestHtml5Notify(object):
'unnamed device': SUBSCRIPTION_1,
}
m = mock_open()
with patch(
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = 'file.conf'
service = html5.get_service(hass, {})
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
assert service is not None
# assert hass.called
assert len(hass.mock_calls) == 3
assert len(hass.mock_calls) == 3
view = hass.mock_calls[1][1][0]
assert view.json_path == hass.config.path.return_value
assert view.registrations == {}
view = hass.mock_calls[1][1][0]
assert view.json_path == hass.config.path.return_value
assert view.registrations == {}
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_1))
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_1))
content = yield from resp.text()
assert resp.status == 200, content
assert view.registrations == expected
handle = m()
assert json.loads(handle.write.call_args[0][0]) == expected
content = yield from resp.text()
assert resp.status == 200, content
assert view.registrations == expected
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
@asyncio.coroutine
def test_registering_new_device_expiration_view(self, loop, test_client):
@ -147,36 +145,114 @@ class TestHtml5Notify(object):
'unnamed device': SUBSCRIPTION_4,
}
m = mock_open()
with patch(
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = 'file.conf'
service = html5.get_service(hass, {})
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
assert service is not None
# assert hass.called
assert len(hass.mock_calls) == 3
# assert hass.called
assert len(hass.mock_calls) == 3
view = hass.mock_calls[1][1][0]
assert view.json_path == hass.config.path.return_value
assert view.registrations == {}
view = hass.mock_calls[1][1][0]
assert view.json_path == hass.config.path.return_value
assert view.registrations == {}
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_4))
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_4))
content = yield from resp.text()
assert resp.status == 200, content
assert view.registrations == expected
handle = m()
assert json.loads(handle.write.call_args[0][0]) == expected
content = yield from resp.text()
assert resp.status == 200, content
assert view.registrations == expected
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
@asyncio.coroutine
def test_registering_new_device_fails_view(self, loop, test_client):
"""Test subs. are not altered when registering a new device fails."""
hass = MagicMock()
expected = {}
hass.config.path.return_value = CONFIG_FILE
html5.get_service(hass, {})
view = hass.mock_calls[1][1][0]
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
hass.async_add_job.side_effect = HomeAssistantError()
resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_1))
content = yield from resp.text()
assert resp.status == 500, content
assert view.registrations == expected
@asyncio.coroutine
def test_registering_existing_device_view(self, loop, test_client):
"""Test subscription is updated when registering existing device."""
hass = MagicMock()
expected = {
'unnamed device': SUBSCRIPTION_4,
}
hass.config.path.return_value = CONFIG_FILE
html5.get_service(hass, {})
view = hass.mock_calls[1][1][0]
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_1))
resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_4))
content = yield from resp.text()
assert resp.status == 200, content
assert view.registrations == expected
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
@asyncio.coroutine
def test_registering_existing_device_fails_view(self, loop, test_client):
"""Test sub. is not updated when registering existing device fails."""
hass = MagicMock()
expected = {
'unnamed device': SUBSCRIPTION_1,
}
hass.config.path.return_value = CONFIG_FILE
html5.get_service(hass, {})
view = hass.mock_calls[1][1][0]
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_1))
hass.async_add_job.side_effect = HomeAssistantError()
resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_4))
content = yield from resp.text()
assert resp.status == 500, content
assert view.registrations == expected
@asyncio.coroutine
def test_registering_new_device_validation(self, loop, test_client):
@ -188,7 +264,7 @@ class TestHtml5Notify(object):
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = 'file.conf'
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
@ -240,7 +316,7 @@ class TestHtml5Notify(object):
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = 'file.conf'
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
@ -266,8 +342,9 @@ class TestHtml5Notify(object):
assert resp.status == 200, resp.response
assert view.registrations == config
handle = m()
assert json.loads(handle.write.call_args[0][0]) == config
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE,
config)
@asyncio.coroutine
def test_unregister_device_view_handle_unknown_subscription(
@ -285,7 +362,7 @@ class TestHtml5Notify(object):
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = 'file.conf'
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
@ -309,13 +386,13 @@ class TestHtml5Notify(object):
assert resp.status == 200, resp.response
assert view.registrations == config
handle = m()
assert handle.write.call_count == 0
hass.async_add_job.assert_not_called()
@asyncio.coroutine
def test_unregistering_device_view_handles_json_safe_error(
def test_unregistering_device_view_handles_save_error(
self, loop, test_client):
"""Test that the HTML unregister view handles JSON write errors."""
"""Test that the HTML unregister view handles save errors."""
hass = MagicMock()
config = {
@ -328,7 +405,7 @@ class TestHtml5Notify(object):
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = 'file.conf'
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
@ -346,16 +423,13 @@ class TestHtml5Notify(object):
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
with patch('homeassistant.components.notify.html5.save_json',
return_value=False):
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
'subscription': SUBSCRIPTION_1['subscription'],
}))
hass.async_add_job.side_effect = HomeAssistantError()
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
'subscription': SUBSCRIPTION_1['subscription'],
}))
assert resp.status == 500, resp.response
assert view.registrations == config
handle = m()
assert handle.write.call_count == 0
@asyncio.coroutine
def test_callback_view_no_jwt(self, loop, test_client):
@ -367,7 +441,7 @@ class TestHtml5Notify(object):
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = 'file.conf'
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
@ -404,7 +478,7 @@ class TestHtml5Notify(object):
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = 'file.conf'
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {'gcm_sender_id': '100'})
assert service is not None