Don't duplicate html5 registrations (#11451)
* Don't duplicate html5 registrations If a registration is posted and another registration with the same endpoint URL exists, update that one instead. That way, we preserve the device name that has been configured. The previous behavior used to append 'unnamed device' registrations over and over, leading to multiple copies of the same registration. The endpoint URL is unique per service worker so it is safe to update matching registrations. * Refactor html5 registration view to not write json in the event loop
This commit is contained in:
parent
01896786fb
commit
455c629f47
2 changed files with 172 additions and 73 deletions
|
@ -169,15 +169,35 @@ class HTML5PushRegistrationView(HomeAssistantView):
|
|||
return self.json_message(
|
||||
humanize_error(data, ex), HTTP_BAD_REQUEST)
|
||||
|
||||
name = ensure_unique_string('unnamed device', self.registrations)
|
||||
name = self.find_registration_name(data)
|
||||
previous_registration = self.registrations.get(name)
|
||||
|
||||
self.registrations[name] = data
|
||||
|
||||
if not save_json(self.json_path, self.registrations):
|
||||
try:
|
||||
hass = request.app['hass']
|
||||
|
||||
yield from hass.async_add_job(save_json, self.json_path,
|
||||
self.registrations)
|
||||
return self.json_message(
|
||||
'Push notification subscriber registered.')
|
||||
except HomeAssistantError:
|
||||
if previous_registration is not None:
|
||||
self.registrations[name] = previous_registration
|
||||
else:
|
||||
self.registrations.pop(name)
|
||||
|
||||
return self.json_message(
|
||||
'Error saving registration.', HTTP_INTERNAL_SERVER_ERROR)
|
||||
|
||||
return self.json_message('Push notification subscriber registered.')
|
||||
def find_registration_name(self, data):
|
||||
"""Find a registration name matching data or generate a unique one."""
|
||||
endpoint = data.get(ATTR_SUBSCRIPTION).get(ATTR_ENDPOINT)
|
||||
for key, registration in self.registrations.items():
|
||||
subscription = registration.get(ATTR_SUBSCRIPTION)
|
||||
if subscription.get(ATTR_ENDPOINT) == endpoint:
|
||||
return key
|
||||
return ensure_unique_string('unnamed device', self.registrations)
|
||||
|
||||
@asyncio.coroutine
|
||||
def delete(self, request):
|
||||
|
@ -202,7 +222,12 @@ class HTML5PushRegistrationView(HomeAssistantView):
|
|||
|
||||
reg = self.registrations.pop(found)
|
||||
|
||||
if not save_json(self.json_path, self.registrations):
|
||||
try:
|
||||
hass = request.app['hass']
|
||||
|
||||
yield from hass.async_add_job(save_json, self.json_path,
|
||||
self.registrations)
|
||||
except HomeAssistantError:
|
||||
self.registrations[found] = reg
|
||||
return self.json_message(
|
||||
'Error saving registration.', HTTP_INTERNAL_SERVER_ERROR)
|
||||
|
|
|
@ -4,10 +4,14 @@ import json
|
|||
from unittest.mock import patch, MagicMock, mock_open
|
||||
from aiohttp.hdrs import AUTHORIZATION
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util.json import save_json
|
||||
from homeassistant.components.notify import html5
|
||||
|
||||
from tests.common import mock_http_component_app
|
||||
|
||||
CONFIG_FILE = 'file.conf'
|
||||
|
||||
SUBSCRIPTION_1 = {
|
||||
'browser': 'chrome',
|
||||
'subscription': {
|
||||
|
@ -108,36 +112,30 @@ class TestHtml5Notify(object):
|
|||
'unnamed device': SUBSCRIPTION_1,
|
||||
}
|
||||
|
||||
m = mock_open()
|
||||
with patch(
|
||||
'homeassistant.util.json.open',
|
||||
m, create=True
|
||||
):
|
||||
hass.config.path.return_value = 'file.conf'
|
||||
service = html5.get_service(hass, {})
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {})
|
||||
|
||||
assert service is not None
|
||||
assert service is not None
|
||||
|
||||
# assert hass.called
|
||||
assert len(hass.mock_calls) == 3
|
||||
assert len(hass.mock_calls) == 3
|
||||
|
||||
view = hass.mock_calls[1][1][0]
|
||||
assert view.json_path == hass.config.path.return_value
|
||||
assert view.registrations == {}
|
||||
view = hass.mock_calls[1][1][0]
|
||||
assert view.json_path == hass.config.path.return_value
|
||||
assert view.registrations == {}
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
resp = yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_1))
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
resp = yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_1))
|
||||
|
||||
content = yield from resp.text()
|
||||
assert resp.status == 200, content
|
||||
assert view.registrations == expected
|
||||
handle = m()
|
||||
assert json.loads(handle.write.call_args[0][0]) == expected
|
||||
content = yield from resp.text()
|
||||
assert resp.status == 200, content
|
||||
assert view.registrations == expected
|
||||
|
||||
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_registering_new_device_expiration_view(self, loop, test_client):
|
||||
|
@ -147,36 +145,114 @@ class TestHtml5Notify(object):
|
|||
'unnamed device': SUBSCRIPTION_4,
|
||||
}
|
||||
|
||||
m = mock_open()
|
||||
with patch(
|
||||
'homeassistant.util.json.open',
|
||||
m, create=True
|
||||
):
|
||||
hass.config.path.return_value = 'file.conf'
|
||||
service = html5.get_service(hass, {})
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {})
|
||||
|
||||
assert service is not None
|
||||
assert service is not None
|
||||
|
||||
# assert hass.called
|
||||
assert len(hass.mock_calls) == 3
|
||||
# assert hass.called
|
||||
assert len(hass.mock_calls) == 3
|
||||
|
||||
view = hass.mock_calls[1][1][0]
|
||||
assert view.json_path == hass.config.path.return_value
|
||||
assert view.registrations == {}
|
||||
view = hass.mock_calls[1][1][0]
|
||||
assert view.json_path == hass.config.path.return_value
|
||||
assert view.registrations == {}
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
resp = yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_4))
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
resp = yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_4))
|
||||
|
||||
content = yield from resp.text()
|
||||
assert resp.status == 200, content
|
||||
assert view.registrations == expected
|
||||
handle = m()
|
||||
assert json.loads(handle.write.call_args[0][0]) == expected
|
||||
content = yield from resp.text()
|
||||
assert resp.status == 200, content
|
||||
assert view.registrations == expected
|
||||
|
||||
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_registering_new_device_fails_view(self, loop, test_client):
|
||||
"""Test subs. are not altered when registering a new device fails."""
|
||||
hass = MagicMock()
|
||||
expected = {}
|
||||
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
html5.get_service(hass, {})
|
||||
view = hass.mock_calls[1][1][0]
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
hass.async_add_job.side_effect = HomeAssistantError()
|
||||
|
||||
resp = yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_1))
|
||||
|
||||
content = yield from resp.text()
|
||||
assert resp.status == 500, content
|
||||
assert view.registrations == expected
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_registering_existing_device_view(self, loop, test_client):
|
||||
"""Test subscription is updated when registering existing device."""
|
||||
hass = MagicMock()
|
||||
expected = {
|
||||
'unnamed device': SUBSCRIPTION_4,
|
||||
}
|
||||
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
html5.get_service(hass, {})
|
||||
view = hass.mock_calls[1][1][0]
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_1))
|
||||
resp = yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_4))
|
||||
|
||||
content = yield from resp.text()
|
||||
assert resp.status == 200, content
|
||||
assert view.registrations == expected
|
||||
|
||||
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_registering_existing_device_fails_view(self, loop, test_client):
|
||||
"""Test sub. is not updated when registering existing device fails."""
|
||||
hass = MagicMock()
|
||||
expected = {
|
||||
'unnamed device': SUBSCRIPTION_1,
|
||||
}
|
||||
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
html5.get_service(hass, {})
|
||||
view = hass.mock_calls[1][1][0]
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_1))
|
||||
|
||||
hass.async_add_job.side_effect = HomeAssistantError()
|
||||
resp = yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_4))
|
||||
|
||||
content = yield from resp.text()
|
||||
assert resp.status == 500, content
|
||||
assert view.registrations == expected
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_registering_new_device_validation(self, loop, test_client):
|
||||
|
@ -188,7 +264,7 @@ class TestHtml5Notify(object):
|
|||
'homeassistant.util.json.open',
|
||||
m, create=True
|
||||
):
|
||||
hass.config.path.return_value = 'file.conf'
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {})
|
||||
|
||||
assert service is not None
|
||||
|
@ -240,7 +316,7 @@ class TestHtml5Notify(object):
|
|||
'homeassistant.util.json.open',
|
||||
m, create=True
|
||||
):
|
||||
hass.config.path.return_value = 'file.conf'
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {})
|
||||
|
||||
assert service is not None
|
||||
|
@ -266,8 +342,9 @@ class TestHtml5Notify(object):
|
|||
|
||||
assert resp.status == 200, resp.response
|
||||
assert view.registrations == config
|
||||
handle = m()
|
||||
assert json.loads(handle.write.call_args[0][0]) == config
|
||||
|
||||
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE,
|
||||
config)
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_unregister_device_view_handle_unknown_subscription(
|
||||
|
@ -285,7 +362,7 @@ class TestHtml5Notify(object):
|
|||
'homeassistant.util.json.open',
|
||||
m, create=True
|
||||
):
|
||||
hass.config.path.return_value = 'file.conf'
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {})
|
||||
|
||||
assert service is not None
|
||||
|
@ -309,13 +386,13 @@ class TestHtml5Notify(object):
|
|||
|
||||
assert resp.status == 200, resp.response
|
||||
assert view.registrations == config
|
||||
handle = m()
|
||||
assert handle.write.call_count == 0
|
||||
|
||||
hass.async_add_job.assert_not_called()
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_unregistering_device_view_handles_json_safe_error(
|
||||
def test_unregistering_device_view_handles_save_error(
|
||||
self, loop, test_client):
|
||||
"""Test that the HTML unregister view handles JSON write errors."""
|
||||
"""Test that the HTML unregister view handles save errors."""
|
||||
hass = MagicMock()
|
||||
|
||||
config = {
|
||||
|
@ -328,7 +405,7 @@ class TestHtml5Notify(object):
|
|||
'homeassistant.util.json.open',
|
||||
m, create=True
|
||||
):
|
||||
hass.config.path.return_value = 'file.conf'
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {})
|
||||
|
||||
assert service is not None
|
||||
|
@ -346,16 +423,13 @@ class TestHtml5Notify(object):
|
|||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
with patch('homeassistant.components.notify.html5.save_json',
|
||||
return_value=False):
|
||||
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
||||
'subscription': SUBSCRIPTION_1['subscription'],
|
||||
}))
|
||||
hass.async_add_job.side_effect = HomeAssistantError()
|
||||
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
||||
'subscription': SUBSCRIPTION_1['subscription'],
|
||||
}))
|
||||
|
||||
assert resp.status == 500, resp.response
|
||||
assert view.registrations == config
|
||||
handle = m()
|
||||
assert handle.write.call_count == 0
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_callback_view_no_jwt(self, loop, test_client):
|
||||
|
@ -367,7 +441,7 @@ class TestHtml5Notify(object):
|
|||
'homeassistant.util.json.open',
|
||||
m, create=True
|
||||
):
|
||||
hass.config.path.return_value = 'file.conf'
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {})
|
||||
|
||||
assert service is not None
|
||||
|
@ -404,7 +478,7 @@ class TestHtml5Notify(object):
|
|||
'homeassistant.util.json.open',
|
||||
m, create=True
|
||||
):
|
||||
hass.config.path.return_value = 'file.conf'
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {'gcm_sender_id': '100'})
|
||||
|
||||
assert service is not None
|
||||
|
|
Loading…
Add table
Reference in a new issue