diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index fcaa05f7921..8868656eb79 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -21,8 +21,8 @@ import voluptuous as vol from homeassistant import config_entries from homeassistant.const import ( - CONF_PASSWORD, CONF_PAYLOAD, CONF_PORT, CONF_PROTOCOL, CONF_USERNAME, - CONF_VALUE_TEMPLATE, EVENT_HOMEASSISTANT_STOP, CONF_NAME) + CONF_DEVICE, CONF_PASSWORD, CONF_PAYLOAD, CONF_PORT, CONF_PROTOCOL, + CONF_USERNAME, CONF_VALUE_TEMPLATE, EVENT_HOMEASSISTANT_STOP, CONF_NAME) from homeassistant.core import Event, ServiceCall, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv @@ -996,9 +996,23 @@ class MqttDiscoveryUpdate(Entity): class MqttEntityDeviceInfo(Entity): """Mixin used for mqtt platforms that support the device registry.""" - def __init__(self, device_config: Optional[ConfigType]) -> None: + def __init__(self, device_config: Optional[ConfigType], + config_entry=None) -> None: """Initialize the device mixin.""" self._device_config = device_config + self._config_entry = config_entry + + async def device_info_discovery_update(self, config: dict): + """Handle updated discovery message.""" + self._device_config = config.get(CONF_DEVICE) + device_registry = await \ + self.hass.helpers.device_registry.async_get_registry() + config_entry_id = self._config_entry.entry_id + device_info = self.device_info + + if config_entry_id is not None and device_info is not None: + device_info['config_entry_id'] = config_entry_id + device_registry.async_get_or_create(**device_info) @property def device_info(self): diff --git a/homeassistant/components/mqtt/switch.py b/homeassistant/components/mqtt/switch.py index 8124dcf811b..ad4356b425b 100644 --- a/homeassistant/components/mqtt/switch.py +++ b/homeassistant/components/mqtt/switch.py @@ -65,7 +65,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities): try: discovery_hash = discovery_payload[ATTR_DISCOVERY_HASH] config = PLATFORM_SCHEMA(discovery_payload) - await _async_setup_entity(config, async_add_entities, + await _async_setup_entity(config, async_add_entities, config_entry, discovery_hash) except Exception: if discovery_hash: @@ -77,10 +77,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async_discover) -async def _async_setup_entity(config, async_add_entities, +async def _async_setup_entity(config, async_add_entities, config_entry=None, discovery_hash=None): """Set up the MQTT switch.""" - async_add_entities([MqttSwitch(config, discovery_hash)]) + async_add_entities([MqttSwitch(config, config_entry, discovery_hash)]) # pylint: disable=too-many-ancestors @@ -88,7 +88,7 @@ class MqttSwitch(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, SwitchDevice, RestoreEntity): """Representation of a switch that can be toggled using MQTT.""" - def __init__(self, config, discovery_hash): + def __init__(self, config, config_entry, discovery_hash): """Initialize the MQTT switch.""" self._state = False self._sub_state = None @@ -107,7 +107,7 @@ class MqttSwitch(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, MqttAvailability.__init__(self, config) MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) - MqttEntityDeviceInfo.__init__(self, device_config) + MqttEntityDeviceInfo.__init__(self, device_config, config_entry) async def async_added_to_hass(self): """Subscribe to MQTT events.""" @@ -120,6 +120,7 @@ class MqttSwitch(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, self._setup_from_config(config) await self.attributes_discovery_update(config) await self.availability_discovery_update(config) + await self.device_info_discovery_update(config) await self._subscribe_topics() self.async_schedule_update_ha_state() diff --git a/tests/components/mqtt/test_switch.py b/tests/components/mqtt/test_switch.py index b282b3149c4..5bbb04e1017 100644 --- a/tests/components/mqtt/test_switch.py +++ b/tests/components/mqtt/test_switch.py @@ -521,6 +521,53 @@ async def test_entity_device_info_with_identifier(hass, mqtt_mock): assert device.sw_version == '0.1-beta' +async def test_entity_device_info_update(hass, mqtt_mock): + """Test device registry update.""" + entry = MockConfigEntry(domain=mqtt.DOMAIN) + entry.add_to_hass(hass) + await async_start(hass, 'homeassistant', {}, entry) + registry = await hass.helpers.device_registry.async_get_registry() + + config = { + 'platform': 'mqtt', + 'name': 'Test 1', + 'state_topic': 'test-topic', + 'command_topic': 'test-command-topic', + 'device': { + 'identifiers': ['helloworld'], + 'connections': [ + ["mac", "02:5b:26:a8:dc:12"], + ], + 'manufacturer': 'Whatever', + 'name': 'Beer', + 'model': 'Glass', + 'sw_version': '0.1-beta', + }, + 'unique_id': 'veryunique' + } + + data = json.dumps(config) + async_fire_mqtt_message(hass, 'homeassistant/switch/bla/config', + data) + await hass.async_block_till_done() + await hass.async_block_till_done() + + device = registry.async_get_device({('mqtt', 'helloworld')}, set()) + assert device is not None + assert device.name == 'Beer' + + config['device']['name'] = 'Milk' + data = json.dumps(config) + async_fire_mqtt_message(hass, 'homeassistant/switch/bla/config', + data) + await hass.async_block_till_done() + await hass.async_block_till_done() + + device = registry.async_get_device({('mqtt', 'helloworld')}, set()) + assert device is not None + assert device.name == 'Milk' + + async def test_entity_id_update(hass, mqtt_mock): """Test MQTT subscriptions are managed when entity_id is updated.""" registry = mock_registry(hass, {})